ostd/sync/rcu/non_null/
either.rs1use core::{marker::PhantomData, ptr::NonNull};
4
5use super::NonNullPtr;
6use crate::util::Either;
7
8unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
12 type Target = PhantomData<Self>;
13
14 type Ref<'a>
15 = Either<L::Ref<'a>, R::Ref<'a>>
16 where
17 Self: 'a;
18
19 const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS)
20 .checked_sub(1)
21 .expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer");
22
23 fn into_raw(self) -> NonNull<Self::Target> {
24 match self {
25 Self::Left(left) => left.into_raw().cast(),
26 Self::Right(right) => right
27 .into_raw()
28 .map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
29 .cast(),
30 }
31 }
32
33 unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self {
34 let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
37
38 if is_right == 0 {
39 Either::Left(unsafe { L::from_raw(real_ptr.cast()) })
42 } else {
43 Either::Right(unsafe { R::from_raw(real_ptr.cast()) })
46 }
47 }
48
49 unsafe fn raw_as_ref<'a>(raw: NonNull<Self::Target>) -> Self::Ref<'a> {
50 let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
53
54 if is_right == 0 {
55 Either::Left(unsafe { L::raw_as_ref(real_ptr.cast()) })
58 } else {
59 Either::Right(unsafe { R::raw_as_ref(real_ptr.cast()) })
62 }
63 }
64
65 fn ref_as_raw(ptr_ref: Self::Ref<'_>) -> NonNull<Self::Target> {
66 match ptr_ref {
67 Either::Left(left) => L::ref_as_raw(left).cast(),
68 Either::Right(right) => R::ref_as_raw(right)
69 .map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
70 .cast(),
71 }
72 }
73}
74
75const fn min(a: u32, b: u32) -> u32 {
77 if a < b { a } else { b }
78}
79
80unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
85 use core::num::NonZeroUsize;
86
87 let removed_bits = ptr.addr().get() & bits;
88 let result_ptr = ptr.map_addr(|addr|
89 unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) });
91
92 (removed_bits, result_ptr)
93}
94
95#[cfg(ktest)]
96mod test {
97 use alloc::{boxed::Box, sync::Arc};
98
99 use super::*;
100 use crate::{prelude::ktest, sync::RcuOption};
101
102 type Either32 = Either<Arc<u32>, Box<u32>>;
103 type Either16 = Either<Arc<u32>, Box<u16>>;
104
105 #[ktest]
106 fn alignment() {
107 assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
108 assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
109 }
110
111 #[ktest]
112 fn left_pointer() {
113 let val: Either16 = Either::Left(Arc::new(123));
114
115 let ptr = NonNullPtr::into_raw(val);
116 assert_eq!(ptr.addr().get() & 1, 0);
117
118 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
119 assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
120
121 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
122 assert_eq!(ptr, ptr2);
123
124 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
125 assert!(matches!(val, Either::Left(ref r) if **r == 123));
126 drop(val);
127 }
128
129 #[ktest]
130 fn right_pointer() {
131 let val: Either16 = Either::Right(Box::new(456));
132
133 let ptr = NonNullPtr::into_raw(val);
134 assert_eq!(ptr.addr().get() & 1, 1);
135
136 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
137 assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
138
139 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
140 assert_eq!(ptr, ptr2);
141
142 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
143 assert!(matches!(val, Either::Right(ref r) if **r == 456));
144 drop(val);
145 }
146
147 #[ktest]
148 fn rcu_store_load() {
149 let rcu: RcuOption<Either32> = RcuOption::new_none();
150 assert!(rcu.read().get().is_none());
151
152 rcu.update(Some(Either::Left(Arc::new(888))));
153 assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
154
155 rcu.update(Some(Either::Right(Box::new(999))));
156 assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
157 }
158}