Skip to main content

ostd/sync/rcu/non_null/
either.rs

1// SPDX-License-Identifier: MPL-2.0
2use core::{marker::PhantomData, ptr::NonNull};
3
4use vstd::raw_ptr::group_raw_ptr_axioms;
5use vstd::{bits, prelude::*};
6use vstd_extra::{external::nonzero::*, prelude::*, sum::Sum};
7
8use super::{NonNullPtr, NonNullPtrRef};
9use crate::util::Either;
10
11verus! {
12
13broadcast use {group_nonull_axioms, group_raw_ptr_axioms, group_nonzero_axioms};
14
15proof fn lemma_aligned_addr_clears_tag_bit(
16    addr: usize,
17    tag: usize,
18    align_bits: u32,
19    ptr_align_bits: u32,
20)
21    requires
22        addr % (1usize << ptr_align_bits) == 0,
23        tag == 1usize << align_bits,
24        align_bits < ptr_align_bits < usize::BITS,
25    ensures
26        addr & tag == 0,
27{
28    assert(addr & tag == 0) by (bit_vector)
29        requires
30            addr % (1usize << ptr_align_bits) == 0,
31            tag == 1usize << align_bits,
32            align_bits < ptr_align_bits < usize::BITS,
33    ;
34}
35
36// If both `L` and `R` have at least one alignment bit (i.e., their alignments are at least 2), we
37// can use the alignment bit to indicate whether a pointer is `L` or `R`, so it's possible to
38// implement `NonNullPtr` for `Either<L, R>`.
39unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
40    type Target = PhantomData<Self>;
41
42    // type Ref<'a>
43    //     = Either<L::Ref<'a>, R::Ref<'a>>
44    // where
45    //     Self: 'a;
46    type Permission = Sum<L::Permission, R::Permission>;
47
48    #[verifier::external_body]
49    const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS).checked_sub(1).expect(
50        "`L` and `R` alignments should be at least 2 to pack `Either` into one pointer",
51    );
52
53    #[verus_spec]
54    #[verifier::spinoff_prover]
55    fn into_raw(self) -> (ret: (NonNull<Self::Target>, Tracked<Self::Permission>)) {
56        proof_decl!{
57           let ghost align_bits = Self::ALIGN_BITS;
58           let ghost l_align_bits = L::ALIGN_BITS;
59           let ghost r_align_bits = R::ALIGN_BITS;
60           let ghost tag = 1usize << align_bits;
61        }
62        proof! {
63            L::lemma_align_bits_range();
64            R::lemma_align_bits_range();
65            Self::lemma_align_bits_range();
66            vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
67            vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
68            vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
69            vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
70            vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
71            vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
72        }
73        match self {
74            Self::Left(left) => {
75                // left.into_raw().cast(),
76                let (left, Tracked(perm)) = left.into_raw();
77                proof! {
78                    let left_addr = left.cast::<Self::Target>().view_ptr_mut().addr();
79                    let extra_bits: u32 = (l_align_bits - align_bits) as u32;
80                    let scale = 1usize << extra_bits;
81                    vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
82                    vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
83                    vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
84                    assert(left_addr % tag == 0) by {
85                        let big = 1usize << l_align_bits;
86                        let q = left_addr / big;
87                        vstd::arithmetic::div_mod::lemma_fundamental_div_mod(left_addr as int, big as int);
88                        assert(left_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
89                        requires
90                            left_addr as int == q as int * big as int,
91                            big == tag * scale,
92                        ;
93                        vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
94                    };
95                    lemma_aligned_addr_clears_tag_bit(left_addr, tag, align_bits, l_align_bits);
96                }
97                (left.cast(), Tracked(Sum::Left(perm)))
98            },
99            Self::Right(right) => {
100                /* right
101                .into_raw()
102                .map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
103                .cast(), */
104                let (right, Tracked(perm)) = right.into_raw();
105                let right_tagged = right.map_addr_v(
106                    |addr: NonZeroUsize| -> (ret: NonZeroUsize)
107                        ensures
108                            ret@ == addr@ | (1usize << Self::ALIGN_BITS),
109                        { addr | 1usize << Self::ALIGN_BITS },
110                );
111                proof! {
112                    let addr = right.addr_spec()@;
113                    let tagged_addr = right_tagged.addr_spec()@;
114                    assert(tagged_addr & tag == tag) by (bit_vector)
115                    requires
116                        tagged_addr == addr | tag,
117                        tag == 1usize << align_bits,
118                        1 <= r_align_bits < usize::BITS,
119                        align_bits < r_align_bits,
120                        addr % (1usize << r_align_bits) == 0,
121                        addr != 0,
122                    ;
123                    let extra_bits: u32 = (r_align_bits - align_bits) as u32;
124                    let scale = 1usize << extra_bits;
125                    vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
126                    vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
127                    vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
128                    assert(tagged_addr == addr + tag) by (bit_vector)
129                    requires
130                        tagged_addr == addr | tag,
131                        tag == 1usize << align_bits,
132                        1 <= r_align_bits < usize::BITS,
133                        align_bits < r_align_bits,
134                        addr % (1usize << r_align_bits) == 0,
135                        addr != 0,
136                    ;
137                    assert(addr % tag == 0) by {
138                        let big = 1usize << r_align_bits;
139                        let q = addr / big;
140                        vstd::arithmetic::div_mod::lemma_fundamental_div_mod(addr as int, big as int);
141                        assert(addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
142                        requires
143                            addr as int == q as int * big as int,
144                            big == tag * scale,
145                        ;
146                        vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
147                    }
148                    lemma_aligned_addr_clears_tag_bit(addr, tag, align_bits, r_align_bits);
149                    assert(tagged_addr & !tag == addr) by (bit_vector)
150                    requires
151                        tagged_addr == addr | tag,
152                        addr & tag == 0,
153                    ;
154                    assert(tagged_addr % (1usize << align_bits) == 0) by {
155                        vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(addr as int, tag as int);
156                    }
157                }
158                (right_tagged.cast(), Tracked(Sum::Right(perm)))
159            },
160        }
161    }
162
163    unsafe fn from_raw(
164        ptr: NonNull<Self::Target>,
165        Tracked(perm): Tracked<Self::Permission>,
166    ) -> Self {
167        proof! {
168            Self::lemma_align_bits_range();
169        }
170        proof_decl! {
171            let ghost align_bits = Self::ALIGN_BITS;
172            let ghost tag = 1usize << Self::ALIGN_BITS;
173            let ghost ptr_addr = ptr.view_ptr_mut()@.addr;
174        }
175        proof! {
176            assert(tag > 0) by (bit_vector)
177            requires
178                tag == 1usize << align_bits,
179                align_bits < usize::BITS,
180            ;
181            match perm {
182                Sum::Left(_) => {
183                    assert((ptr_addr & !tag) == ptr_addr) by (bit_vector)
184                    requires
185                        ptr_addr & tag == 0,
186                    ;
187                },
188                Sum::Right(_) => {
189                    assert((ptr_addr & tag) < ptr_addr) by (bit_vector)
190                    requires
191                        ptr_addr & tag == tag,
192                        (ptr_addr & !tag) != 0,
193                    ;
194                },
195            }
196        }
197        // SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
198        // guarantees that `real_ptr` is a non-null pointer.
199        let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
200
201        if is_right == 0 {
202            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
203            // safety requirements are upheld by the caller.
204            Either::Left(unsafe { L::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_left())) })
205        } else {
206            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
207            // safety requirements are upheld by the caller.
208            Either::Right(
209                unsafe { R::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
210            )
211        }
212    }
213
214    open spec fn ptr_perm_match(ptr: NonNull<Self::Target>, perm: Self::Permission) -> bool {
215        let tag = 1usize << Self::ALIGN_BITS;
216        match perm {
217            Sum::Left(left) => {
218                &&& ptr.view_ptr_mut().addr() & tag == 0
219                &&& L::ptr_perm_match(ptr.cast(), left)
220            },
221            Sum::Right(right) => {
222                let untagged_ptr = ptr.view_ptr_mut().with_addr((ptr.view_ptr_mut().addr() & !tag));
223                let right_nonnull = nonnull_from_ptr_mut_spec(untagged_ptr);
224                &&& ptr.view_ptr_mut().addr() & tag == tag
225                &&& (ptr.view_ptr_mut().addr() & !tag) != 0
226                &&& R::ptr_perm_match(right_nonnull.cast(), right)
227            },
228        }
229    }
230
231    open spec fn rel_perm(self, perm: Self::Permission) -> bool {
232        match (self, perm) {
233            (Either::Left(left), Sum::Left(left_perm)) => left.rel_perm(left_perm),
234            (Either::Right(right), Sum::Right(right_perm)) => right.rel_perm(right_perm),
235            _ => false,
236        }
237    }
238
239    axiom fn lemma_align_bits_range()
240        ensures
241            Self::ALIGN_BITS == if L::ALIGN_BITS < R::ALIGN_BITS {
242                L::ALIGN_BITS - 1
243            } else {
244                R::ALIGN_BITS - 1
245            },
246    ;
247}
248
249unsafe impl<'a, L: NonNullPtrRef<'a>, R: NonNullPtrRef<'a>> NonNullPtrRef<'a> for Either<L, R> {
250    type Ref = Either<L::Ref, R::Ref>;
251
252    type RefPermission = Sum<L::RefPermission, R::RefPermission>;
253
254    open spec fn ref_perm_view_permission(perm: Self::RefPermission) -> Self::Permission {
255        match perm {
256            Sum::Left(left) => Sum::Left(L::ref_perm_view_permission(left)),
257            Sum::Right(right) => Sum::Right(R::ref_perm_view_permission(right)),
258        }
259    }
260
261    open spec fn ref_rel_perm(r: Self::Ref, perm: Self::RefPermission) -> bool {
262        true
263    }
264
265    proof fn lemma_ref_perm_inv_impl_perm_inv(perm: Self::RefPermission) {
266        match perm {
267            Sum::Left(left) => L::lemma_ref_perm_inv_impl_perm_inv(left),
268            Sum::Right(right) => R::lemma_ref_perm_inv_impl_perm_inv(right),
269        }
270    }
271
272    unsafe fn raw_as_ref(
273        raw: NonNull<Self::Target>,
274        Tracked(perm): Tracked<Self::RefPermission>,
275    ) -> Self::Ref {
276        proof_decl! {
277            let ghost align_bits = Self::ALIGN_BITS;
278            let ghost tag = 1usize << align_bits;
279            let ghost raw_addr = raw.view_ptr_mut()@.addr;
280        }
281        proof! {
282            Self::lemma_align_bits_range();
283            if perm is Left {
284                assert((raw_addr & !tag) == raw_addr) by (bit_vector)
285                    requires
286                        raw_addr & tag == 0,
287                    ;
288            } else {
289                assert((raw_addr & tag) < raw_addr) by (bit_vector)
290                requires
291                    raw_addr & tag == tag,
292                    (raw_addr & !tag) != 0,
293                ;
294            }
295        }
296        // SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
297        // guarantees that `real_ptr` is a non-null pointer.
298        let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
299
300        if is_right == 0 {
301            proof!{
302                if perm is Right {
303                    assert(tag != 0) by (bit_vector)
304                    requires
305                        tag == 1usize << align_bits,
306                        align_bits < usize::BITS,
307                    ;
308                    assert(false);
309                }
310            }
311            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
312            // safety requirements are upheld by the caller.
313            Either::Left(
314                unsafe { L::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_left())) },
315            )
316        } else {
317            // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
318            // safety requirements are upheld by the caller.
319            Either::Right(
320                unsafe { R::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
321            )
322        }
323    }
324
325    fn ref_as_raw(ptr_ref: Self::Ref) -> (NonNull<Self::Target>, Tracked<Self::RefPermission>) {
326        proof!{
327            Self::lemma_align_bits_range();
328        }
329        proof_decl!{
330            let ghost align_bits = Self::ALIGN_BITS;
331            let ghost tag = 1usize << align_bits;
332            let ghost l_align_bits = L::ALIGN_BITS;
333            let ghost r_align_bits = R::ALIGN_BITS;
334        }
335        proof!{
336            L::lemma_align_bits_range();
337            R::lemma_align_bits_range();
338            vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
339            vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
340            vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
341            vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
342            vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
343            vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
344        }
345        match ptr_ref {
346            Either::Left(left) => {
347                // L::ref_as_raw(left).cast()
348                let (ptr, Tracked(perm)) = L::ref_as_raw(left);
349                proof! {
350                    let ghost ptr_addr = ptr.view_ptr_mut().addr();
351                    L::lemma_ref_perm_inv_impl_perm_inv(perm);
352                    let extra_bits: u32 = (l_align_bits - align_bits) as u32;
353                    let scale = 1usize << extra_bits;
354                    vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
355                    vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
356                    vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
357                    assert(ptr_addr % tag == 0) by {
358                        let big = 1usize << l_align_bits;
359                        let q = ptr_addr / big;
360                        vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
361                        assert(ptr_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
362                        requires
363                            ptr_addr as int == q as int * big as int,
364                            big == tag * scale,
365                        ;
366                        vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
367                    };
368                    assert(ptr_addr & tag == 0) by (bit_vector)
369                    requires
370                        ptr_addr % (1usize << l_align_bits) == 0,
371                        tag == 1usize << align_bits,
372                        align_bits < l_align_bits < usize::BITS,
373                    ;
374                }
375                (ptr.cast(), Tracked(Sum::Left(perm)))
376            },
377            Either::Right(right) => {
378                /* R::ref_as_raw(right)
379                .map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
380                .cast() */
381                let (ptr, Tracked(perm)) = R::ref_as_raw(right);
382                proof! {
383                    Self::lemma_align_bits_range();
384                }
385                let tagged_ptr = ptr.map_addr_v(
386                    |addr: NonZeroUsize| -> (ret: NonZeroUsize)
387                        ensures
388                            ret@ == addr@ | (1usize << Self::ALIGN_BITS),
389                        { addr | 1usize << Self::ALIGN_BITS },
390                );
391                proof! {
392                    let ghost ptr_addr = ptr.view_ptr_mut().addr();
393                    let ghost tagged_addr = tagged_ptr.view_ptr_mut().addr();
394                    R::lemma_ref_perm_inv_impl_perm_inv(perm);
395                    assert(tagged_addr & tag == tag) by (bit_vector)
396                    requires
397                        tagged_addr == ptr_addr | tag,
398                        tag == 1usize << align_bits,
399                        align_bits < r_align_bits < usize::BITS,
400                        ptr_addr % (1usize << r_align_bits) == 0,
401                        ptr_addr != 0,
402                    ;
403                    assert(ptr_addr & tag == 0) by (bit_vector)
404                    requires
405                        ptr_addr % (1usize << r_align_bits) == 0,
406                        tag == 1usize << align_bits,
407                        align_bits < r_align_bits < usize::BITS,
408                    ;
409                    assert(tagged_addr & !tag == ptr_addr) by (bit_vector)
410                    requires
411                        tagged_addr == ptr_addr | tag,
412                        ptr_addr & tag == 0,
413                    ;
414                    let extra_bits: u32 = (r_align_bits - align_bits) as u32;
415                    let scale = 1usize << extra_bits;
416                    vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
417                    vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
418                    vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
419                    assert(tagged_addr == ptr_addr + tag) by (bit_vector)
420                    requires
421                        tagged_addr == ptr_addr | tag,
422                        ptr_addr & tag == 0,
423                    ;
424                    assert(ptr_addr % tag == 0) by {
425                        let big = 1usize << r_align_bits;
426                        let q = ptr_addr / big;
427                        vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
428                        assert(ptr_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
429                        requires
430                            ptr_addr as int == q as int * big as int,
431                            big == tag * scale,
432                        ;
433                        vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
434                    };
435                    assert(tagged_addr % tag == 0) by {
436                        vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(ptr_addr as int, tag as int);
437                    }
438                }
439                (tagged_ptr.cast(), Tracked(Sum::Right(perm)))
440            },
441        }
442    }
443}
444
445// A `min` implementation for use in constant evaluation.
446#[vstd::contrib::auto_spec]
447const fn min(a: u32, b: u32) -> u32 {
448    if a < b {
449        a
450    } else {
451        b
452    }
453}
454
455/// # Safety
456///
457/// The caller must ensure that removing the bits from the non-null pointer will result in another
458/// non-null pointer.
459#[verus_spec(ret =>
460    requires
461        (ptr.view_ptr_mut().addr() & bits) < ptr.view_ptr_mut().addr(),
462        (ptr.view_ptr_mut().addr() & !bits) != 0,
463    ensures
464        ret.0 == (ptr.view_ptr_mut().addr() & bits),
465        ret.1.view_ptr_mut() == ptr.view_ptr_mut().with_addr((ptr.view_ptr_mut().addr() & !bits) as usize),
466)]
467unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
468    // use core::num::NonZeroUsize;
469    use vstd_extra::external::nonzero::NonZeroUsize;
470
471    let removed_bits = ptr.addr_v().get() & bits;
472    let result_ptr = ptr.map_addr_v(
473        |addr| -> (ret: NonZeroUsize)
474            ensures
475                ret@ == addr@ & !bits,
476            {
477                // SAFETY: The safety is upheld by the caller.
478                unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) }
479            },
480    );
481    (removed_bits, result_ptr)
482}
483
484} // verus!
485#[cfg(ktest)]
486mod test {
487    use alloc::{boxed::Box, sync::Arc};
488
489    use super::*;
490    use crate::{prelude::ktest, sync::RcuOption};
491
492    type Either32 = Either<Arc<u32>, Box<u32>>;
493    type Either16 = Either<Arc<u32>, Box<u16>>;
494
495    #[ktest]
496    fn alignment() {
497        assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
498        assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
499    }
500
501    #[ktest]
502    fn left_pointer() {
503        let val: Either16 = Either::Left(Arc::new(123));
504
505        let ptr = NonNullPtr::into_raw(val);
506        assert_eq!(ptr.addr().get() & 1, 0);
507
508        let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
509        assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
510
511        let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
512        assert_eq!(ptr, ptr2);
513
514        let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
515        assert!(matches!(val, Either::Left(ref r) if **r == 123));
516        drop(val);
517    }
518
519    #[ktest]
520    fn right_pointer() {
521        let val: Either16 = Either::Right(Box::new(456));
522
523        let ptr = NonNullPtr::into_raw(val);
524        assert_eq!(ptr.addr().get() & 1, 1);
525
526        let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
527        assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
528
529        let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
530        assert_eq!(ptr, ptr2);
531
532        let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
533        assert!(matches!(val, Either::Right(ref r) if **r == 456));
534        drop(val);
535    }
536
537    #[ktest]
538    fn rcu_store_load() {
539        let rcu: RcuOption<Either32> = RcuOption::new_none();
540        assert!(rcu.read().get().is_none());
541
542        rcu.update(Some(Either::Left(Arc::new(888))));
543        assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
544
545        rcu.update(Some(Either::Right(Box::new(999))));
546        assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
547    }
548}