ostd/sync/
rwmutex.rs

1// SPDX-License-Identifier: MPL-2.0
2
3use core::{
4    cell::UnsafeCell,
5    fmt,
6    ops::{Deref, DerefMut},
7    sync::atomic::{
8        AtomicUsize,
9        Ordering::{AcqRel, Acquire, Relaxed, Release},
10    },
11};
12
13use super::WaitQueue;
14
15/// A mutex that provides data access to either one writer or many readers.
16///
17/// # Overview
18///
19/// This mutex allows for multiple readers, or at most one writer to access
20/// at any point in time. The writer of this mutex has exclusive access to
21/// modify the underlying data, while the readers are allowed shared and
22/// read-only access.
23///
24/// The writing and reading portions cannot be active simultaneously, when
25/// one portion is in progress, the other portion will sleep. This is
26/// suitable for scenarios where the mutex is expected to be held for a
27/// period of time, which can avoid wasting CPU resources.
28///
29/// This implementation provides the upgradeable read mutex (`upread mutex`).
30/// The `upread mutex` can be upgraded to write mutex atomically, useful in
31/// scenarios where a decision to write is made after reading.
32///
33/// The type parameter `T` represents the data that this mutex is protecting.
34/// It is necessary for `T` to satisfy [`Send`] to be shared across tasks and
35/// [`Sync`] to permit concurrent access via readers. The [`Deref`] method (and
36/// [`DerefMut`] for the writer) is implemented for the RAII guards returned
37/// by the locking methods, which allows for the access to the protected data
38/// while the mutex is held.
39///
40/// # Usage
41/// The mutex can be used in scenarios where data needs to be read frequently
42/// but written to occasionally.
43///
44/// Use `upread mutex` in scenarios where related checking is performed before
45/// modification to effectively avoid deadlocks and improve efficiency.
46///
47/// # Safety
48///
49/// Avoid using `RwMutex` in an interrupt context, as it may result in sleeping
50/// and never being awakened.
51///
52/// # Examples
53///
54/// ```
55/// use ostd::sync::RwMutex;
56///
57/// let mutex = RwMutex::new(5)
58///
59/// // many read mutexes can be held at once
60/// {
61///     let r1 = mutex.read();
62///     let r2 = mutex.read();
63///     assert_eq!(*r1, 5);
64///     assert_eq!(*r2, 5);
65///     
66///     // Upgradeable read mutex can share access to data with read mutexes
67///     let r3 = mutex.upread();
68///     assert_eq!(*r3, 5);
69///     drop(r1);
70///     drop(r2);
71///     // read mutexes are dropped at this point
72///
73///     // An upread mutex can only be upgraded successfully after all the
74///     // read mutexes are released, otherwise it will spin-wait.
75///     let mut w1 = r3.upgrade();
76///     *w1 += 1;
77///     assert_eq!(*w1, 6);
78/// }   // upread mutex are dropped at this point
79///
80/// {   
81///     // Only one write mutex can be held at a time
82///     let mut w2 = mutex.write();
83///     *w2 += 1;
84///     assert_eq!(*w2, 7);
85/// }   // write mutex is dropped at this point
86/// ```
87pub struct RwMutex<T: ?Sized> {
88    /// The internal representation of the mutex state is as follows:
89    /// - **Bit 63:** Writer mutex.
90    /// - **Bit 62:** Upgradeable reader mutex.
91    /// - **Bit 61:** Indicates if an upgradeable reader is being upgraded.
92    /// - **Bits 60-0:** Reader mutex count.
93    lock: AtomicUsize,
94    /// Threads that fail to acquire the mutex will sleep on this waitqueue.
95    queue: WaitQueue,
96    val: UnsafeCell<T>,
97}
98
99const READER: usize = 1;
100const WRITER: usize = 1 << (usize::BITS - 1);
101const UPGRADEABLE_READER: usize = 1 << (usize::BITS - 2);
102const BEING_UPGRADED: usize = 1 << (usize::BITS - 3);
103const MAX_READER: usize = 1 << (usize::BITS - 4);
104
105impl<T> RwMutex<T> {
106    /// Creates a new read-write mutex with an initial value.
107    pub const fn new(val: T) -> Self {
108        Self {
109            val: UnsafeCell::new(val),
110            lock: AtomicUsize::new(0),
111            queue: WaitQueue::new(),
112        }
113    }
114}
115
116impl<T: ?Sized> RwMutex<T> {
117    /// Acquires a read mutex and sleep until it can be acquired.
118    ///
119    /// The calling thread will sleep until there are no writers or upgrading
120    /// upreaders present. The implementation of [`WaitQueue`] guarantees the
121    /// order in which other concurrent readers or writers waiting simultaneously
122    /// will acquire the mutex.
123    #[track_caller]
124    pub fn read(&self) -> RwMutexReadGuard<'_, T> {
125        self.queue.wait_until(|| self.try_read())
126    }
127
128    /// Acquires a write mutex and sleep until it can be acquired.
129    ///
130    /// The calling thread will sleep until there are no writers, upreaders,
131    /// or readers present. The implementation of [`WaitQueue`] guarantees the
132    /// order in which other concurrent readers or writers waiting simultaneously
133    /// will acquire the mutex.
134    #[track_caller]
135    pub fn write(&self) -> RwMutexWriteGuard<'_, T> {
136        self.queue.wait_until(|| self.try_write())
137    }
138
139    /// Acquires a upread mutex and sleep until it can be acquired.
140    ///
141    /// The calling thread will sleep until there are no writers or upreaders present.
142    /// The implementation of [`WaitQueue`] guarantees the order in which other concurrent
143    /// readers or writers waiting simultaneously will acquire the mutex.
144    ///
145    /// Upreader will not block new readers until it tries to upgrade. Upreader
146    /// and reader do not differ before invoking the upgread method. However,
147    /// only one upreader can exist at any time to avoid deadlock in the
148    /// upgread method.
149    #[track_caller]
150    pub fn upread(&self) -> RwMutexUpgradeableGuard<'_, T> {
151        self.queue.wait_until(|| self.try_upread())
152    }
153
154    /// Attempts to acquire a read mutex.
155    ///
156    /// This function will never sleep and will return immediately.
157    pub fn try_read(&self) -> Option<RwMutexReadGuard<'_, T>> {
158        let lock = self.lock.fetch_add(READER, Acquire);
159        if lock & (WRITER | BEING_UPGRADED | MAX_READER) == 0 {
160            Some(RwMutexReadGuard { inner: self })
161        } else {
162            self.lock.fetch_sub(READER, Release);
163            None
164        }
165    }
166
167    /// Attempts to acquire a write mutex.
168    ///
169    /// This function will never sleep and will return immediately.
170    pub fn try_write(&self) -> Option<RwMutexWriteGuard<'_, T>> {
171        if self
172            .lock
173            .compare_exchange(0, WRITER, Acquire, Relaxed)
174            .is_ok()
175        {
176            Some(RwMutexWriteGuard { inner: self })
177        } else {
178            None
179        }
180    }
181
182    /// Attempts to acquire a upread mutex.
183    ///
184    /// This function will never sleep and will return immediately.
185    pub fn try_upread(&self) -> Option<RwMutexUpgradeableGuard<'_, T>> {
186        let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER);
187        if lock == 0 {
188            return Some(RwMutexUpgradeableGuard { inner: self });
189        } else if lock == WRITER {
190            self.lock.fetch_sub(UPGRADEABLE_READER, Release);
191        }
192        None
193    }
194
195    /// Returns a mutable reference to the underlying data.
196    ///
197    /// This method is zero-cost: By holding a mutable reference to the lock, the compiler has
198    /// already statically guaranteed that access to the data is exclusive.
199    pub fn get_mut(&mut self) -> &mut T {
200        self.val.get_mut()
201    }
202}
203
204impl<T: ?Sized + fmt::Debug> fmt::Debug for RwMutex<T> {
205    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
206        fmt::Debug::fmt(&self.val, f)
207    }
208}
209
210/// Because there can be more than one readers to get the T's immutable ref,
211/// so T must be Sync to guarantee the sharing safety.
212unsafe impl<T: ?Sized + Send> Send for RwMutex<T> {}
213unsafe impl<T: ?Sized + Send + Sync> Sync for RwMutex<T> {}
214
215impl<T: ?Sized> !Send for RwMutexWriteGuard<'_, T> {}
216unsafe impl<T: ?Sized + Sync> Sync for RwMutexWriteGuard<'_, T> {}
217
218impl<T: ?Sized> !Send for RwMutexReadGuard<'_, T> {}
219unsafe impl<T: ?Sized + Sync> Sync for RwMutexReadGuard<'_, T> {}
220
221impl<T: ?Sized> !Send for RwMutexUpgradeableGuard<'_, T> {}
222unsafe impl<T: ?Sized + Sync> Sync for RwMutexUpgradeableGuard<'_, T> {}
223
224/// A guard that provides immutable data access.
225pub struct RwMutexReadGuard<'a, T: ?Sized> {
226    inner: &'a RwMutex<T>,
227}
228
229impl<T: ?Sized> Deref for RwMutexReadGuard<'_, T> {
230    type Target = T;
231
232    fn deref(&self) -> &T {
233        unsafe { &*self.inner.val.get() }
234    }
235}
236
237impl<T: ?Sized> Drop for RwMutexReadGuard<'_, T> {
238    fn drop(&mut self) {
239        // When there are no readers, wake up a waiting writer.
240        if self.inner.lock.fetch_sub(READER, Release) == READER {
241            self.inner.queue.wake_one();
242        }
243    }
244}
245
246/// A guard that provides mutable data access.
247#[clippy::has_significant_drop]
248#[must_use]
249pub struct RwMutexWriteGuard<'a, T: ?Sized> {
250    inner: &'a RwMutex<T>,
251}
252
253impl<T: ?Sized> Deref for RwMutexWriteGuard<'_, T> {
254    type Target = T;
255
256    fn deref(&self) -> &T {
257        unsafe { &*self.inner.val.get() }
258    }
259}
260
261impl<'a, T: ?Sized> RwMutexWriteGuard<'a, T> {
262    /// Atomically downgrades a write guard to an upgradeable reader guard.
263    ///
264    /// This method always succeeds because the lock is exclusively held by the writer.
265    pub fn downgrade(mut self) -> RwMutexUpgradeableGuard<'a, T> {
266        loop {
267            self = match self.try_downgrade() {
268                Ok(guard) => return guard,
269                Err(e) => e,
270            };
271        }
272    }
273
274    /// This is not exposed as a public method to prevent intermediate lock states from affecting the
275    /// downgrade process.
276    fn try_downgrade(self) -> Result<RwMutexUpgradeableGuard<'a, T>, Self> {
277        let inner = self.inner;
278        let res = self
279            .inner
280            .lock
281            .compare_exchange(WRITER, UPGRADEABLE_READER, AcqRel, Relaxed);
282        if res.is_ok() {
283            drop(self);
284            Ok(RwMutexUpgradeableGuard { inner })
285        } else {
286            Err(self)
287        }
288    }
289}
290
291impl<T: ?Sized> DerefMut for RwMutexWriteGuard<'_, T> {
292    fn deref_mut(&mut self) -> &mut Self::Target {
293        unsafe { &mut *self.inner.val.get() }
294    }
295}
296
297impl<T: ?Sized> Drop for RwMutexWriteGuard<'_, T> {
298    fn drop(&mut self) {
299        self.inner.lock.fetch_and(!WRITER, Release);
300
301        // When the current writer releases, wake up all the sleeping threads.
302        // All awakened threads may include readers and writers.
303        // Thanks to the `wait_until` method, either all readers
304        // continue to execute or one writer continues to execute.
305        self.inner.queue.wake_all();
306    }
307}
308
309/// A guard that provides immutable data access but can be atomically
310/// upgraded to [`RwMutexWriteGuard`].
311pub struct RwMutexUpgradeableGuard<'a, T: ?Sized> {
312    inner: &'a RwMutex<T>,
313}
314
315impl<'a, T: ?Sized> RwMutexUpgradeableGuard<'a, T> {
316    /// Upgrades this upread guard to a write guard atomically.
317    ///
318    /// After calling this method, subsequent readers will be blocked
319    /// while previous readers remain unaffected.
320    ///
321    /// The calling thread will not sleep, but spin to wait for the existing
322    /// reader to be released. There are two main reasons.
323    /// - First, it needs to sleep in an extra waiting queue and needs extra wake-up logic and overhead.
324    /// - Second, upgrading method usually requires a high response time (because the mutex is being used now).
325    pub fn upgrade(mut self) -> RwMutexWriteGuard<'a, T> {
326        self.inner.lock.fetch_or(BEING_UPGRADED, Acquire);
327        loop {
328            self = match self.try_upgrade() {
329                Ok(guard) => return guard,
330                Err(e) => e,
331            };
332        }
333    }
334
335    /// Attempts to upgrade this upread guard to a write guard atomically.
336    ///
337    /// This function will return immediately.
338    pub fn try_upgrade(self) -> Result<RwMutexWriteGuard<'a, T>, Self> {
339        let res = self.inner.lock.compare_exchange(
340            UPGRADEABLE_READER | BEING_UPGRADED,
341            WRITER | UPGRADEABLE_READER,
342            AcqRel,
343            Relaxed,
344        );
345        if res.is_ok() {
346            let inner = self.inner;
347            drop(self);
348            Ok(RwMutexWriteGuard { inner })
349        } else {
350            Err(self)
351        }
352    }
353}
354
355impl<T: ?Sized> Deref for RwMutexUpgradeableGuard<'_, T> {
356    type Target = T;
357
358    fn deref(&self) -> &T {
359        unsafe { &*self.inner.val.get() }
360    }
361}
362
363impl<T: ?Sized> Drop for RwMutexUpgradeableGuard<'_, T> {
364    fn drop(&mut self) {
365        let res = self.inner.lock.fetch_sub(UPGRADEABLE_READER, Release);
366        if res == UPGRADEABLE_READER {
367            self.inner.queue.wake_all();
368        }
369    }
370}