1use core::{marker::PhantomData, ptr::NonNull};
3
4use vstd::raw_ptr::group_raw_ptr_axioms;
5use vstd::{bits, prelude::*};
6use vstd_extra::{external::nonzero::*, prelude::*, sum::Sum};
7
8use super::{NonNullPtr, NonNullPtrRef};
9use crate::util::Either;
10
11verus! {
12
13broadcast use {group_nonull_axioms, group_raw_ptr_axioms, group_nonzero_axioms};
14unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
19 type Target = PhantomData<Self>;
20
21 type Permission = Sum<L::Permission, R::Permission>;
26
27 #[verifier::external_body]
28 const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS).checked_sub(1).expect(
29 "`L` and `R` alignments should be at least 2 to pack `Either` into one pointer",
30 );
31
32 #[verus_spec]
33 #[verifier::spinoff_prover]
34 fn into_raw(self) -> (ret: (NonNull<Self::Target>, Tracked<Self::Permission>)) {
35 proof_decl!{
36 let ghost align_bits = Self::ALIGN_BITS;
37 let ghost l_align_bits = L::ALIGN_BITS;
38 let ghost r_align_bits = R::ALIGN_BITS;
39 let ghost tag = 1usize << align_bits;
40 }
41 proof! {
42 L::lemma_align_bits_range();
43 R::lemma_align_bits_range();
44 Self::lemma_align_bits_range();
45 vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
46 vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
47 vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
48 vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
49 vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
50 vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
51 }
52 match self {
53 Self::Left(left) => {
54 let (left, Tracked(perm)) = left.into_raw();
56 proof! {
57 let left_addr = left.cast::<Self::Target>().view_ptr_mut().addr();
58 let extra_bits: u32 = (l_align_bits - align_bits) as u32;
59 let scale = 1usize << extra_bits;
60 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
61 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
62 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
63 assert(left_addr % tag == 0) by {
64 let big = 1usize << l_align_bits;
65 let q = left_addr / big;
66 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(left_addr as int, big as int);
67 assert(left_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
68 requires
69 left_addr as int == q as int * big as int,
70 big == tag * scale,
71 ;
72 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
73 };
74 lemma_aligned_addr_clears_tag_bit(left_addr, tag, align_bits, l_align_bits);
75 }
76 (left.cast(), Tracked(Sum::Left(perm)))
77 },
78 Self::Right(right) => {
79 let (right, Tracked(perm)) = right.into_raw();
84 let right_tagged = right.map_addr_v(
85 |addr: NonZeroUsize| -> (ret: NonZeroUsize)
86 ensures
87 ret@ == addr@ | (1usize << Self::ALIGN_BITS),
88 { addr | 1usize << Self::ALIGN_BITS },
89 );
90 proof! {
91 let addr = right.addr_spec()@;
92 let tagged_addr = right_tagged.addr_spec()@;
93 assert(tagged_addr & tag == tag) by (bit_vector)
94 requires
95 tagged_addr == addr | tag,
96 tag == 1usize << align_bits,
97 1 <= r_align_bits < usize::BITS,
98 align_bits < r_align_bits,
99 addr % (1usize << r_align_bits) == 0,
100 addr != 0,
101 ;
102 let extra_bits: u32 = (r_align_bits - align_bits) as u32;
103 let scale = 1usize << extra_bits;
104 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
105 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
106 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
107 assert(tagged_addr == addr + tag) by (bit_vector)
108 requires
109 tagged_addr == addr | tag,
110 tag == 1usize << align_bits,
111 1 <= r_align_bits < usize::BITS,
112 align_bits < r_align_bits,
113 addr % (1usize << r_align_bits) == 0,
114 addr != 0,
115 ;
116 assert(addr % tag == 0) by {
117 let big = 1usize << r_align_bits;
118 let q = addr / big;
119 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(addr as int, big as int);
120 assert(addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
121 requires
122 addr as int == q as int * big as int,
123 big == tag * scale,
124 ;
125 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
126 }
127 lemma_aligned_addr_clears_tag_bit(addr, tag, align_bits, r_align_bits);
128 assert(tagged_addr & !tag == addr) by (bit_vector)
129 requires
130 tagged_addr == addr | tag,
131 addr & tag == 0,
132 ;
133 assert(tagged_addr % (1usize << align_bits) == 0) by {
134 vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(addr as int, tag as int);
135 }
136 }
137 (right_tagged.cast(), Tracked(Sum::Right(perm)))
138 },
139 }
140 }
141
142 unsafe fn from_raw(
143 ptr: NonNull<Self::Target>,
144 Tracked(perm): Tracked<Self::Permission>,
145 ) -> Self {
146 proof! {
147 Self::lemma_align_bits_range();
148 }
149 proof_decl! {
150 let ghost align_bits = Self::ALIGN_BITS;
151 let ghost tag = 1usize << Self::ALIGN_BITS;
152 let ghost ptr_addr = ptr.view_ptr_mut()@.addr;
153 }
154 proof! {
155 assert(tag > 0) by (bit_vector)
156 requires
157 tag == 1usize << align_bits,
158 align_bits < usize::BITS,
159 ;
160 match perm {
161 Sum::Left(_) => {
162 assert((ptr_addr & !tag) == ptr_addr) by (bit_vector)
163 requires
164 ptr_addr & tag == 0,
165 ;
166 },
167 Sum::Right(_) => {
168 assert((ptr_addr & tag) < ptr_addr) by (bit_vector)
169 requires
170 ptr_addr & tag == tag,
171 (ptr_addr & !tag) != 0,
172 ;
173 },
174 }
175 }
176 let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
179
180 if is_right == 0 {
181 Either::Left(unsafe { L::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_left())) })
184 } else {
185 Either::Right(
188 unsafe { R::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
189 )
190 }
191 }
192
193 open spec fn ptr_perm_match(ptr: *mut Self::Target, perm: Self::Permission) -> bool {
194 let tag = 1usize << Self::ALIGN_BITS;
195 match perm {
196 Sum::Left(left) => {
197 &&& ptr.addr() & tag == 0
198 &&& L::ptr_perm_match(ptr.cast(), left)
199 },
200 Sum::Right(right) => {
201 let untagged_ptr = ptr.with_addr((ptr.addr() & !tag));
202 let right_nonnull = nonnull_from_ptr_mut_spec(untagged_ptr);
203 &&& ptr.addr() & tag == tag
204 &&& (ptr.addr() & !tag) != 0
205 &&& R::ptr_perm_match(right_nonnull.cast().view_ptr_mut(), right)
206 },
207 }
208 }
209
210 open spec fn rel_perm(self, perm: Self::Permission) -> bool {
211 match (self, perm) {
212 (Either::Left(left), Sum::Left(left_perm)) => left.rel_perm(left_perm),
213 (Either::Right(right), Sum::Right(right_perm)) => right.rel_perm(right_perm),
214 _ => false,
215 }
216 }
217
218 axiom fn lemma_align_bits_range()
219 ensures
220 Self::ALIGN_BITS == if L::ALIGN_BITS < R::ALIGN_BITS {
221 L::ALIGN_BITS - 1
222 } else {
223 R::ALIGN_BITS - 1
224 },
225 ;
226}
227
228unsafe impl<'a, L: NonNullPtrRef<'a>, R: NonNullPtrRef<'a>> NonNullPtrRef<'a> for Either<L, R> {
229 type Ref = Either<L::Ref, R::Ref>;
230
231 type RefPermission = Sum<L::RefPermission, R::RefPermission>;
232
233 open spec fn ref_perm_view_permission(perm: Self::RefPermission) -> Self::Permission {
234 match perm {
235 Sum::Left(left) => Sum::Left(L::ref_perm_view_permission(left)),
236 Sum::Right(right) => Sum::Right(R::ref_perm_view_permission(right)),
237 }
238 }
239
240 open spec fn ref_rel_perm(r: Self::Ref, perm: Self::RefPermission) -> bool {
241 true
242 }
243
244 proof fn lemma_ref_perm_inv_impl_perm_inv(perm: Self::RefPermission) {
245 match perm {
246 Sum::Left(left) => L::lemma_ref_perm_inv_impl_perm_inv(left),
247 Sum::Right(right) => R::lemma_ref_perm_inv_impl_perm_inv(right),
248 }
249 }
250
251 proof fn borrow_ref_perm(tracked perm: &Self::RefPermission) -> (tracked ret:
252 Self::RefPermission) {
253 if perm is Left {
254 Sum::Left(L::borrow_ref_perm(perm.tracked_borrow_left()))
255 } else {
256 Sum::Right(R::borrow_ref_perm(perm.tracked_borrow_right()))
257 }
258 }
259
260 proof fn borrow_perm_as_ref_perm(tracked perm: &'a Self::Permission) -> (tracked ret:
261 Self::RefPermission) {
262 if perm is Left {
263 Sum::Left(L::borrow_perm_as_ref_perm(perm.tracked_borrow_left()))
264 } else {
265 Sum::Right(R::borrow_perm_as_ref_perm(perm.tracked_borrow_right()))
266 }
267 }
268
269 unsafe fn raw_as_ref(
270 raw: NonNull<Self::Target>,
271 Tracked(perm): Tracked<Self::RefPermission>,
272 ) -> Self::Ref {
273 proof_decl! {
274 let ghost align_bits = Self::ALIGN_BITS;
275 let ghost tag = 1usize << align_bits;
276 let ghost raw_addr = raw.view_ptr_mut()@.addr;
277 }
278 proof! {
279 Self::lemma_align_bits_range();
280 if perm is Left {
281 assert((raw_addr & !tag) == raw_addr) by (bit_vector)
282 requires
283 raw_addr & tag == 0,
284 ;
285 } else {
286 assert((raw_addr & tag) < raw_addr) by (bit_vector)
287 requires
288 raw_addr & tag == tag,
289 (raw_addr & !tag) != 0,
290 ;
291 }
292 }
293 let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
296
297 if is_right == 0 {
298 proof!{
299 if perm is Right {
300 assert(tag != 0) by (bit_vector)
301 requires
302 tag == 1usize << align_bits,
303 align_bits < usize::BITS,
304 ;
305 assert(false);
306 }
307 }
308 Either::Left(
311 unsafe { L::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_left())) },
312 )
313 } else {
314 Either::Right(
317 unsafe { R::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
318 )
319 }
320 }
321
322 fn ref_as_raw(ptr_ref: Self::Ref) -> (NonNull<Self::Target>, Tracked<Self::RefPermission>) {
323 proof!{
324 Self::lemma_align_bits_range();
325 }
326 proof_decl!{
327 let ghost align_bits = Self::ALIGN_BITS;
328 let ghost tag = 1usize << align_bits;
329 let ghost l_align_bits = L::ALIGN_BITS;
330 let ghost r_align_bits = R::ALIGN_BITS;
331 }
332 proof!{
333 L::lemma_align_bits_range();
334 R::lemma_align_bits_range();
335 vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
336 vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
337 vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
338 vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
339 vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
340 vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
341 }
342 match ptr_ref {
343 Either::Left(left) => {
344 let (ptr, Tracked(perm)) = L::ref_as_raw(left);
346 proof! {
347 let ghost ptr_addr = ptr.view_ptr_mut().addr();
348 L::lemma_ref_perm_inv_impl_perm_inv(perm);
349 let extra_bits: u32 = (l_align_bits - align_bits) as u32;
350 let scale = 1usize << extra_bits;
351 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
352 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
353 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
354 assert(ptr_addr % tag == 0) by {
355 let big = 1usize << l_align_bits;
356 let q = ptr_addr / big;
357 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
358 assert(ptr_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
359 requires
360 ptr_addr as int == q as int * big as int,
361 big == tag * scale,
362 ;
363 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
364 };
365 assert(ptr_addr & tag == 0) by (bit_vector)
366 requires
367 ptr_addr % (1usize << l_align_bits) == 0,
368 tag == 1usize << align_bits,
369 align_bits < l_align_bits < usize::BITS,
370 ;
371 }
372 (ptr.cast(), Tracked(Sum::Left(perm)))
373 },
374 Either::Right(right) => {
375 let (ptr, Tracked(perm)) = R::ref_as_raw(right);
379 proof! {
380 Self::lemma_align_bits_range();
381 }
382 let tagged_ptr = ptr.map_addr_v(
383 |addr: NonZeroUsize| -> (ret: NonZeroUsize)
384 ensures
385 ret@ == addr@ | (1usize << Self::ALIGN_BITS),
386 { addr | 1usize << Self::ALIGN_BITS },
387 );
388 proof! {
389 let ghost ptr_addr = ptr.view_ptr_mut().addr();
390 let ghost tagged_addr = tagged_ptr.view_ptr_mut().addr();
391 R::lemma_ref_perm_inv_impl_perm_inv(perm);
392 assert(tagged_addr & tag == tag) by (bit_vector)
393 requires
394 tagged_addr == ptr_addr | tag,
395 tag == 1usize << align_bits,
396 align_bits < r_align_bits < usize::BITS,
397 ptr_addr % (1usize << r_align_bits) == 0,
398 ptr_addr != 0,
399 ;
400 assert(ptr_addr & tag == 0) by (bit_vector)
401 requires
402 ptr_addr % (1usize << r_align_bits) == 0,
403 tag == 1usize << align_bits,
404 align_bits < r_align_bits < usize::BITS,
405 ;
406 assert(tagged_addr & !tag == ptr_addr) by (bit_vector)
407 requires
408 tagged_addr == ptr_addr | tag,
409 ptr_addr & tag == 0,
410 ;
411 let extra_bits: u32 = (r_align_bits - align_bits) as u32;
412 let scale = 1usize << extra_bits;
413 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
414 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
415 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
416 assert(tagged_addr == ptr_addr + tag) by (bit_vector)
417 requires
418 tagged_addr == ptr_addr | tag,
419 ptr_addr & tag == 0,
420 ;
421 assert(ptr_addr % tag == 0) by {
422 let big = 1usize << r_align_bits;
423 let q = ptr_addr / big;
424 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
425 assert(ptr_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
426 requires
427 ptr_addr as int == q as int * big as int,
428 big == tag * scale,
429 ;
430 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
431 };
432 assert(tagged_addr % tag == 0) by {
433 vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(ptr_addr as int, tag as int);
434 }
435 }
436 (tagged_ptr.cast(), Tracked(Sum::Right(perm)))
437 },
438 }
439 }
440}
441
442#[vstd::contrib::auto_spec]
444const fn min(a: u32, b: u32) -> u32 {
445 if a < b {
446 a
447 } else {
448 b
449 }
450}
451
452#[verus_spec(ret =>
457 requires
458 (ptr.view_ptr_mut().addr() & bits) < ptr.view_ptr_mut().addr(),
459 (ptr.view_ptr_mut().addr() & !bits) != 0,
460 ensures
461 ret.0 == (ptr.view_ptr_mut().addr() & bits),
462 ret.1.view_ptr_mut() == ptr.view_ptr_mut().with_addr((ptr.view_ptr_mut().addr() & !bits) as usize),
463)]
464unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
465 use vstd_extra::external::nonzero::NonZeroUsize;
467
468 let removed_bits = ptr.addr_v().get() & bits;
469 let result_ptr = ptr.map_addr_v(
470 |addr| -> (ret: NonZeroUsize)
471 ensures
472 ret@ == addr@ & !bits,
473 {
474 unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) }
476 },
477 );
478 (removed_bits, result_ptr)
479}
480
481#[verifier::spinoff_prover]
482proof fn lemma_aligned_addr_clears_tag_bit(
483 addr: usize,
484 tag: usize,
485 align_bits: u32,
486 ptr_align_bits: u32,
487)
488 requires
489 addr % (1usize << ptr_align_bits) == 0,
490 tag == 1usize << align_bits,
491 align_bits < ptr_align_bits < usize::BITS,
492 ensures
493 addr & tag == 0,
494{
495 assert(addr & tag == 0) by (bit_vector)
496 requires
497 addr % (1usize << ptr_align_bits) == 0,
498 tag == 1usize << align_bits,
499 align_bits < ptr_align_bits < usize::BITS,
500 ;
501}
502
503} #[cfg(ktest)]
505mod test {
506 use alloc::{boxed::Box, sync::Arc};
507
508 use super::*;
509 use crate::{prelude::ktest, sync::RcuOption};
510
511 type Either32 = Either<Arc<u32>, Box<u32>>;
512 type Either16 = Either<Arc<u32>, Box<u16>>;
513
514 #[ktest]
515 fn alignment() {
516 assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
517 assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
518 }
519
520 #[ktest]
521 fn left_pointer() {
522 let val: Either16 = Either::Left(Arc::new(123));
523
524 let ptr = NonNullPtr::into_raw(val);
525 assert_eq!(ptr.addr().get() & 1, 0);
526
527 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
528 assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
529
530 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
531 assert_eq!(ptr, ptr2);
532
533 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
534 assert!(matches!(val, Either::Left(ref r) if **r == 123));
535 drop(val);
536 }
537
538 #[ktest]
539 fn right_pointer() {
540 let val: Either16 = Either::Right(Box::new(456));
541
542 let ptr = NonNullPtr::into_raw(val);
543 assert_eq!(ptr.addr().get() & 1, 1);
544
545 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
546 assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
547
548 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
549 assert_eq!(ptr, ptr2);
550
551 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
552 assert!(matches!(val, Either::Right(ref r) if **r == 456));
553 drop(val);
554 }
555
556 #[ktest]
557 fn rcu_store_load() {
558 let rcu: RcuOption<Either32> = RcuOption::new_none();
559 assert!(rcu.read().get().is_none());
560
561 rcu.update(Some(Either::Left(Arc::new(888))));
562 assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
563
564 rcu.update(Some(Either::Right(Box::new(999))));
565 assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
566 }
567}