diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ce67670 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +/msvc/.vs/ +*.exe +*.pdb +*.recipe +*.ilk +*.log +*.tlog +*.o +*.lastbuildstate +*.obj +*.idb +*.user diff --git a/README.txt b/README.txt new file mode 100644 index 0000000..a840f33 --- /dev/null +++ b/README.txt @@ -0,0 +1,42 @@ +This is a simple multiple-reader, single-writer mutex class for C++14. Reader +locks can be upgraded and then downgraded without being released. It uses a +queue (allocated on the stack) to prevent writer starvation and to resolve lock +acquisition conflicts. Writer lock acquisition requests are always fulfilled in +FIFO order. + + +Usage examples: + +helios::rwmutex m; + +void reader(){ + helios::rwlock l(m); + //reader code +} + +void conditional_writer(){ + helios::rwlock l(m); + //reader code + if (/*condition*/) + return; + l.upgrade(); + //note: calling l.upgrade() again would throw an exception + //writer code +} + +void nested_writer(){ + helios::rwlock l(m); + //reader code + if (/*condition*/) + return; + { + helios::rwlock_writer_region r(l); + //writer code + } + //some more reader code +} + +void unconditional_writer(){ + helios::wlock l(m); + //writer code +} diff --git a/build_test.sh b/build_test.sh new file mode 100644 index 0000000..6953160 --- /dev/null +++ b/build_test.sh @@ -0,0 +1,2 @@ +#!/bin/sh +c++ src/rwmutex.cpp test/test_concurrency.cpp -o test_concurrency diff --git a/msvc/rwmutex_test.sln b/msvc/rwmutex_test.sln new file mode 100644 index 0000000..d9eca84 --- /dev/null +++ b/msvc/rwmutex_test.sln @@ -0,0 +1,31 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.31424.327 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rwmutex_test", "rwmutex_test\rwmutex_test.vcxproj", "{FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + Release|x64 = Release|x64 + Release|x86 = Release|x86 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Debug|x64.ActiveCfg = Debug|x64 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Debug|x64.Build.0 = Debug|x64 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Debug|x86.ActiveCfg = Debug|Win32 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Debug|x86.Build.0 = Debug|Win32 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Release|x64.ActiveCfg = Release|x64 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Release|x64.Build.0 = Release|x64 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Release|x86.ActiveCfg = Release|Win32 + {FEE9AC51-BB7E-4164-8581-5A4A9EBD007C}.Release|x86.Build.0 = Release|Win32 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {7BED76C7-7591-4569-AF6A-52168DA71F59} + EndGlobalSection +EndGlobal diff --git a/msvc/rwmutex_test/rwmutex_test.vcxproj b/msvc/rwmutex_test/rwmutex_test.vcxproj new file mode 100644 index 0000000..76f5587 --- /dev/null +++ b/msvc/rwmutex_test/rwmutex_test.vcxproj @@ -0,0 +1,151 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + + + + + + + + 16.0 + Win32Proj + {fee9ac51-bb7e-4164-8581-5a4a9ebd007c} + rwmutextest + 10.0 + + + + Application + true + v142 + Unicode + + + Application + false + v142 + true + Unicode + + + Application + true + v142 + Unicode + + + Application + false + v142 + true + Unicode + + + + + + + + + + + + + + + + + + + + + true + + + false + + + true + + + false + + + + Level3 + true + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + + + + + Level3 + true + true + true + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + true + true + + + + + Level3 + true + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + + + + + Level3 + true + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + true + true + + + + + + \ No newline at end of file diff --git a/msvc/rwmutex_test/rwmutex_test.vcxproj.filters b/msvc/rwmutex_test/rwmutex_test.vcxproj.filters new file mode 100644 index 0000000..745fefb --- /dev/null +++ b/msvc/rwmutex_test/rwmutex_test.vcxproj.filters @@ -0,0 +1,30 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Source Files + + + Source Files + + + + + Header Files + + + \ No newline at end of file diff --git a/src/rwmutex.cpp b/src/rwmutex.cpp new file mode 100644 index 0000000..74ce1de --- /dev/null +++ b/src/rwmutex.cpp @@ -0,0 +1,115 @@ +#include "rwmutex.hpp" +#include + +namespace helios{ + +rwlock::rwlock(rwmutex &m): m(m){ + node.id = std::this_thread::get_id(); + m.lock_reader(node); + state = 1; +} + +rwlock::~rwlock(){ + if (state == 1) + m.unlock_reader(node); + else + m.unlock_writer(node); +} + +void rwlock::upgrade(){ + if (state != 1) + throw std::runtime_error("incorrect use"); + m.upgrade(); + state = 2; +} + +void rwlock::downgrade(){ + if (state != 2) + throw std::runtime_error("incorrect use"); + m.upgrade(); + state = 1; +} + +void rwmutex::lock_reader(rwlock_node &node){ + UL lg(m); + push(node); + while (state == 2) + cv.wait(lg); + state = 1; +} + +void rwmutex::unlock_reader(rwlock_node &node){ + UL lg(m); + remove(node); + if (!queue_size) + state = 0; + cv.notify_all(); +} + +void rwmutex::upgrade(){ + auto id = std::this_thread::get_id(); + UL lg(m); + writers_waiting++; + cv.notify_all(); + while (queue_head->id != id) + cv.wait(lg); + while (queue_size > writers_waiting && state != 0) + cv.wait(lg); + writers_waiting--; + state = 2; +} + +void rwmutex::downgrade(){ + UL lg(m); + state = 1; + cv.notify_all(); +} + +void rwmutex::lock_writer(rwlock_node &node){ + auto id = std::this_thread::get_id(); + UL lg(m); + push(node); + writers_waiting++; + cv.notify_all(); + while (queue_head->id != id) + cv.wait(lg); + while (queue_size > writers_waiting && state != 0) + cv.wait(lg); + writers_waiting--; + state = 2; +} + +void rwmutex::unlock_writer(rwlock_node &node){ + UL lg(m); + remove(node); + state = !!queue_size; + cv.notify_all(); +} + +void rwmutex::push(rwlock_node &node){ + queue_size++; + node.next = nullptr; + if (queue_size == 1){ + queue_head = &node; + queue_tail = &node; + node.previous = nullptr; + return; + } + queue_tail->next = &node; + node.previous = queue_tail; + queue_tail = &node; +} + +void rwmutex::remove(rwlock_node &node){ + queue_size--; + if (node.previous) + node.previous->next = node.next; + else + queue_head = node.next; + if (node.next) + node.next->previous = node.previous; + else + queue_tail = node.previous; +} + +} diff --git a/src/rwmutex.hpp b/src/rwmutex.hpp new file mode 100644 index 0000000..f0c3880 --- /dev/null +++ b/src/rwmutex.hpp @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include + +namespace helios{ + +struct rwlock_node{ + rwlock_node *previous; + rwlock_node *next; + std::thread::id id; +}; + +class rwmutex{ + int state = 0; + rwlock_node *queue_head = nullptr; + rwlock_node *queue_tail = nullptr; + size_t queue_size = 0; + size_t writers_waiting = 0; + std::mutex m; + std::condition_variable cv; + + friend class rwlock; + friend class wlock; + void lock_reader(rwlock_node &); + void unlock_reader(rwlock_node &); + void upgrade(); + void downgrade(); + void lock_writer(rwlock_node &); + void unlock_writer(rwlock_node &); + void push(rwlock_node &); + void remove(rwlock_node &); + typedef std::unique_lock UL; +}; + +class rwlock{ + rwmutex &m; + int state; + rwlock_node node; +public: + rwlock(rwmutex &); + ~rwlock(); + rwlock(const rwlock &) = delete; + rwlock &operator=(const rwlock &) = delete; + rwlock(rwlock &&) = delete; + rwlock &operator=(rwlock &&) = delete; + void upgrade(); + void downgrade(); +}; + +class wlock{ + rwmutex &m; + rwlock_node node; +public: + wlock(rwmutex &m): m(m){ + node.id = std::this_thread::get_id(); + m.lock_writer(node); + } + ~wlock(){ + m.unlock_writer(node); + } + wlock(const wlock &) = delete; + wlock &operator=(const wlock &) = delete; + wlock(wlock &&) = delete; + wlock &operator=(wlock &&) = delete; +}; + +class rwlock_writer_region{ + rwlock &lock; +public: + rwlock_writer_region(rwlock &lock): lock(lock){ + lock.upgrade(); + } + ~rwlock_writer_region(){ + lock.downgrade(); + } + rwlock_writer_region(const rwlock_writer_region &) = delete; + rwlock_writer_region &operator=(const rwlock_writer_region &) = delete; + rwlock_writer_region(rwlock_writer_region &&) = delete; + rwlock_writer_region &operator=(rwlock_writer_region &&) = delete; +}; + +} diff --git a/test/test_concurrency.cpp b/test/test_concurrency.cpp new file mode 100644 index 0000000..2c092f0 --- /dev/null +++ b/test/test_concurrency.cpp @@ -0,0 +1,224 @@ +/* + * This program tests that the mutex permits the maximum possible concurrency. + * Readers must always be able to run in parallel, but at most a single + * writer should run at any time. + */ + +#include "../src/rwmutex.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +typedef typename std::make_signed::type ssize_t; + +helios::rwmutex mutex; + +void simulate_work(){ + std::this_thread::sleep_for(std::chrono::milliseconds(250)); +} + +class AutoResetEvent{ + ssize_t state; + ssize_t waiting = 0; + std::mutex mutex; + std::condition_variable cv; +public: + AutoResetEvent(ssize_t initial_state = 0): state(initial_state){} + void signal(){ + std::lock_guard lock(this->mutex); + this->state++; + this->cv.notify_one(); + } + void signal_all(){ + std::lock_guard lock(this->mutex); + this->state += this->waiting; + this->cv.notify_all(); + } + void wait(){ + std::unique_lock lock(this->mutex); + this->waiting++; + while (this->state <= 0) + this->cv.wait(lock); + this->waiting--; + this->state--; + } + void set_state(ssize_t state){ + this->state = state; + } +}; + +class Counter{ + int state = 0; + int max_state = 0; + std::mutex m; +public: + void increment(){ + std::lock_guard lg(m); + max_state = std::max(max_state, ++state); + } + void decrement(){ + std::lock_guard lg(m); + --state; + } + int get(){ + std::lock_guard lg(m); + return max_state; + } +}; + +class AutoCounter{ + Counter &c; +public: + AutoCounter(Counter &c): c(c){ + c.increment(); + } + ~AutoCounter(){ + c.decrement(); + } +}; + +void read(Counter &shared, Counter &unique){ + helios::rwlock l(mutex); + { + AutoCounter ac(shared); + simulate_work(); + } +} + +void read_then_write(Counter &shared, Counter &unique){ + helios::rwlock l(mutex); + AutoCounter ac1(shared); + simulate_work(); + l.upgrade(); + AutoCounter ac2(unique); + simulate_work(); +} + +void read_then_write_then_read(Counter &shared, Counter &unique){ + helios::rwlock l(mutex); + { + AutoCounter ac1(shared); + simulate_work(); + { + helios::rwlock_writer_region r(l); + AutoCounter ac2(unique); + simulate_work(); + } + simulate_work(); + } +} + +void write(Counter &shared, Counter &unique){ + helios::wlock l(mutex); + { + AutoCounter ac2(unique); + simulate_work(); + } +} + +void test(AutoResetEvent &event1, AutoResetEvent &event2, int f, Counter &shared, Counter &unique){ + event1.signal(); + event2.wait(); + switch (f){ + case 0: + read(shared, unique); + break; + case 1: + read_then_write(shared, unique); + break; + case 2: + read_then_write_then_read(shared, unique); + break; + case 3: + write(shared, unique); + break; + } +} + +const char *function_to_string(int f){ + switch (f){ + case 0: + return "read"; + case 1: + return "read_then_write"; + case 2: + return "read_then_write_then_read"; + case 3: + return "write"; + } + return nullptr; +} + +void failed_test(int f1, int f2, int shared, int unique){ + std::cout << "Test " << function_to_string(f1) << " + " << function_to_string(f2) << " failed: " << shared << ", " << unique << std::endl; +} + +bool contains(const std::set &set, int i){ + return set.find(i) != set.end(); +} + +const int r = 0; +const int rw = 1; +const int rwr = 2; +const int w = 3; + +int main(){ + std::map, std::pair, std::set>> data = { + {{r , r }, {{2}, {0}}}, + {{r , rw }, {{2}, {1}}}, + {{r , rwr}, {{2}, {1}}}, + {{r , w }, {{1}, {1}}}, + + {{rw , r }, {{2}, {1}}}, + {{rw , rw }, {{2}, {1}}}, + {{rw , rwr}, {{2}, {1}}}, + {{rw , w }, {{1}, {1}}}, + + {{rwr, r }, {{2}, {1}}}, + {{rwr, rw }, {{2}, {1}}}, + {{rwr, rwr}, {{2}, {1}}}, + {{rwr, w }, {{1}, {1}}}, + + {{w , r }, {{1}, {1}}}, + {{w , rw }, {{1}, {1}}}, + {{w , rwr}, {{1}, {1}}}, + {{w , w }, {{0}, {1}}}, + }; + + bool ok = true; + int i = 0; + for (int f1 = 0; f1 < 4; f1++){ + for (int f2 = 0; f2 < 4; f2++){ + std::cout << "Test " << i++ << std::endl; + + AutoResetEvent event1(-1); + AutoResetEvent event2; + Counter shared; + Counter unique; + std::thread t1([&event1, &event2, f = f1, &shared, &unique](){ + test(event1, event2, f, shared, unique); + }); + std::thread t2([&event1, &event2, f = f2, &shared, &unique](){ + test(event1, event2, f, shared, unique); + }); + event1.wait(); + event2.signal_all(); + t1.join(); + t2.join(); + + auto &pair = data.find({ f1, f2 })->second; + + if (!contains(pair.first, shared.get()) || !contains(pair.second, unique.get())){ + failed_test(f1, f2, shared.get(), unique.get()); + ok = false; + } + } + } + if (ok) + std::cout << "OK\n"; +}