ostd/sync/rcu/non_null/
either.rsuse core::{marker::PhantomData, ptr::NonNull};
use super::NonNullPtr;
use crate::util::Either;
unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
type Target = PhantomData<Self>;
type Ref<'a>
= Either<L::Ref<'a>, R::Ref<'a>>
where
Self: 'a;
const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS)
.checked_sub(1)
.expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer");
fn into_raw(self) -> NonNull<Self::Target> {
match self {
Self::Left(left) => left.into_raw().cast(),
Self::Right(right) => right
.into_raw()
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
.cast(),
}
}
unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self {
let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
if is_right == 0 {
Either::Left(unsafe { L::from_raw(real_ptr.cast()) })
} else {
Either::Right(unsafe { R::from_raw(real_ptr.cast()) })
}
}
unsafe fn raw_as_ref<'a>(raw: NonNull<Self::Target>) -> Self::Ref<'a> {
let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
if is_right == 0 {
Either::Left(unsafe { L::raw_as_ref(real_ptr.cast()) })
} else {
Either::Right(unsafe { R::raw_as_ref(real_ptr.cast()) })
}
}
fn ref_as_raw(ptr_ref: Self::Ref<'_>) -> NonNull<Self::Target> {
match ptr_ref {
Either::Left(left) => L::ref_as_raw(left).cast(),
Either::Right(right) => R::ref_as_raw(right)
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
.cast(),
}
}
}
const fn min(a: u32, b: u32) -> u32 {
if a < b {
a
} else {
b
}
}
unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
use core::num::NonZeroUsize;
let removed_bits = ptr.addr().get() & bits;
let result_ptr = ptr.map_addr(|addr|
unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) });
(removed_bits, result_ptr)
}
#[cfg(ktest)]
mod test {
use alloc::{boxed::Box, sync::Arc};
use super::*;
use crate::{prelude::ktest, sync::RcuOption};
type Either32 = Either<Arc<u32>, Box<u32>>;
type Either16 = Either<Arc<u32>, Box<u16>>;
#[ktest]
fn alignment() {
assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
}
#[ktest]
fn left_pointer() {
let val: Either16 = Either::Left(Arc::new(123));
let ptr = NonNullPtr::into_raw(val);
assert_eq!(ptr.addr().get() & 1, 0);
let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
assert_eq!(ptr, ptr2);
let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
assert!(matches!(val, Either::Left(ref r) if **r == 123));
drop(val);
}
#[ktest]
fn right_pointer() {
let val: Either16 = Either::Right(Box::new(456));
let ptr = NonNullPtr::into_raw(val);
assert_eq!(ptr.addr().get() & 1, 1);
let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
assert_eq!(ptr, ptr2);
let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
assert!(matches!(val, Either::Right(ref r) if **r == 456));
drop(val);
}
#[ktest]
fn rcu_store_load() {
let rcu: RcuOption<Either32> = RcuOption::new_none();
assert!(rcu.read().get().is_none());
rcu.update(Some(Either::Left(Arc::new(888))));
assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
rcu.update(Some(Either::Right(Box::new(999))));
assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
}
}