Created
November 9, 2018 19:37
-
-
Save nikkon-dev/b7d1e84c28f5cfea64bf54500cee4b72 to your computer and use it in GitHub Desktop.
Rust-like Mutex
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <mutex> | |
#include <utility> | |
#include <future> | |
#include <thread> | |
#include <vector> | |
#include <cstdio> | |
struct NoLog{ | |
template <class ... TArgs> | |
static void LogError(const char* fmt, TArgs&& ... args){ | |
fprintf(stderr, fmt, std::forward<TArgs>(args)...); | |
} | |
}; | |
template <class T, class M, class TLogger, bool Rec> class BaseMutex; | |
template <class TMutex> | |
class MutexGuard final{ | |
public: | |
using value_type = typename TMutex::value_type; | |
value_type* operator->() noexcept { | |
return &obj_; | |
} | |
value_type const* operator->() const noexcept { | |
return &obj_; | |
} | |
private: | |
friend TMutex; | |
MutexGuard(value_type& obj, TMutex& mutex) | |
: obj_(obj) | |
, lock_{mutex, std::adopt_lock} | |
{ | |
} | |
value_type& obj_; | |
std::unique_lock<TMutex> lock_; | |
}; | |
template <class T, class M, class TLogger = NoLog, bool Rec = false> | |
class BaseMutex{ | |
public: | |
using value_type = T; | |
BaseMutex(BaseMutex&& r) | |
: obj_(std::move(*r.get().operator->())) | |
, poisoned_(r.poisoned_) | |
{ | |
r.poisoned_ = true; | |
} | |
BaseMutex& operator=(BaseMutex&& r){ | |
if (this != &r){ | |
std::lock_guard<BaseMutex> this_lock{*this, std::defer_lock}; | |
std::lock_guard<BaseMutex> that_lock{r, std::defer_lock}; | |
std::lock(this_lock, that_lock); | |
obj_ = std::move(r.obj_); | |
poisoned_ = r.poisoned_; | |
r.poisoned_ = true; | |
} | |
return *this; | |
} | |
explicit BaseMutex(value_type& obj) : obj_(std::move(obj)) {} | |
template <class ... TArgs> | |
explicit BaseMutex(TArgs&& ... args) : obj_{value_type(std::forward<TArgs>(args)...)} {;} | |
MutexGuard<BaseMutex> get() noexcept { | |
lock(); | |
if (poisoned_){ | |
TLogger::LogError("BaseMutex: Use after move."); | |
std::abort(); | |
} | |
return MutexGuard<BaseMutex>{obj_, *this}; | |
} | |
void lock(){ | |
while(!mutex_.try_lock()){ | |
if constexpr (!Rec){ | |
if (owner_thread_id_ == std::this_thread::get_id()) { | |
TLogger::LogError("BaseMutex: Recursive use is not supported"); | |
std::abort(); | |
} | |
} | |
} | |
owner_thread_id_ = std::this_thread::get_id(); | |
} | |
void unlock(){ | |
owner_thread_id_ = std::thread::id{}; | |
mutex_.unlock(); | |
} | |
private: | |
friend class MutexGuard<BaseMutex>; | |
value_type obj_; | |
M mutex_; | |
std::thread::id owner_thread_id_{}; | |
bool poisoned_ = false; | |
}; | |
template <class T> | |
using Mutex = BaseMutex<T, std::mutex, NoLog, false>; | |
template <class T> | |
using RecursiveMutex = BaseMutex<T, std::recursive_mutex, NoLog, true>; | |
struct MyObj{ | |
explicit MyObj(int cnt, std::string str) : cnt(cnt), some(std::move(str)){} | |
MyObj(MyObj&&) = default; | |
int cnt = 0; | |
std::string some; | |
}; | |
//=========================== | |
#define SIMPLE_FAILURE | |
// Choose wisely :) | |
// #ifdef RECURSIVE_SUCCESS | |
// #define USE_RECURSIVE | |
// #define RECURSIVE_TEST | |
// #endif | |
// #ifdef RECURSIVE_FAILURE | |
// #define TEST_MOVE | |
// #define USE_RECURSIVE | |
// #define RECURSIVE_TEST | |
// #endif | |
// #ifdef SIMPLE_SUCCESS | |
// #define USE_SIMPLE | |
// #endif | |
#ifdef SIMPLE_FAILURE | |
#define USE_SIMPLE | |
#define RECURSIVE_TEST | |
#endif | |
template <class T> | |
using TestMutex = | |
#ifdef USE_RECURSIVE | |
#error | |
RecursiveMutex<T> | |
#endif | |
#ifdef USE_SIMPLE | |
Mutex<T> | |
#endif | |
; | |
//--------------------------- | |
#if defined(TEST_MOVE) | |
void own_func(TestMutex<MyObj> obj){ | |
obj.get()->cnt += 10; | |
} | |
#endif | |
void func(TestMutex<MyObj>& obj){ | |
using namespace std::chrono_literals; | |
std::this_thread::sleep_for(1ns); | |
#ifdef RECURSIVE_TEST | |
auto o1 = obj.get(); | |
auto o2 = obj.get(); | |
o1->cnt += 1; | |
o2->cnt += 2; | |
#else | |
obj.get()->cnt += 1; | |
#endif | |
#ifdef TEST_MOVE | |
own_func(std::move(obj)); | |
#endif | |
} | |
int main(){ | |
std::vector<std::future<void>> threads; | |
threads.reserve(100); | |
TestMutex<MyObj> mobj{10, "hello"}; | |
for (size_t i = 0; i < 100; ++i){ | |
threads.emplace_back(std::async(std::launch::async, [&](){func(mobj);})); | |
} | |
for (auto& f : threads){ | |
f.wait(); | |
} | |
printf("%d\n", mobj.get()->cnt); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment