ostd/sync/rwarc.rs
1// SPDX-License-Identifier: MPL-2.0
2
3use alloc::sync::Arc;
4use core::sync::atomic::{AtomicUsize, Ordering, fence};
5
6use super::{PreemptDisabled, RwLock, RwLockReadGuard, RwLockWriteGuard};
7
8/// A reference-counting pointer with read-write capabilities.
9///
10/// This is essentially `Arc<RwLock<T>>`, so it can provide read-write capabilities through
11/// [`RwArc::read`] and [`RwArc::write`].
12///
13/// In addition, this allows to derive another reference-counting pointer with read-only
14/// capabilities ([`RoArc`]) via [`RwArc::clone_ro`].
15///
16/// The purpose of having this type is to allow lockless (read) access to the underlying data when
17/// there is only one [`RwArc`] instance for the particular allocation (note that there can be any
18/// number of [`RoArc`] instances for that allocation). See the [`RwArc::get`] method for more
19/// details.
20pub struct RwArc<T>(Arc<Inner<T>>);
21
22/// A reference-counting pointer with read-only capabilities.
23///
24/// This type can be created from an existing [`RwArc`] using its [`RwArc::clone_ro`] method. See
25/// the type and method documentation for more details.
26pub struct RoArc<T>(Arc<Inner<T>>);
27
28struct Inner<T> {
29 data: RwLock<T>,
30 num_rw: AtomicUsize,
31}
32
33impl<T> RwArc<T> {
34 /// Creates a new `RwArc<T>`.
35 pub fn new(data: T) -> Self {
36 let inner = Inner {
37 data: RwLock::new(data),
38 num_rw: AtomicUsize::new(1),
39 };
40 Self(Arc::new(inner))
41 }
42
43 /// Acquires the read lock for immutable access.
44 pub fn read(&self) -> RwLockReadGuard<'_, T, PreemptDisabled> {
45 self.0.data.read()
46 }
47
48 /// Acquires the write lock for mutable access.
49 pub fn write(&self) -> RwLockWriteGuard<'_, T, PreemptDisabled> {
50 self.0.data.write()
51 }
52
53 /// Returns an immutable reference if no other `RwArc` points to the same allocation.
54 ///
55 /// This method is cheap because it does not acquire a lock.
56 ///
57 /// It's still sound because:
58 /// - The mutable reference to `self` and the condition ensure that we are exclusively
59 /// accessing the unique `RwArc` instance for the particular allocation.
60 /// - There may be any number of [`RoArc`]s pointing to the same allocation, but they may only
61 /// produce immutable references to the underlying data.
62 pub fn get(&mut self) -> Option<&T> {
63 if self.0.num_rw.load(Ordering::Relaxed) > 1 {
64 return None;
65 }
66
67 // This will synchronize with `RwArc::drop` to make sure its changes are visible to us.
68 fence(Ordering::Acquire);
69
70 let data_ptr = self.0.data.as_ptr();
71
72 // SAFETY: The data is valid. During the lifetime, no one will be able to create a mutable
73 // reference to the data, so it's okay to create an immutable reference like the one below.
74 Some(unsafe { &*data_ptr })
75 }
76
77 /// Clones a [`RoArc`] that points to the same allocation.
78 pub fn clone_ro(&self) -> RoArc<T> {
79 RoArc(self.0.clone())
80 }
81}
82
83impl<T> Clone for RwArc<T> {
84 fn clone(&self) -> Self {
85 let inner = self.0.clone();
86
87 // Note that overflowing the counter will make it unsound. But not to worry: the above
88 // `Arc::clone` must have already aborted the kernel before this happens.
89 inner.num_rw.fetch_add(1, Ordering::Relaxed);
90
91 Self(inner)
92 }
93}
94
95impl<T> Drop for RwArc<T> {
96 fn drop(&mut self) {
97 self.0.num_rw.fetch_sub(1, Ordering::Release);
98 }
99}
100
101impl<T: Clone> RwArc<T> {
102 /// Returns the contained value by cloning it.
103 pub fn get_cloned(&self) -> T {
104 let guard = self.read();
105 guard.clone()
106 }
107}
108
109impl<T> RoArc<T> {
110 /// Acquires the read lock for immutable access.
111 pub fn read(&self) -> RwLockReadGuard<'_, T, PreemptDisabled> {
112 self.0.data.read()
113 }
114}
115
116#[cfg(ktest)]
117mod test {
118 use super::*;
119 use crate::prelude::*;
120
121 #[ktest]
122 fn lockless_get() {
123 let mut rw1 = RwArc::new(1u32);
124 assert_eq!(rw1.get(), Some(1).as_ref());
125
126 let _ro = rw1.clone_ro();
127 assert_eq!(rw1.get(), Some(1).as_ref());
128
129 let rw2 = rw1.clone();
130 assert_eq!(rw1.get(), None);
131
132 drop(rw2);
133 assert_eq!(rw1.get(), Some(1).as_ref());
134 }
135}