1use core::{marker::PhantomData, ptr::NonNull};
3
4use vstd::raw_ptr::group_raw_ptr_axioms;
5use vstd::{bits, prelude::*};
6use vstd_extra::{external::nonzero::*, prelude::*, sum::Sum};
7
8use super::{NonNullPtr, NonNullPtrRef};
9use crate::util::Either;
10
11verus! {
12
13broadcast use {group_nonull_axioms, group_raw_ptr_axioms, group_nonzero_axioms};
14
15proof fn lemma_aligned_addr_clears_tag_bit(
16 addr: usize,
17 tag: usize,
18 align_bits: u32,
19 ptr_align_bits: u32,
20)
21 requires
22 addr % (1usize << ptr_align_bits) == 0,
23 tag == 1usize << align_bits,
24 align_bits < ptr_align_bits < usize::BITS,
25 ensures
26 addr & tag == 0,
27{
28 assert(addr & tag == 0) by (bit_vector)
29 requires
30 addr % (1usize << ptr_align_bits) == 0,
31 tag == 1usize << align_bits,
32 align_bits < ptr_align_bits < usize::BITS,
33 ;
34}
35
36unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
40 type Target = PhantomData<Self>;
41
42 type Permission = Sum<L::Permission, R::Permission>;
47
48 #[verifier::external_body]
49 const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS).checked_sub(1).expect(
50 "`L` and `R` alignments should be at least 2 to pack `Either` into one pointer",
51 );
52
53 #[verus_spec]
54 #[verifier::spinoff_prover]
55 fn into_raw(self) -> (ret: (NonNull<Self::Target>, Tracked<Self::Permission>)) {
56 proof_decl!{
57 let ghost align_bits = Self::ALIGN_BITS;
58 let ghost l_align_bits = L::ALIGN_BITS;
59 let ghost r_align_bits = R::ALIGN_BITS;
60 let ghost tag = 1usize << align_bits;
61 }
62 proof! {
63 L::lemma_align_bits_range();
64 R::lemma_align_bits_range();
65 Self::lemma_align_bits_range();
66 vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
67 vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
68 vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
69 vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
70 vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
71 vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
72 }
73 match self {
74 Self::Left(left) => {
75 let (left, Tracked(perm)) = left.into_raw();
77 proof! {
78 let left_addr = left.cast::<Self::Target>().view_ptr_mut().addr();
79 let extra_bits: u32 = (l_align_bits - align_bits) as u32;
80 let scale = 1usize << extra_bits;
81 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
82 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
83 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
84 assert(left_addr % tag == 0) by {
85 let big = 1usize << l_align_bits;
86 let q = left_addr / big;
87 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(left_addr as int, big as int);
88 assert(left_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
89 requires
90 left_addr as int == q as int * big as int,
91 big == tag * scale,
92 ;
93 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
94 };
95 lemma_aligned_addr_clears_tag_bit(left_addr, tag, align_bits, l_align_bits);
96 }
97 (left.cast(), Tracked(Sum::Left(perm)))
98 },
99 Self::Right(right) => {
100 let (right, Tracked(perm)) = right.into_raw();
105 let right_tagged = right.map_addr_v(
106 |addr: NonZeroUsize| -> (ret: NonZeroUsize)
107 ensures
108 ret@ == addr@ | (1usize << Self::ALIGN_BITS),
109 { addr | 1usize << Self::ALIGN_BITS },
110 );
111 proof! {
112 let addr = right.addr_spec()@;
113 let tagged_addr = right_tagged.addr_spec()@;
114 assert(tagged_addr & tag == tag) by (bit_vector)
115 requires
116 tagged_addr == addr | tag,
117 tag == 1usize << align_bits,
118 1 <= r_align_bits < usize::BITS,
119 align_bits < r_align_bits,
120 addr % (1usize << r_align_bits) == 0,
121 addr != 0,
122 ;
123 let extra_bits: u32 = (r_align_bits - align_bits) as u32;
124 let scale = 1usize << extra_bits;
125 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
126 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
127 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
128 assert(tagged_addr == addr + tag) by (bit_vector)
129 requires
130 tagged_addr == addr | tag,
131 tag == 1usize << align_bits,
132 1 <= r_align_bits < usize::BITS,
133 align_bits < r_align_bits,
134 addr % (1usize << r_align_bits) == 0,
135 addr != 0,
136 ;
137 assert(addr % tag == 0) by {
138 let big = 1usize << r_align_bits;
139 let q = addr / big;
140 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(addr as int, big as int);
141 assert(addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
142 requires
143 addr as int == q as int * big as int,
144 big == tag * scale,
145 ;
146 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
147 }
148 lemma_aligned_addr_clears_tag_bit(addr, tag, align_bits, r_align_bits);
149 assert(tagged_addr & !tag == addr) by (bit_vector)
150 requires
151 tagged_addr == addr | tag,
152 addr & tag == 0,
153 ;
154 assert(tagged_addr % (1usize << align_bits) == 0) by {
155 vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(addr as int, tag as int);
156 }
157 }
158 (right_tagged.cast(), Tracked(Sum::Right(perm)))
159 },
160 }
161 }
162
163 unsafe fn from_raw(
164 ptr: NonNull<Self::Target>,
165 Tracked(perm): Tracked<Self::Permission>,
166 ) -> Self {
167 proof! {
168 Self::lemma_align_bits_range();
169 }
170 proof_decl! {
171 let ghost align_bits = Self::ALIGN_BITS;
172 let ghost tag = 1usize << Self::ALIGN_BITS;
173 let ghost ptr_addr = ptr.view_ptr_mut()@.addr;
174 }
175 proof! {
176 assert(tag > 0) by (bit_vector)
177 requires
178 tag == 1usize << align_bits,
179 align_bits < usize::BITS,
180 ;
181 match perm {
182 Sum::Left(_) => {
183 assert((ptr_addr & !tag) == ptr_addr) by (bit_vector)
184 requires
185 ptr_addr & tag == 0,
186 ;
187 },
188 Sum::Right(_) => {
189 assert((ptr_addr & tag) < ptr_addr) by (bit_vector)
190 requires
191 ptr_addr & tag == tag,
192 (ptr_addr & !tag) != 0,
193 ;
194 },
195 }
196 }
197 let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
200
201 if is_right == 0 {
202 Either::Left(unsafe { L::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_left())) })
205 } else {
206 Either::Right(
209 unsafe { R::from_raw(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
210 )
211 }
212 }
213
214 open spec fn ptr_perm_match(ptr: NonNull<Self::Target>, perm: Self::Permission) -> bool {
215 let tag = 1usize << Self::ALIGN_BITS;
216 match perm {
217 Sum::Left(left) => {
218 &&& ptr.view_ptr_mut().addr() & tag == 0
219 &&& L::ptr_perm_match(ptr.cast(), left)
220 },
221 Sum::Right(right) => {
222 let untagged_ptr = ptr.view_ptr_mut().with_addr((ptr.view_ptr_mut().addr() & !tag));
223 let right_nonnull = nonnull_from_ptr_mut_spec(untagged_ptr);
224 &&& ptr.view_ptr_mut().addr() & tag == tag
225 &&& (ptr.view_ptr_mut().addr() & !tag) != 0
226 &&& R::ptr_perm_match(right_nonnull.cast(), right)
227 },
228 }
229 }
230
231 open spec fn rel_perm(self, perm: Self::Permission) -> bool {
232 match (self, perm) {
233 (Either::Left(left), Sum::Left(left_perm)) => left.rel_perm(left_perm),
234 (Either::Right(right), Sum::Right(right_perm)) => right.rel_perm(right_perm),
235 _ => false,
236 }
237 }
238
239 axiom fn lemma_align_bits_range()
240 ensures
241 Self::ALIGN_BITS == if L::ALIGN_BITS < R::ALIGN_BITS {
242 L::ALIGN_BITS - 1
243 } else {
244 R::ALIGN_BITS - 1
245 },
246 ;
247}
248
249unsafe impl<'a, L: NonNullPtrRef<'a>, R: NonNullPtrRef<'a>> NonNullPtrRef<'a> for Either<L, R> {
250 type Ref = Either<L::Ref, R::Ref>;
251
252 type RefPermission = Sum<L::RefPermission, R::RefPermission>;
253
254 open spec fn ref_perm_view_permission(perm: Self::RefPermission) -> Self::Permission {
255 match perm {
256 Sum::Left(left) => Sum::Left(L::ref_perm_view_permission(left)),
257 Sum::Right(right) => Sum::Right(R::ref_perm_view_permission(right)),
258 }
259 }
260
261 open spec fn ref_rel_perm(r: Self::Ref, perm: Self::RefPermission) -> bool {
262 true
263 }
264
265 proof fn lemma_ref_perm_inv_impl_perm_inv(perm: Self::RefPermission) {
266 match perm {
267 Sum::Left(left) => L::lemma_ref_perm_inv_impl_perm_inv(left),
268 Sum::Right(right) => R::lemma_ref_perm_inv_impl_perm_inv(right),
269 }
270 }
271
272 unsafe fn raw_as_ref(
273 raw: NonNull<Self::Target>,
274 Tracked(perm): Tracked<Self::RefPermission>,
275 ) -> Self::Ref {
276 proof_decl! {
277 let ghost align_bits = Self::ALIGN_BITS;
278 let ghost tag = 1usize << align_bits;
279 let ghost raw_addr = raw.view_ptr_mut()@.addr;
280 }
281 proof! {
282 Self::lemma_align_bits_range();
283 if perm is Left {
284 assert((raw_addr & !tag) == raw_addr) by (bit_vector)
285 requires
286 raw_addr & tag == 0,
287 ;
288 } else {
289 assert((raw_addr & tag) < raw_addr) by (bit_vector)
290 requires
291 raw_addr & tag == tag,
292 (raw_addr & !tag) != 0,
293 ;
294 }
295 }
296 let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
299
300 if is_right == 0 {
301 proof!{
302 if perm is Right {
303 assert(tag != 0) by (bit_vector)
304 requires
305 tag == 1usize << align_bits,
306 align_bits < usize::BITS,
307 ;
308 assert(false);
309 }
310 }
311 Either::Left(
314 unsafe { L::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_left())) },
315 )
316 } else {
317 Either::Right(
320 unsafe { R::raw_as_ref(real_ptr.cast(), Tracked(perm.tracked_take_right())) },
321 )
322 }
323 }
324
325 fn ref_as_raw(ptr_ref: Self::Ref) -> (NonNull<Self::Target>, Tracked<Self::RefPermission>) {
326 proof!{
327 Self::lemma_align_bits_range();
328 }
329 proof_decl!{
330 let ghost align_bits = Self::ALIGN_BITS;
331 let ghost tag = 1usize << align_bits;
332 let ghost l_align_bits = L::ALIGN_BITS;
333 let ghost r_align_bits = R::ALIGN_BITS;
334 }
335 proof!{
336 L::lemma_align_bits_range();
337 R::lemma_align_bits_range();
338 vstd::bits::lemma_usize_pow2_no_overflow(align_bits as nat);
339 vstd::bits::lemma_usize_pow2_no_overflow(l_align_bits as nat);
340 vstd::bits::lemma_usize_pow2_no_overflow(r_align_bits as nat);
341 vstd::bits::lemma_usize_shl_is_mul(1, align_bits as usize);
342 vstd::bits::lemma_usize_shl_is_mul(1, l_align_bits as usize);
343 vstd::bits::lemma_usize_shl_is_mul(1, r_align_bits as usize);
344 }
345 match ptr_ref {
346 Either::Left(left) => {
347 let (ptr, Tracked(perm)) = L::ref_as_raw(left);
349 proof! {
350 let ghost ptr_addr = ptr.view_ptr_mut().addr();
351 L::lemma_ref_perm_inv_impl_perm_inv(perm);
352 let extra_bits: u32 = (l_align_bits - align_bits) as u32;
353 let scale = 1usize << extra_bits;
354 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
355 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
356 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
357 assert(ptr_addr % tag == 0) by {
358 let big = 1usize << l_align_bits;
359 let q = ptr_addr / big;
360 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
361 assert(ptr_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
362 requires
363 ptr_addr as int == q as int * big as int,
364 big == tag * scale,
365 ;
366 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
367 };
368 assert(ptr_addr & tag == 0) by (bit_vector)
369 requires
370 ptr_addr % (1usize << l_align_bits) == 0,
371 tag == 1usize << align_bits,
372 align_bits < l_align_bits < usize::BITS,
373 ;
374 }
375 (ptr.cast(), Tracked(Sum::Left(perm)))
376 },
377 Either::Right(right) => {
378 let (ptr, Tracked(perm)) = R::ref_as_raw(right);
382 proof! {
383 Self::lemma_align_bits_range();
384 }
385 let tagged_ptr = ptr.map_addr_v(
386 |addr: NonZeroUsize| -> (ret: NonZeroUsize)
387 ensures
388 ret@ == addr@ | (1usize << Self::ALIGN_BITS),
389 { addr | 1usize << Self::ALIGN_BITS },
390 );
391 proof! {
392 let ghost ptr_addr = ptr.view_ptr_mut().addr();
393 let ghost tagged_addr = tagged_ptr.view_ptr_mut().addr();
394 R::lemma_ref_perm_inv_impl_perm_inv(perm);
395 assert(tagged_addr & tag == tag) by (bit_vector)
396 requires
397 tagged_addr == ptr_addr | tag,
398 tag == 1usize << align_bits,
399 align_bits < r_align_bits < usize::BITS,
400 ptr_addr % (1usize << r_align_bits) == 0,
401 ptr_addr != 0,
402 ;
403 assert(ptr_addr & tag == 0) by (bit_vector)
404 requires
405 ptr_addr % (1usize << r_align_bits) == 0,
406 tag == 1usize << align_bits,
407 align_bits < r_align_bits < usize::BITS,
408 ;
409 assert(tagged_addr & !tag == ptr_addr) by (bit_vector)
410 requires
411 tagged_addr == ptr_addr | tag,
412 ptr_addr & tag == 0,
413 ;
414 let extra_bits: u32 = (r_align_bits - align_bits) as u32;
415 let scale = 1usize << extra_bits;
416 vstd::bits::lemma_usize_pow2_no_overflow(extra_bits as nat);
417 vstd::bits::lemma_usize_shl_is_mul(1, extra_bits as usize);
418 vstd::arithmetic::power2::lemma_pow2_adds(align_bits as nat, extra_bits as nat);
419 assert(tagged_addr == ptr_addr + tag) by (bit_vector)
420 requires
421 tagged_addr == ptr_addr | tag,
422 ptr_addr & tag == 0,
423 ;
424 assert(ptr_addr % tag == 0) by {
425 let big = 1usize << r_align_bits;
426 let q = ptr_addr / big;
427 vstd::arithmetic::div_mod::lemma_fundamental_div_mod(ptr_addr as int, big as int);
428 assert(ptr_addr as int == (q as int * scale as int) * tag as int) by (nonlinear_arith)
429 requires
430 ptr_addr as int == q as int * big as int,
431 big == tag * scale,
432 ;
433 vstd::arithmetic::div_mod::lemma_mod_multiples_basic(q as int * scale as int, tag as int);
434 };
435 assert(tagged_addr % tag == 0) by {
436 vstd::arithmetic::div_mod::lemma_mod_add_multiples_vanish(ptr_addr as int, tag as int);
437 }
438 }
439 (tagged_ptr.cast(), Tracked(Sum::Right(perm)))
440 },
441 }
442 }
443}
444
445#[vstd::contrib::auto_spec]
447const fn min(a: u32, b: u32) -> u32 {
448 if a < b {
449 a
450 } else {
451 b
452 }
453}
454
455#[verus_spec(ret =>
460 requires
461 (ptr.view_ptr_mut().addr() & bits) < ptr.view_ptr_mut().addr(),
462 (ptr.view_ptr_mut().addr() & !bits) != 0,
463 ensures
464 ret.0 == (ptr.view_ptr_mut().addr() & bits),
465 ret.1.view_ptr_mut() == ptr.view_ptr_mut().with_addr((ptr.view_ptr_mut().addr() & !bits) as usize),
466)]
467unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
468 use vstd_extra::external::nonzero::NonZeroUsize;
470
471 let removed_bits = ptr.addr_v().get() & bits;
472 let result_ptr = ptr.map_addr_v(
473 |addr| -> (ret: NonZeroUsize)
474 ensures
475 ret@ == addr@ & !bits,
476 {
477 unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) }
479 },
480 );
481 (removed_bits, result_ptr)
482}
483
484} #[cfg(ktest)]
486mod test {
487 use alloc::{boxed::Box, sync::Arc};
488
489 use super::*;
490 use crate::{prelude::ktest, sync::RcuOption};
491
492 type Either32 = Either<Arc<u32>, Box<u32>>;
493 type Either16 = Either<Arc<u32>, Box<u16>>;
494
495 #[ktest]
496 fn alignment() {
497 assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
498 assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
499 }
500
501 #[ktest]
502 fn left_pointer() {
503 let val: Either16 = Either::Left(Arc::new(123));
504
505 let ptr = NonNullPtr::into_raw(val);
506 assert_eq!(ptr.addr().get() & 1, 0);
507
508 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
509 assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
510
511 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
512 assert_eq!(ptr, ptr2);
513
514 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
515 assert!(matches!(val, Either::Left(ref r) if **r == 123));
516 drop(val);
517 }
518
519 #[ktest]
520 fn right_pointer() {
521 let val: Either16 = Either::Right(Box::new(456));
522
523 let ptr = NonNullPtr::into_raw(val);
524 assert_eq!(ptr.addr().get() & 1, 1);
525
526 let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
527 assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
528
529 let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
530 assert_eq!(ptr, ptr2);
531
532 let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
533 assert!(matches!(val, Either::Right(ref r) if **r == 456));
534 drop(val);
535 }
536
537 #[ktest]
538 fn rcu_store_load() {
539 let rcu: RcuOption<Either32> = RcuOption::new_none();
540 assert!(rcu.read().get().is_none());
541
542 rcu.update(Some(Either::Left(Arc::new(888))));
543 assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
544
545 rcu.update(Some(Either::Right(Box::new(999))));
546 assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
547 }
548}