#include #include #include #include #include #include #include namespace c10 { namespace detail { struct IncrementRAII final { public: explicit IncrementRAII(std::atomic* counter) : _counter(counter) { _counter->fetch_add(1); } ~IncrementRAII() { _counter->fetch_sub(1); } private: std::atomic* _counter; C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); }; } // namespace detail // LeftRight wait-free readers synchronization primitive // https://hal.archives-ouvertes.fr/hal-01207881/document // // LeftRight is quite easy to use (it can make an arbitrary // data structure permit wait-free reads), but it has some // particular performance characteristics you should be aware // of if you're deciding to use it: // // - Reads still incur an atomic write (this is how LeftRight // keeps track of how long it needs to keep around the old // data structure) // // - Writes get executed twice, to keep both the left and right // versions up to date. So if your write is expensive or // nondeterministic, this is also an inappropriate structure // // LeftRight is used fairly rarely in PyTorch's codebase. If you // are still not sure if you need it or not, consult your local // C++ expert. // template class LeftRight final { public: template explicit LeftRight(const Args&... args) : _counters{{{0}, {0}}}, _foregroundCounterIndex(0), _foregroundDataIndex(0), _data{{T{args...}, T{args...}}}, _writeMutex() {} // Copying and moving would not be threadsafe. // Needs more thought and careful design to make that work. LeftRight(const LeftRight&) = delete; LeftRight(LeftRight&&) noexcept = delete; LeftRight& operator=(const LeftRight&) = delete; LeftRight& operator=(LeftRight&&) noexcept = delete; ~LeftRight() { // wait until any potentially running writers are finished { std::unique_lock lock(_writeMutex); } // wait until any potentially running readers are finished while (_counters[0].load() != 0 || _counters[1].load() != 0) { std::this_thread::yield(); } } template auto read(F&& readFunc) const -> typename std::result_of::type { detail::IncrementRAII _increment_counter( &_counters[_foregroundCounterIndex.load()]); return readFunc(_data[_foregroundDataIndex.load()]); } // Throwing an exception in writeFunc is ok but causes the state to be either // the old or the new state, depending on if the first or the second call to // writeFunc threw. template auto write(F&& writeFunc) -> typename std::result_of::type { std::unique_lock lock(_writeMutex); return _write(writeFunc); } private: template auto _write(const F& writeFunc) -> typename std::result_of::type { /* * Assume, A is in background and B in foreground. In simplified terms, we * want to do the following: * 1. Write to A (old background) * 2. Switch A/B * 3. Write to B (new background) * * More detailed algorithm (explanations on why this is important are below * in code): * 1. Write to A * 2. Switch A/B data pointers * 3. Wait until A counter is zero * 4. Switch A/B counters * 5. Wait until B counter is zero * 6. Write to B */ auto localDataIndex = _foregroundDataIndex.load(); // 1. Write to A _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); // 2. Switch A/B data pointers localDataIndex = localDataIndex ^ 1; _foregroundDataIndex = localDataIndex; /* * 3. Wait until A counter is zero * * In the previous write run, A was foreground and B was background. * There was a time after switching _foregroundDataIndex (B to foreground) * and before switching _foregroundCounterIndex, in which new readers could * have read B but incremented A's counter. * * In this current run, we just switched _foregroundDataIndex (A back to * foreground), but before writing to the new background B, we have to make * sure A's counter was zero briefly, so all these old readers are gone. */ auto localCounterIndex = _foregroundCounterIndex.load(); _waitForBackgroundCounterToBeZero(localCounterIndex); /* * 4. Switch A/B counters * * Now that we know all readers on B are really gone, we can switch the * counters and have new readers increment A's counter again, which is the * correct counter since they're reading A. */ localCounterIndex = localCounterIndex ^ 1; _foregroundCounterIndex = localCounterIndex; /* * 5. Wait until B counter is zero * * This waits for all the readers on B that came in while both data and * counter for B was in foreground, i.e. normal readers that happened * outside of that brief gap between switching data and counter. */ _waitForBackgroundCounterToBeZero(localCounterIndex); // 6. Write to B return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); } template auto _callWriteFuncOnBackgroundInstance( const F& writeFunc, uint8_t localDataIndex) -> typename std::result_of::type { try { return writeFunc(_data[localDataIndex ^ 1]); } catch (...) { // recover invariant by copying from the foreground instance _data[localDataIndex ^ 1] = _data[localDataIndex]; // rethrow throw; } } void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { while (_counters[counterIndex ^ 1].load() != 0) { std::this_thread::yield(); } } mutable std::array, 2> _counters; std::atomic _foregroundCounterIndex; std::atomic _foregroundDataIndex; std::array _data; std::mutex _writeMutex; }; // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a // read-write lock to protect T (data). template class RWSafeLeftRightWrapper final { using mutexType = std::mutex; using rLockType = std::unique_lock; using wLockType = std::unique_lock; public: template explicit RWSafeLeftRightWrapper(const Args&... args) : _data{args...} {} // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight // is not copyable or moveable. RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; template auto read(F&& readFunc) const -> typename std::result_of::type { rLockType lock(mutex_); return readFunc(_data); } template auto write(F&& writeFunc) -> typename std::result_of::type { wLockType lock(mutex_); return writeFunc(_data); } private: T _data; mutable mutexType mutex_; }; } // namespace c10