diff --git a/include/glaze/thread/sync.hpp b/include/glaze/thread/sync.hpp new file mode 100644 index 0000000000..4eafa13ac1 --- /dev/null +++ b/include/glaze/thread/sync.hpp @@ -0,0 +1,121 @@ +// Glaze Library +// For the license information refer to glaze.hpp + +#pragma once + +#include +#include +#include +#include +#include + +#include "glaze/util/type_traits.hpp" + +// The purpose of glz::sync is to create a thread-safe wrapper around a type +// The only way to access the data is by supplying lambdas to `read` or `write` +// methods, which feed underlying data into the lambda. +// A lock is held for the duration of the call. + +// Example: +// struct foo { int x{}; }; +// sync s{}; +// s.write([](auto& value) { value.x = 42; }); +// s.read([](const auto& value) { std::cout << value.x << '\n'; }); + +namespace glz +{ + template + concept const_callable = std::invocable; + + template + concept non_const_callable = + std::invocable || std::invocable; + + template + concept void_return = std::same_as, void>; + + template + class sync { + T data{}; + mutable std::shared_mutex mtx{}; + + public: + sync() = default; + + template + requires (!is_specialization_v, sync>) + sync(U&& initial_value) : data(std::forward(initial_value)) {} + + sync(const sync& other) + requires(std::copy_constructible) + { + std::shared_lock lock(other.mtx); + data = other.data; + } + + sync(sync&& other) noexcept(std::is_nothrow_move_constructible_v) + requires(std::move_constructible) + { + std::unique_lock lock(other.mtx); + data = std::move(other.data); + } + + sync& operator=(const sync& other) + requires(std::is_copy_assignable_v) + { + if (this != &other) { + std::scoped_lock lock(mtx, other.mtx); + data = other.data; + } + return *this; + } + + sync& operator=(sync&& other) noexcept(std::is_nothrow_move_assignable_v) + requires(std::is_move_assignable_v) + { + if (this != &other) { + std::scoped_lock lock(mtx, other.mtx); + data = std::move(other.data); + } + return *this; + } + + T copy() const { + std::shared_lock lock(mtx); + return data; + } + + // Read with non-void return value. + template + requires(const_callable && + !void_return) + auto read(Callable&& f) const -> std::invoke_result_t { + std::shared_lock lock(mtx); + return std::forward(f)(data); + } + + // Read with void return. + template + requires(const_callable && void_return) + void read(Callable&& f) const { + std::shared_lock lock(mtx); + std::forward(f)(data); + } + + // Write with non-void return value. + template + requires(non_const_callable && !void_return) + auto write(Callable&& f) -> std::invoke_result_t { + std::unique_lock lock(mtx); + return std::forward(f)(data); + } + + // Write with void return. + template + requires(non_const_callable && void_return) + void write(Callable&& f) { + std::unique_lock lock(mtx); + std::forward(f)(data); + } + }; +} diff --git a/tests/exceptions_test/exceptions_test.cpp b/tests/exceptions_test/exceptions_test.cpp index f1d163a85b..f0108fa38f 100644 --- a/tests/exceptions_test/exceptions_test.cpp +++ b/tests/exceptions_test/exceptions_test.cpp @@ -5,6 +5,7 @@ #include "glaze/thread/async_string.hpp" #include "glaze/thread/shared_async_map.hpp" #include "glaze/thread/shared_async_vector.hpp" +#include "glaze/thread/sync.hpp" #include "glaze/thread/threadpool.hpp" #include "ut/ut.hpp" @@ -1125,4 +1126,115 @@ suite async_string_tests = [] { }; }; +suite sync_tests = [] { + "non-void read and write operations"_test = [] + { + // Initialize with 10. + glz::sync s{10}; + + // Read with a lambda that returns a value. + auto doubled = s.read([](const int &x) -> int { + return x * 2; + }); + expect(doubled == 20); + + // Write with a lambda that returns a value. + auto new_value = s.write([](int &x) -> int { + x += 5; + return x; + }); + expect(new_value == 15); + + // Confirm the new value via a read lambda (void-returning). + s.read([](const int &x) { + expect(x == 15); + }); + }; + + "void read operation"_test = [] + { + glz::sync s{20}; + bool flag = false; + s.read([&flag](const int &x) { + if (x == 20) + flag = true; + }); + expect(flag); + }; + + "void write operation"_test = [] + { + glz::sync s{100}; + s.write([](int &x) { + x = 200; + }); + s.read([](const int &x) { + expect(x == 200); + }); + }; + + "copy constructor"_test = [] + { + glz::sync original{123}; + glz::sync copy = original; + copy.read([](const int &x) { + expect(x == 123); + }); + }; + + "move constructor"_test = [] + { + glz::sync original{"hello"}; + glz::sync moved = std::move(original); + moved.read([](const std::string &s) { + expect(s == "hello"); + }); + }; + + "copy assignment."_test = [] + { + glz::sync a{10}, b{20}; + a = b; // requires T to be copy-assignable + a.read([](const int &x) { + expect(x == 20); + }); + }; + + "move assignment."_test = [] + { + glz::sync a{"foo"}, b{"bar"}; + a = std::move(b); // requires T to be move-assignable + a.read([](const std::string &s) { + expect(s == "bar"); + }); + }; + + "concurrent access."_test = [] + { + glz::sync s{0}; + const int num_threads = 10; + const int increments = 1000; + std::vector threads; + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([&] { + for (int j = 0; j < increments; ++j) { + s.write([](int &value) { + ++value; + }); + } + }); + } + + for (auto &th : threads) { + th.join(); + } + + // Verify that the value is the expected total. + s.read([&](const int &value) { + expect(value == num_threads * increments); + }); + }; +}; + int main() { return 0; }