ostd/sync/rcu/non_null/
either.rs

1// SPDX-License-Identifier: MPL-2.0
2
3use core::{marker::PhantomData, ptr::NonNull};
4
5use super::NonNullPtr;
6use crate::util::Either;
7
8// If both `L` and `R` have at least one alignment bit (i.e., their alignments are at least 2), we
9// can use the alignment bit to indicate whether a pointer is `L` or `R`, so it's possible to
10// implement `NonNullPtr` for `Either<L, R>`.
11unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
12    type Target = PhantomData<Self>;
13
14    type Ref<'a>
15        = Either<L::Ref<'a>, R::Ref<'a>>
16    where
17        Self: 'a;
18
19    const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS)
20        .checked_sub(1)
21        .expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer");
22
23    fn into_raw(self) -> NonNull<Self::Target> {
24        match self {
25            Self::Left(left) => left.into_raw().cast(),
26            Self::Right(right) => right
27                .into_raw()
28                .map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
29                .cast(),
30        }
31    }
32
33    unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self {
34        // SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
35        // guarantees that `real_ptr` is a non-null pointer.
36        let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
37
38        if is_right == 0 {
39            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
40            // safety requirements are upheld by the caller.
41            Either::Left(unsafe { L::from_raw(real_ptr.cast()) })
42        } else {
43            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
44            // safety requirements are upheld by the caller.
45            Either::Right(unsafe { R::from_raw(real_ptr.cast()) })
46        }
47    }
48
49    unsafe fn raw_as_ref<'a>(raw: NonNull<Self::Target>) -> Self::Ref<'a> {
50        // SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
51        // guarantees that `real_ptr` is a non-null pointer.
52        let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
53
54        if is_right == 0 {
55            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
56            // safety requirements are upheld by the caller.
57            Either::Left(unsafe { L::raw_as_ref(real_ptr.cast()) })
58        } else {
59            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
60            // safety requirements are upheld by the caller.
61            Either::Right(unsafe { R::raw_as_ref(real_ptr.cast()) })
62        }
63    }
64
65    fn ref_as_raw(ptr_ref: Self::Ref<'_>) -> NonNull<Self::Target> {
66        match ptr_ref {
67            Either::Left(left) => L::ref_as_raw(left).cast(),
68            Either::Right(right) => R::ref_as_raw(right)
69                .map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
70                .cast(),
71        }
72    }
73}
74
75// A `min` implementation for use in constant evaluation.
76const fn min(a: u32, b: u32) -> u32 {
77    if a < b { a } else { b }
78}
79
80/// # Safety
81///
82/// The caller must ensure that removing the bits from the non-null pointer will result in another
83/// non-null pointer.
84unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
85    use core::num::NonZeroUsize;
86
87    let removed_bits = ptr.addr().get() & bits;
88    let result_ptr = ptr.map_addr(|addr|
89        // SAFETY: The safety is upheld by the caller.
90        unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) });
91
92    (removed_bits, result_ptr)
93}
94
95#[cfg(ktest)]
96mod test {
97    use alloc::{boxed::Box, sync::Arc};
98
99    use super::*;
100    use crate::{prelude::ktest, sync::RcuOption};
101
102    type Either32 = Either<Arc<u32>, Box<u32>>;
103    type Either16 = Either<Arc<u32>, Box<u16>>;
104
105    #[ktest]
106    fn alignment() {
107        assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
108        assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
109    }
110
111    #[ktest]
112    fn left_pointer() {
113        let val: Either16 = Either::Left(Arc::new(123));
114
115        let ptr = NonNullPtr::into_raw(val);
116        assert_eq!(ptr.addr().get() & 1, 0);
117
118        let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
119        assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
120
121        let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
122        assert_eq!(ptr, ptr2);
123
124        let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
125        assert!(matches!(val, Either::Left(ref r) if **r == 123));
126        drop(val);
127    }
128
129    #[ktest]
130    fn right_pointer() {
131        let val: Either16 = Either::Right(Box::new(456));
132
133        let ptr = NonNullPtr::into_raw(val);
134        assert_eq!(ptr.addr().get() & 1, 1);
135
136        let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
137        assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
138
139        let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
140        assert_eq!(ptr, ptr2);
141
142        let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
143        assert!(matches!(val, Either::Right(ref r) if **r == 456));
144        drop(val);
145    }
146
147    #[ktest]
148    fn rcu_store_load() {
149        let rcu: RcuOption<Either32> = RcuOption::new_none();
150        assert!(rcu.read().get().is_none());
151
152        rcu.update(Some(Either::Left(Arc::new(888))));
153        assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
154
155        rcu.update(Some(Either::Right(Box::new(999))));
156        assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
157    }
158}