1use core::{marker::PhantomData, ptr::NonNull};
3
4use vstd::raw_ptr::group_raw_ptr_axioms;
5use vstd::{bits, prelude::*};
6use vstd_extra::{external::nonzero::*, prelude::*, sum::Sum};
7
8use super::{NonNullPtr, NonNullPtrRef};
9use crate::util::Either;
10
11verus! {
12
13broadcast use {group_nonull_axioms, group_raw_ptr_axioms, group_nonzero_axioms};
14unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
19 type Target = PhantomData<Self>;
20
21 type Permission = Sum<L::Permission, R::Permission>;
26
27 #[verifier::external_body]
28 const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS).checked_sub(1).expect(
29 "`L` and `R` alignments should be at least 2 to pack `Either` into one pointer",
30 );
31
32 #[verifier::spinoff_prover]
33 fn into_raw(self) -> (ret: (NonNull<Self::Target>, Tracked<Self::Permission>)) {
34 proof_decl!{
35 let ghost align_bits = Self::ALIGN_BITS;
36 let ghost l_align_bits = L::ALIGN_BITS;
37 let ghost r_align_bits = R::ALIGN_BITS;
38 let ghost tag = 1usize << align_bits;
39 }
40 proof! {
41 L::lemma_align_bits_range();
42 R::lemma_align_bits_range();
43 Self::lemma_align_bits_range();
44 vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
45 vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
46 vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
47 vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
48 vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
49 vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
50 }
51 match self {
52 Self::Left(left) => {
53 let (left, Tracked(perm)) = left.into_raw();
55 proof! {
56 let left_addr = left.cast::<Self::Target>().view_ptr_mut().addr();
57 let extra_bits: u32 = (l_align_bits - align_bits) as u32;
58 let scale = 1usize << extra_bits;
59 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
60 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
61 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
62 assert(left_addr % tag == 0) by {
63 let big = 1usize << l_align_bits;
64 let q = left_addr / big;
65 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(left_addr as int, big as int);
66 assert(left_addr == q * scale * tag) by (nonlinear_arith)
67 requires
68 left_addr == q * big,
69 big == tag * scale,
70 ;
71 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q * scale, tag as int);
72 };
73 lemma_aligned_addr_clears_tag_bit(left_addr, tag, align_bits, l_align_bits);
74 }
75 (left.cast(), Tracked(Sum::Left(perm)))
76 },
77 Self::Right(right) => {
78 let (right, Tracked(perm)) = right.into_raw();
83 let right_tagged = right.map_addr_v(
84 |addr: NonZeroUsize| -> (ret: NonZeroUsize)
85 ensures
86 ret@ == addr@ | (1usize << Self::ALIGN_BITS),
87 { addr | 1usize << Self::ALIGN_BITS },
88 );
89 proof! {
90 let addr = right.addr_spec()@;
91 let tagged_addr = right_tagged.addr_spec()@;
92 assert(tagged_addr & tag == tag) by (bit_vector)
93 requires
94 tagged_addr == addr | tag,
95 tag == 1usize << align_bits,
96 1 <= r_align_bits < usize::BITS,
97 align_bits < r_align_bits,
98 addr % (1usize << r_align_bits) == 0,
99 addr != 0,
100 ;
101 let extra_bits: u32 = (r_align_bits - align_bits) as u32;
102 let scale = 1usize << extra_bits;
103 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
104 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
105 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
106 assert(tagged_addr == addr + tag) by (bit_vector)
107 requires
108 tagged_addr == addr | tag,
109 tag == 1usize << align_bits,
110 1 <= r_align_bits < usize::BITS,
111 align_bits < r_align_bits,
112 addr % (1usize << r_align_bits) == 0,
113 addr != 0,
114 ;
115 assert(addr % tag == 0) by {
116 let big = 1usize << r_align_bits;
117 let q = addr / big;
118 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(addr as int, big as int);
119 assert(addr == q * scale * tag) by (nonlinear_arith)
120 requires
121 addr == q * big,
122 big == tag * scale,
123 ;
124 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q * scale, tag as int);
125 }
126 lemma_aligned_addr_clears_tag_bit(addr, tag, align_bits, r_align_bits);
127 assert(tagged_addr & !tag == addr) by (bit_vector)
128 requires
129 tagged_addr == addr | tag,
130 addr & tag == 0,
131 ;
132 assert(tagged_addr % (1usize << align_bits) == 0) by {
133 vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(addr as int, tag as int);
134 }
135 }
136 (right_tagged.cast(), Tracked(Sum::Right(perm)))
137 },
138 }
139 }
140
141 unsafe fn from_raw(
142 ptr: NonNull<Self::Target>,
143 Tracked(perm): Tracked<Self::Permission>,
144 ) -> Self {
145 proof! {
146 Self::lemma_align_bits_range();
147 }
148 proof_decl! {
149 let ghost align_bits = Self::ALIGN_BITS;
150 let ghost tag = 1usize << Self::ALIGN_BITS;
151 let ghost ptr_addr = ptr.view_ptr_mut()@.addr;
152 }
153 proof! {
154 assert(tag > 0) by (bit_vector)
155 requires
156 tag == 1usize << align_bits,
157 align_bits < usize::BITS,
158 ;
159 match perm {
160 Sum::Left(_) => {
161 assert((ptr_addr & !tag) == ptr_addr) by (bit_vector)
162 requires
163 ptr_addr & tag == 0,
164 ;
165 },
166 Sum::Right(_) => {
167 assert((ptr_addr & tag) < ptr_addr) by (bit_vector)
168 requires
169 ptr_addr & tag == tag,
170 (ptr_addr & !tag) != 0,
171 ;
172 },
173 }
174 }
175 let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
178
179 if is_right == 0 {
180 Either::Left(unsafe { L::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_left())) })
183 } else {
184 Either::Right(
187 unsafe { R::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
188 )
189 }
190 }
191
192 open spec fn ptr_perm_match(ptr: *mut Self::Target, perm: Self::Permission) -> bool {
193 let tag = 1usize << Self::ALIGN_BITS;
194 match perm {
195 Sum::Left(left) => {
196 &&& ptr.addr() & tag == 0
197 &&& L::ptr_perm_match(ptr.cast(), left)
198 },
199 Sum::Right(right) => {
200 let untagged_ptr = ptr.with_addr((ptr.addr() & !tag));
201 let right_nonnull = nonnull_from_ptr_mut_spec(untagged_ptr);
202 &&& ptr.addr() & tag == tag
203 &&& (ptr.addr() & !tag) != 0
204 &&& R::ptr_perm_match(right_nonnull.cast().view_ptr_mut(), right)
205 },
206 }
207 }
208
209 open spec fn rel_perm(self, perm: Self::Permission) -> bool {
210 match (self, perm) {
211 (Either::Left(left), Sum::Left(left_perm)) => left.rel_perm(left_perm),
212 (Either::Right(right), Sum::Right(right_perm)) => right.rel_perm(right_perm),
213 _ => false,
214 }
215 }
216
217 axiom fn lemma_align_bits_range()
218 ensures
219 Self::ALIGN_BITS == if L::ALIGN_BITS < R::ALIGN_BITS {
220 L::ALIGN_BITS - 1
221 } else {
222 R::ALIGN_BITS - 1
223 },
224 ;
225}
226
227unsafe impl<'a, L: NonNullPtrRef<'a>, R: NonNullPtrRef<'a>> NonNullPtrRef<'a> for Either<L, R> {
228 type Ref = Either<L::Ref, R::Ref>;
229
230 type RefPermission = Sum<L::RefPermission, R::RefPermission>;
231
232 open spec fn ref_perm_view_permission(perm: Self::RefPermission) -> Self::Permission {
233 match perm {
234 Sum::Left(left) => Sum::Left(L::ref_perm_view_permission(left)),
235 Sum::Right(right) => Sum::Right(R::ref_perm_view_permission(right)),
236 }
237 }
238
239 open spec fn ref_rel_perm(r: Self::Ref, perm: Self::RefPermission) -> bool {
240 true
241 }
242
243 proof fn lemma_ref_perm_inv_impl_perm_inv(perm: Self::RefPermission) {
244 match perm {
245 Sum::Left(left) => L::lemma_ref_perm_inv_impl_perm_inv(left),
246 Sum::Right(right) => R::lemma_ref_perm_inv_impl_perm_inv(right),
247 }
248 }
249
250 proof fn borrow_ref_perm(tracked perm: &Self::RefPermission) -> (tracked ret:
251 Self::RefPermission) {
252 if perm is Left {
253 Sum::Left(L::borrow_ref_perm(perm.tracked_borrow_left()))
254 } else {
255 Sum::Right(R::borrow_ref_perm(perm.tracked_borrow_right()))
256 }
257 }
258
259 proof fn borrow_perm_as_ref_perm(tracked perm: &'a Self::Permission) -> (tracked ret:
260 Self::RefPermission) {
261 if perm is Left {
262 Sum::Left(L::borrow_perm_as_ref_perm(perm.tracked_borrow_left()))
263 } else {
264 Sum::Right(R::borrow_perm_as_ref_perm(perm.tracked_borrow_right()))
265 }
266 }
267
268 unsafe fn raw_as_ref(
269 raw: NonNull<Self::Target>,
270 Tracked(perm): Tracked<Self::RefPermission>,
271 ) -> Self::Ref {
272 proof_decl! {
273 let ghost align_bits = Self::ALIGN_BITS;
274 let ghost tag = 1usize << align_bits;
275 let ghost raw_addr = raw.view_ptr_mut()@.addr;
276 }
277 proof! {
278 Self::lemma_align_bits_range();
279 if perm is Left {
280 assert((raw_addr & !tag) == raw_addr) by (bit_vector)
281 requires
282 raw_addr & tag == 0,
283 ;
284 } else {
285 assert((raw_addr & tag) < raw_addr) by (bit_vector)
286 requires
287 raw_addr & tag == tag,
288 (raw_addr & !tag) != 0,
289 ;
290 }
291 }
292 let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
295
296 if is_right == 0 {
297 proof!{
298 if perm is Right {
299 assert(tag != 0) by (bit_vector)
300 requires
301 tag == 1usize << align_bits,
302 align_bits < usize::BITS,
303 ;
304 assert(false);
305 }
306 }
307 Either::Left(
310 unsafe { L::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_left())) },
311 )
312 } else {
313 Either::Right(
316 unsafe { R::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
317 )
318 }
319 }
320
321 #[verifier::spinoff_prover]
322 fn ref_as_raw(ptr_ref: Self::Ref) -> (NonNull<Self::Target>, Tracked<Self::RefPermission>) {
323 proof!{
324 Self::lemma_align_bits_range();
325 }
326 proof_decl!{
327 let ghost align_bits = Self::ALIGN_BITS;
328 let ghost tag = 1usize << align_bits;
329 let ghost l_align_bits = L::ALIGN_BITS;
330 let ghost r_align_bits = R::ALIGN_BITS;
331 }
332 proof!{
333 L::lemma_align_bits_range();
334 R::lemma_align_bits_range();
335 vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
336 vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
337 vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
338 vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
339 vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
340 vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
341 }
342 match ptr_ref {
343 Either::Left(left) => {
344 let (ptr, Tracked(perm)) = L::ref_as_raw(left);
346 proof! {
347 let ghost ptr_addr = ptr.view_ptr_mut().addr();
348 L::lemma_ref_perm_inv_impl_perm_inv(perm);
349 let extra_bits: u32 = (l_align_bits - align_bits) as u32;
350 let scale = 1usize << extra_bits;
351 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
352 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
353 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
354 assert(ptr_addr % tag == 0) by {
355 let big = 1usize << l_align_bits;
356 let q = ptr_addr / big;
357 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
358 assert(ptr_addr == (q * scale) * tag) by (nonlinear_arith)
359 requires
360 ptr_addr == q * big,
361 big == tag * scale,
362 ;
363 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q * scale, tag as int);
364 };
365 assert(ptr_addr & tag == 0) by (bit_vector)
366 requires
367 ptr_addr % (1usize << l_align_bits) == 0,
368 tag == 1usize << align_bits,
369 align_bits < l_align_bits < usize::BITS,
370 ;
371 }
372 (ptr.cast(), Tracked(Sum::Left(perm)))
373 },
374 Either::Right(right) => {
375 let (ptr, Tracked(perm)) = R::ref_as_raw(right);
379 proof! {
380 Self::lemma_align_bits_range();
381 }
382 let tagged_ptr = ptr.map_addr_v(
383 |addr: NonZeroUsize| -> (ret: NonZeroUsize)
384 ensures
385 ret@ == addr@ | (1usize << Self::ALIGN_BITS),
386 { addr | 1usize << Self::ALIGN_BITS },
387 );
388 proof! {
389 let ghost ptr_addr = ptr.view_ptr_mut().addr();
390 let ghost tagged_addr = tagged_ptr.view_ptr_mut().addr();
391 R::lemma_ref_perm_inv_impl_perm_inv(perm);
392 assert(tagged_addr & tag == tag) by (bit_vector)
393 requires
394 tagged_addr == ptr_addr | tag,
395 tag == 1usize << align_bits,
396 align_bits < r_align_bits < usize::BITS,
397 ptr_addr % (1usize << r_align_bits) == 0,
398 ptr_addr != 0,
399 ;
400 assert(ptr_addr & tag == 0) by (bit_vector)
401 requires
402 ptr_addr % (1usize << r_align_bits) == 0,
403 tag == 1usize << align_bits,
404 align_bits < r_align_bits < usize::BITS,
405 ;
406 assert(tagged_addr & !tag == ptr_addr) by (bit_vector)
407 requires
408 tagged_addr == ptr_addr | tag,
409 ptr_addr & tag == 0,
410 ;
411 let extra_bits: u32 = (r_align_bits - align_bits) as u32;
412 let scale = 1usize << extra_bits;
413 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
414 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
415 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
416 assert(tagged_addr == ptr_addr + tag) by (bit_vector)
417 requires
418 tagged_addr == ptr_addr | tag,
419 ptr_addr & tag == 0,
420 ;
421 assert(ptr_addr % tag == 0) by {
422 let big = 1usize << r_align_bits;
423 let q = ptr_addr / big;
424 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
425 assert(ptr_addr == (q * scale) * tag) by (nonlinear_arith)
426 requires
427 ptr_addr == q * big,
428 big == tag * scale,
429 ;
430 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q * scale, tag as int);
431 };
432 assert(tagged_addr % tag == 0) by {
433 vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(ptr_addr as int, tag as int);
434 }
435 }
436 (tagged_ptr.cast(), Tracked(Sum::Right(perm)))
437 },
438 }
439 }
440}
441
442} #[verus_verify(dual_spec)]
445const fn min(a: u32, b: u32) -> u32 {
446 if a < b { a } else { b }
447}
448
449verus! {
450
451#[verus_spec(ret =>
457 requires
458 (ptr.view_ptr_mut().addr() & bits) < ptr.view_ptr_mut().addr(),
459 (ptr.view_ptr_mut().addr() & !bits) != 0,
460 ensures
461 ret.0 == (ptr.view_ptr_mut().addr() & bits),
462 ret.1.view_ptr_mut() == ptr.view_ptr_mut().with_addr((ptr.view_ptr_mut().addr() & !bits) as usize),
463)]
464unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
465 use vstd_extra::external::nonzero::NonZeroUsize;
467
468 let removed_bits = ptr.addr_v().get() & bits;
469 let result_ptr = ptr.map_addr_v(
470 |addr| -> (ret: NonZeroUsize)
471 ensures
472 ret@ == addr@ & !bits,
473 {
474 unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) }
476 },
477 );
478 (removed_bits, result_ptr)
479}
480
481#[verifier::spinoff_prover]
482proof fn lemma_aligned_addr_clears_tag_bit(
483 addr: usize,
484 tag: usize,
485 align_bits: u32,
486 ptr_align_bits: u32,
487)
488 requires
489 addr % (1usize << ptr_align_bits) == 0,
490 tag == 1usize << align_bits,
491 align_bits < ptr_align_bits < usize::BITS,
492 ensures
493 addr & tag == 0,
494{
495 assert(addr & (1usize << align_bits) == 0) by (bit_vector)
496 requires
497 addr % (1usize << ptr_align_bits) == 0,
498 0u32 < ptr_align_bits,
499 align_bits < ptr_align_bits,
500 ptr_align_bits <= 63u32,
501 ;
502 assert(addr & tag == 0) by (bit_vector)
503 requires
504 addr & (1usize << align_bits) == 0,
505 tag == 1usize << align_bits,
506 align_bits < 64u32,
507 ;
508}
509
510} #[cfg(ktest)]
512mod test {
513 use alloc::{boxed::Box, sync::Arc};
514
515 use super::*;
516 use crate::{prelude::ktest, sync::RcuOption};
517
518 type Either32 = Either<Arc<u32>, Box<u32>>;
519 type Either16 = Either<Arc<u32>, Box<u16>>;
520
521 #[ktest]
522 fn alignment() {
523 assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
524 assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
525 }
526
527 #[ktest]
528 fn left_pointer() {
529 let val: Either16 = Either::Left(Arc::new(123));
530
531 let ptr = NonNullPtr::into_raw(val);
532 assert_eq!(ptr.addr().get() & 1, 0);
533
534 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
535 assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
536
537 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
538 assert_eq!(ptr, ptr2);
539
540 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
541 assert!(matches!(val, Either::Left(ref r) if **r == 123));
542 drop(val);
543 }
544
545 #[ktest]
546 fn right_pointer() {
547 let val: Either16 = Either::Right(Box::new(456));
548
549 let ptr = NonNullPtr::into_raw(val);
550 assert_eq!(ptr.addr().get() & 1, 1);
551
552 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
553 assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
554
555 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
556 assert_eq!(ptr, ptr2);
557
558 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
559 assert!(matches!(val, Either::Right(ref r) if **r == 456));
560 drop(val);
561 }
562
563 #[ktest]
564 fn rcu_store_load() {
565 let rcu: RcuOption<Either32> = RcuOption::new_none();
566 assert!(rcu.read().get().is_none());
567
568 rcu.update(Some(Either::Left(Arc::new(888))));
569 assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
570
571 rcu.update(Some(Either::Right(Box::new(999))));
572 assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
573 }
574}