1#![allow(unused_macros)]
3#![cfg_attr(not(test), no_std)]
4#![allow(non_snake_case)]
5#![allow(unused_parens)]
6#![allow(unused_braces)]
7#![allow(rustdoc::invalid_rust_codeblocks)]
8#![allow(rustdoc::invalid_html_tags)]
9#![allow(rustdoc::broken_intra_doc_links)]
10
11use vstd::arithmetic::div_mod::*;
12use vstd::arithmetic::mul::*;
13use vstd::arithmetic::power2::pow2;
14use vstd::arithmetic::power2::{lemma2_to64, lemma_pow2_strictly_increases};
15use vstd::bits::*;
16use vstd::pervasive::trigger;
17use vstd::prelude::*;
18use vstd_extra::prelude::*;
19
20pub trait AlignExt {
24 fn align_up(self, power_of_two: Self) -> Self;
41
42 fn align_down(self, power_of_two: Self) -> Self;
59}
60
61verus! {
62
63proof fn lemma_usize_low_bits_mask_is_mod(x: usize, n: nat)
64 requires
65 n < usize::BITS,
66 ensures
67 (x & (low_bits_mask(n) as usize)) == x % (pow2(n) as usize),
68{
69 if usize::BITS == 64 {
70 lemma_u64_low_bits_mask_is_mod(x as u64, n);
71 } else {
72 lemma_u32_low_bits_mask_is_mod(x as u32, n);
73 }
74}
75
76} macro_rules! call_lemma_low_bits_mask_is_mod {
78 (u8, $x:expr, $n:expr) => {
79 lemma_u8_low_bits_mask_is_mod($x, $n)
80 };
81 (u16, $x:expr, $n:expr) => {
82 lemma_u16_low_bits_mask_is_mod($x, $n)
83 };
84 (u32, $x:expr, $n:expr) => {
85 lemma_u32_low_bits_mask_is_mod($x, $n)
86 };
87 (u64, $x:expr, $n:expr) => {
88 lemma_u64_low_bits_mask_is_mod($x, $n)
89 };
90 (usize, $x:expr, $n:expr) => {
91 lemma_usize_low_bits_mask_is_mod($x, $n)
92 };
93}
94
95macro_rules! impl_align_ext {
96 ($( $uint_type:ty ),+,) => {
97 $(
98 #[verus_verify]
104 impl AlignExt for $uint_type {
105 #[inline]
113 #[verus_spec(ret =>
114 requires
115 exists |e:nat| pow2(e) == align,
116 align >= 2,
117 self + (align - 1) <= $uint_type::MAX,
118 ensures
119 ret >= self,
120 ret % align == 0,
121 ret == nat_align_up(self as nat, align as nat),
122 forall |n: nat| !(n>=self && #[trigger] (n % align as nat) == 0) || (ret <= n),
123 )]
124 fn align_up(self, align: Self) -> Self {
125 proof!{
127 let x_int = self as int + align as int - 1;
128 let x = x_int as Self;
129 if self as int % align as int == 0 {
130 assert((align as int - 1) % align as int == align as int - 1) by {
131 lemma_small_mod((align as int - 1) as nat, align as nat);
132 }
133 assert(x_int % align as int == align as int - 1) by {
134 lemma_mod_adds(self as int, align as int - 1, align as int);
135 }
136 } else {
137 let q = self as int / align as int;
138 let r = self as int % align as int;
139 lemma_fundamental_div_mod(self as int, align as int);
140
141 assert((q + 1) * align as int == q * align as int + align as int) by {
142 lemma_mul_is_distributive_add(align as int, q, 1);
143 }
144 assert(((q + 1) * align as int) % align as int == 0) by {
145 lemma_mod_multiples_basic(q + 1, align as int);
146 }
147 assert((r - 1) % align as int == r - 1) by {
148 lemma_small_mod((r - 1) as nat, align as nat);
149 }
150 assert(x_int % align as int == (r - 1)) by {
151 lemma_mod_adds((q + 1) * align as int, r - 1, align as int);
152 }
153 }
154
155 lemma_low_bits_mask_values();
156 let mask = (align - 1) as Self;
157 let e = choose |e: nat| pow2(e) == align;
158 assert(e < $uint_type::BITS) by {
159 if e >= $uint_type::BITS {
160 lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
161 lemma2_to64();
162 }
163 }
164 call_lemma_low_bits_mask_is_mod!($uint_type, x, e);
165 assert(x == (x & mask) + (x & !mask)) by (bit_vector);
166 lemma_nat_align_up_sound(self as nat, align as nat);
167 }
168 self.checked_add(align - 1).unwrap() & !(align - 1)
169 }
170
171 #[inline]
172 #[verus_spec(ret =>
173 requires
174 exists |e:nat| pow2(e) == align,
175 align >= 2,
176 ensures
177 ret <= self,
178 ret % align == 0,
179 ret == nat_align_down(self as nat, align as nat),
180 forall |n: nat| !(n<=self && #[trigger] (n % align as nat) == 0) || (ret >= n),
181 )]
182
183 fn align_down(self, align: Self) -> Self {
190 proof!{
192 lemma_low_bits_mask_values();
193 let mask = (align - 1) as Self;
194 let e = choose |e: nat| pow2(e) == align;
195 assert(e < $uint_type::BITS) by {
196 if e >= $uint_type::BITS {
197 lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
198 lemma2_to64();
199 }
200 }
201 call_lemma_low_bits_mask_is_mod!($uint_type, self, e);
202 assert(self == (self & mask) + (self & !mask)) by (bit_vector);
203 assert((self & !mask) as nat == nat_align_down(self as nat, align as nat));
204 lemma_nat_align_down_sound(self as nat, align as nat);
205 }
206 self & !(align - 1)
207 }
208 }
209 )*
210 }
211}
212
213impl_align_ext! {
214 u8,
215 u16,
216 u32,
217 u64,
218 usize,
219}
220
221#[cfg(test)]
222mod test {
223 use super::*;
224
225 #[test]
226 fn test_align_up() {
227 let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
228 let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
229 let output_ns = [0usize, 2, 2, 10, 16, 24, 32, 48, 56];
230
231 for i in 0..input_ns.len() {
232 let n = input_ns[i];
233 let a = input_as[i];
234 let n2 = output_ns[i];
235 assert!(n.align_up(a) == n2);
236 }
237 }
238
239 #[test]
240 fn test_align_down() {
241 let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
242 let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
243 let output_ns = [0usize, 0, 2, 8, 12, 20, 32, 40, 48];
244
245 for i in 0..input_ns.len() {
246 let n = input_ns[i];
247 let a = input_as[i];
248 let n2 = output_ns[i];
249 assert!(n.align_down(a) == n2);
250 }
251 }
252}