ostd/sync/
rwmutex.rs

1// SPDX-License-Identifier: MPL-2.0
2
3use core::{
4    cell::UnsafeCell,
5    fmt,
6    ops::{Deref, DerefMut},
7    sync::atomic::{
8        AtomicUsize,
9        Ordering::{AcqRel, Acquire, Relaxed, Release},
10    },
11};
12
13use super::WaitQueue;
14
15/// A mutex that provides data access to either one writer or many readers.
16///
17/// # Overview
18///
19/// This mutex allows for multiple readers, or at most one writer to access
20/// at any point in time. The writer of this mutex has exclusive access to
21/// modify the underlying data, while the readers are allowed shared and
22/// read-only access.
23///
24/// The writing and reading portions cannot be active simultaneously, when
25/// one portion is in progress, the other portion will sleep. This is
26/// suitable for scenarios where the mutex is expected to be held for a
27/// period of time, which can avoid wasting CPU resources.
28///
29/// This implementation provides the upgradeable read mutex (`upread mutex`).
30/// The `upread mutex` can be upgraded to write mutex atomically, useful in
31/// scenarios where a decision to write is made after reading.
32///
33/// The type parameter `T` represents the data that this mutex is protecting.
34/// It is necessary for `T` to satisfy [`Send`] to be shared across tasks and
35/// [`Sync`] to permit concurrent access via readers. The [`Deref`] method (and
36/// [`DerefMut`] for the writer) is implemented for the RAII guards returned
37/// by the locking methods, which allows for the access to the protected data
38/// while the mutex is held.
39///
40/// # Usage
41///
42/// The mutex can be used in scenarios where data needs to be read frequently
43/// but written to occasionally.
44///
45/// Use `upread mutex` in scenarios where related checking is performed before
46/// modification to effectively avoid deadlocks and improve efficiency.
47///
48/// # Safety
49///
50/// Avoid using `RwMutex` in an interrupt context, as it may result in sleeping
51/// and never being awakened.
52///
53/// # Examples
54///
55/// ```
56/// use ostd::sync::RwMutex;
57///
58/// let mutex = RwMutex::new(5)
59///
60/// // many read mutexes can be held at once
61/// {
62///     let r1 = mutex.read();
63///     let r2 = mutex.read();
64///     assert_eq!(*r1, 5);
65///     assert_eq!(*r2, 5);
66///     
67///     // Upgradeable read mutex can share access to data with read mutexes
68///     let r3 = mutex.upread();
69///     assert_eq!(*r3, 5);
70///     drop(r1);
71///     drop(r2);
72///     // read mutexes are dropped at this point
73///
74///     // An upread mutex can only be upgraded successfully after all the
75///     // read mutexes are released, otherwise it will spin-wait.
76///     let mut w1 = r3.upgrade();
77///     *w1 += 1;
78///     assert_eq!(*w1, 6);
79/// }   // upread mutex are dropped at this point
80///
81/// {   
82///     // Only one write mutex can be held at a time
83///     let mut w2 = mutex.write();
84///     *w2 += 1;
85///     assert_eq!(*w2, 7);
86/// }   // write mutex is dropped at this point
87/// ```
88pub struct RwMutex<T: ?Sized> {
89    /// The internal representation of the mutex state is as follows:
90    /// - **Bit 63:** Writer mutex.
91    /// - **Bit 62:** Upgradeable reader mutex.
92    /// - **Bit 61:** Indicates if an upgradeable reader is being upgraded.
93    /// - **Bits 60-0:** Reader mutex count.
94    lock: AtomicUsize,
95    /// Threads that fail to acquire the mutex will sleep on this waitqueue.
96    queue: WaitQueue,
97    val: UnsafeCell<T>,
98}
99
100const READER: usize = 1;
101const WRITER: usize = 1 << (usize::BITS - 1);
102const UPGRADEABLE_READER: usize = 1 << (usize::BITS - 2);
103const BEING_UPGRADED: usize = 1 << (usize::BITS - 3);
104
105/// This bit is reserved as an overflow sentinel.
106/// For more details, see comments on the `MAX_READER` constant
107/// in the [`super::rwlock`] module.
108const MAX_READER: usize = 1 << (usize::BITS - 4);
109
110impl<T> RwMutex<T> {
111    /// Creates a new read-write mutex with an initial value.
112    pub const fn new(val: T) -> Self {
113        Self {
114            val: UnsafeCell::new(val),
115            lock: AtomicUsize::new(0),
116            queue: WaitQueue::new(),
117        }
118    }
119}
120
121impl<T: ?Sized> RwMutex<T> {
122    /// Acquires a read mutex and sleep until it can be acquired.
123    ///
124    /// The calling thread will sleep until there are no writers or upgrading
125    /// upreaders present. The implementation of [`WaitQueue`] guarantees the
126    /// order in which other concurrent readers or writers waiting simultaneously
127    /// will acquire the mutex.
128    #[track_caller]
129    pub fn read(&self) -> RwMutexReadGuard<'_, T> {
130        self.queue.wait_until(|| self.try_read())
131    }
132
133    /// Acquires a write mutex and sleep until it can be acquired.
134    ///
135    /// The calling thread will sleep until there are no writers, upreaders,
136    /// or readers present. The implementation of [`WaitQueue`] guarantees the
137    /// order in which other concurrent readers or writers waiting simultaneously
138    /// will acquire the mutex.
139    #[track_caller]
140    pub fn write(&self) -> RwMutexWriteGuard<'_, T> {
141        self.queue.wait_until(|| self.try_write())
142    }
143
144    /// Acquires a upread mutex and sleep until it can be acquired.
145    ///
146    /// The calling thread will sleep until there are no writers or upreaders present.
147    /// The implementation of [`WaitQueue`] guarantees the order in which other concurrent
148    /// readers or writers waiting simultaneously will acquire the mutex.
149    ///
150    /// Upreader will not block new readers until it tries to upgrade. Upreader
151    /// and reader do not differ before invoking the upgrade method. However,
152    /// only one upreader can exist at any time to avoid deadlock in the
153    /// upgrade method.
154    #[track_caller]
155    pub fn upread(&self) -> RwMutexUpgradeableGuard<'_, T> {
156        self.queue.wait_until(|| self.try_upread())
157    }
158
159    /// Attempts to acquire a read mutex.
160    ///
161    /// This function will never sleep and will return immediately.
162    pub fn try_read(&self) -> Option<RwMutexReadGuard<'_, T>> {
163        let lock = self.lock.fetch_add(READER, Acquire);
164        if lock & (WRITER | BEING_UPGRADED | MAX_READER) == 0 {
165            Some(RwMutexReadGuard { inner: self })
166        } else {
167            self.lock.fetch_sub(READER, Release);
168            None
169        }
170    }
171
172    /// Attempts to acquire a write mutex.
173    ///
174    /// This function will never sleep and will return immediately.
175    pub fn try_write(&self) -> Option<RwMutexWriteGuard<'_, T>> {
176        if self
177            .lock
178            .compare_exchange(0, WRITER, Acquire, Relaxed)
179            .is_ok()
180        {
181            Some(RwMutexWriteGuard { inner: self })
182        } else {
183            None
184        }
185    }
186
187    /// Attempts to acquire a upread mutex.
188    ///
189    /// This function will never sleep and will return immediately.
190    pub fn try_upread(&self) -> Option<RwMutexUpgradeableGuard<'_, T>> {
191        let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER);
192        if lock == 0 {
193            return Some(RwMutexUpgradeableGuard { inner: self });
194        } else if lock == WRITER {
195            self.lock.fetch_sub(UPGRADEABLE_READER, Release);
196        }
197        None
198    }
199
200    /// Returns a mutable reference to the underlying data.
201    ///
202    /// This method is zero-cost: By holding a mutable reference to the lock, the compiler has
203    /// already statically guaranteed that access to the data is exclusive.
204    pub fn get_mut(&mut self) -> &mut T {
205        self.val.get_mut()
206    }
207}
208
209impl<T: ?Sized + fmt::Debug> fmt::Debug for RwMutex<T> {
210    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211        fmt::Debug::fmt(&self.val, f)
212    }
213}
214
215/// Because there can be more than one readers to get the T's immutable ref,
216/// so T must be Sync to guarantee the sharing safety.
217unsafe impl<T: ?Sized + Send> Send for RwMutex<T> {}
218unsafe impl<T: ?Sized + Send + Sync> Sync for RwMutex<T> {}
219
220impl<T: ?Sized> !Send for RwMutexWriteGuard<'_, T> {}
221unsafe impl<T: ?Sized + Sync> Sync for RwMutexWriteGuard<'_, T> {}
222
223impl<T: ?Sized> !Send for RwMutexReadGuard<'_, T> {}
224unsafe impl<T: ?Sized + Sync> Sync for RwMutexReadGuard<'_, T> {}
225
226impl<T: ?Sized> !Send for RwMutexUpgradeableGuard<'_, T> {}
227unsafe impl<T: ?Sized + Sync> Sync for RwMutexUpgradeableGuard<'_, T> {}
228
229/// A guard that provides immutable data access.
230pub struct RwMutexReadGuard<'a, T: ?Sized> {
231    inner: &'a RwMutex<T>,
232}
233
234impl<T: ?Sized> Deref for RwMutexReadGuard<'_, T> {
235    type Target = T;
236
237    fn deref(&self) -> &T {
238        unsafe { &*self.inner.val.get() }
239    }
240}
241
242impl<T: ?Sized> Drop for RwMutexReadGuard<'_, T> {
243    fn drop(&mut self) {
244        // When there are no readers, wake up a waiting writer.
245        if self.inner.lock.fetch_sub(READER, Release) == READER {
246            self.inner.queue.wake_one();
247        }
248    }
249}
250
251/// A guard that provides mutable data access.
252#[clippy::has_significant_drop]
253#[must_use]
254pub struct RwMutexWriteGuard<'a, T: ?Sized> {
255    inner: &'a RwMutex<T>,
256}
257
258impl<T: ?Sized> Deref for RwMutexWriteGuard<'_, T> {
259    type Target = T;
260
261    fn deref(&self) -> &T {
262        unsafe { &*self.inner.val.get() }
263    }
264}
265
266impl<'a, T: ?Sized> RwMutexWriteGuard<'a, T> {
267    /// Atomically downgrades a write guard to an upgradeable reader guard.
268    ///
269    /// This method always succeeds because the lock is exclusively held by the writer.
270    pub fn downgrade(mut self) -> RwMutexUpgradeableGuard<'a, T> {
271        loop {
272            self = match self.try_downgrade() {
273                Ok(guard) => return guard,
274                Err(e) => e,
275            };
276        }
277    }
278
279    /// This is not exposed as a public method to prevent intermediate lock states from affecting the
280    /// downgrade process.
281    fn try_downgrade(self) -> Result<RwMutexUpgradeableGuard<'a, T>, Self> {
282        let inner = self.inner;
283        let res = self
284            .inner
285            .lock
286            .compare_exchange(WRITER, UPGRADEABLE_READER, AcqRel, Relaxed);
287        if res.is_ok() {
288            drop(self);
289            Ok(RwMutexUpgradeableGuard { inner })
290        } else {
291            Err(self)
292        }
293    }
294}
295
296impl<T: ?Sized> DerefMut for RwMutexWriteGuard<'_, T> {
297    fn deref_mut(&mut self) -> &mut Self::Target {
298        unsafe { &mut *self.inner.val.get() }
299    }
300}
301
302impl<T: ?Sized> Drop for RwMutexWriteGuard<'_, T> {
303    fn drop(&mut self) {
304        self.inner.lock.fetch_and(!WRITER, Release);
305
306        // When the current writer releases, wake up all the sleeping threads.
307        // All awakened threads may include readers and writers.
308        // Thanks to the `wait_until` method, either all readers
309        // continue to execute or one writer continues to execute.
310        self.inner.queue.wake_all();
311    }
312}
313
314/// A guard that provides immutable data access but can be atomically
315/// upgraded to [`RwMutexWriteGuard`].
316pub struct RwMutexUpgradeableGuard<'a, T: ?Sized> {
317    inner: &'a RwMutex<T>,
318}
319
320impl<'a, T: ?Sized> RwMutexUpgradeableGuard<'a, T> {
321    /// Upgrades this upread guard to a write guard atomically.
322    ///
323    /// After calling this method, subsequent readers will be blocked
324    /// while previous readers remain unaffected.
325    ///
326    /// The calling thread will not sleep, but spin to wait for the existing
327    /// reader to be released. There are two main reasons.
328    /// - First, it needs to sleep in an extra waiting queue and needs extra wake-up logic and overhead.
329    /// - Second, upgrading method usually requires a high response time (because the mutex is being used now).
330    pub fn upgrade(mut self) -> RwMutexWriteGuard<'a, T> {
331        self.inner.lock.fetch_or(BEING_UPGRADED, Acquire);
332        loop {
333            self = match self.try_upgrade() {
334                Ok(guard) => return guard,
335                Err(e) => e,
336            };
337        }
338    }
339
340    /// Attempts to upgrade this upread guard to a write guard atomically.
341    ///
342    /// This function will return immediately.
343    ///
344    /// This function is not exposed publicly because the `BEING_UPGRADED` bit
345    /// is set only in [`Self::upgrade`].
346    fn try_upgrade(self) -> Result<RwMutexWriteGuard<'a, T>, Self> {
347        let res = self.inner.lock.compare_exchange(
348            UPGRADEABLE_READER | BEING_UPGRADED,
349            WRITER | UPGRADEABLE_READER,
350            AcqRel,
351            Relaxed,
352        );
353        if res.is_ok() {
354            let inner = self.inner;
355            drop(self);
356            Ok(RwMutexWriteGuard { inner })
357        } else {
358            Err(self)
359        }
360    }
361}
362
363impl<T: ?Sized> Deref for RwMutexUpgradeableGuard<'_, T> {
364    type Target = T;
365
366    fn deref(&self) -> &T {
367        unsafe { &*self.inner.val.get() }
368    }
369}
370
371impl<T: ?Sized> Drop for RwMutexUpgradeableGuard<'_, T> {
372    fn drop(&mut self) {
373        let res = self.inner.lock.fetch_sub(UPGRADEABLE_READER, Release);
374        if res == UPGRADEABLE_READER {
375            self.inner.queue.wake_all();
376        }
377    }
378}