1#![allow(unused_macros)]
3#![cfg_attr(not(test), no_std)]
4#![allow(non_snake_case)]
5#![allow(unused_parens)]
6#![allow(unused_braces)]
7#![allow(rustdoc::invalid_rust_codeblocks)]
8#![allow(rustdoc::invalid_html_tags)]
9#![allow(rustdoc::broken_intra_doc_links)]
10
11use vstd::arithmetic::div_mod::*;
12use vstd::arithmetic::mul::*;
13use vstd::arithmetic::{power::pow, power2::*};
14use vstd::bits::*;
15use vstd::pervasive::trigger;
16use vstd::prelude::*;
17use vstd_extra::panic::*;
18use vstd_extra::prelude::*;
19
20pub trait AlignExt {
24 fn align_up(self, power_of_two: Self) -> Self;
41
42 fn align_down(self, power_of_two: Self) -> Self;
59}
60
61verus! {
62
63proof fn lemma_usize_low_bits_mask_is_mod(x: usize, n: nat)
64 requires
65 n < usize::BITS,
66 ensures
67 (x & (low_bits_mask(n) as usize)) == x % (pow2(n) as usize),
68{
69 if usize::BITS == 64 {
70 lemma_u64_low_bits_mask_is_mod(x as u64, n);
71 } else {
72 lemma_u32_low_bits_mask_is_mod(x as u32, n);
73 }
74}
75
76} macro_rules! call_lemma_low_bits_mask_is_mod {
78 (u8, $x:expr, $n:expr) => {
79 lemma_u8_low_bits_mask_is_mod($x, $n)
80 };
81 (u16, $x:expr, $n:expr) => {
82 lemma_u16_low_bits_mask_is_mod($x, $n)
83 };
84 (u32, $x:expr, $n:expr) => {
85 lemma_u32_low_bits_mask_is_mod($x, $n)
86 };
87 (u64, $x:expr, $n:expr) => {
88 lemma_u64_low_bits_mask_is_mod($x, $n)
89 };
90 (usize, $x:expr, $n:expr) => {
91 lemma_usize_low_bits_mask_is_mod($x, $n)
92 };
93}
94
95macro_rules! impl_align_ext {
96 ($( $uint_type:ty ),+,) => {
97 $(
98 verus!{
99 impl AlignExt for $uint_type {
105 #[inline]
114 #[verus_spec(ret =>
115 requires
116 self + (align - 1) <= $uint_type::MAX,
117 !(align >= 2 && is_pow2(align as int)) ==> may_panic(),
118 ensures
119 align >= 2,
120 is_pow2(align as int),
121 ret >= self,
122 ret % align == 0,
123 ret == nat_align_up(self as nat, align as nat),
124 forall |n: nat| !(n>=self && #[trigger] (n % align as nat) == 0) || (ret <= n),
125 )]
126 fn align_up(self, align: Self) -> Self {
127 vstd_extra::assert!(align.is_power_of_two() && align >= 2);
128 proof!{
129 is_pow2_equiv(align as int);
130 let x_int = self as int + align as int - 1;
131 let x = x_int as Self;
132 if self as int % align as int == 0 {
133 assert((align as int - 1) % align as int == align as int - 1) by {
134 lemma_small_mod((align as int - 1) as nat, align as nat);
135 }
136 assert(x_int % align as int == align as int - 1) by {
137 lemma_mod_adds(self as int, align as int - 1, align as int);
138 }
139 } else {
140 let q = self as int / align as int;
141 let r = self as int % align as int;
142 lemma_fundamental_div_mod(self as int, align as int);
143 assert(self as int == q * align as int + r) by {
144 lemma_mul_is_commutative(align as int, q);
145 }
146
147 assert((q + 1) * align as int == q * align as int + align as int) by {
148 lemma_mul_is_distributive_add_other_way(align as int, q, 1);
149 }
150 assert(((q + 1) * align as int) % align as int == 0) by {
151 lemma_mod_multiples_basic(q + 1, align as int);
152 }
153 assert((r - 1) % align as int == r - 1) by {
154 lemma_small_mod((r - 1) as nat, align as nat);
155 }
156 assert(x_int == (q + 1) * align as int + (r - 1));
157 assert(x_int % align as int == (r - 1)) by {
158 lemma_mod_adds((q + 1) * align as int, r - 1, align as int);
159 }
160 }
161
162 lemma_low_bits_mask_values();
163 let mask = (align - 1) as Self;
164 let e = choose |e: nat| pow(2, e) == align;
165 lemma_pow2(e);
166 assert(e < $uint_type::BITS) by {
167 if e >= $uint_type::BITS {
168 lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
169 lemma2_to64();
170 }
171 }
172 call_lemma_low_bits_mask_is_mod!($uint_type, x, e);
173 assert(x == (x & mask) + (x & !mask)) by (bit_vector);
174 lemma_nat_align_up_sound(self as nat, align as nat);
175 }
176 self.checked_add(align - 1).unwrap() & !(align - 1)
177 }
178
179 #[inline]
180 #[verus_spec(ret =>
181 requires
182 !(is_pow2(align as int) && align >= 2) ==> may_panic(),
183 ensures
184 align >= 2,
185 is_pow2(align as int),
186 ret <= self,
187 ret % align == 0,
188 ret == nat_align_down(self as nat, align as nat),
189 forall |n: nat| !(n<=self && #[trigger] (n % align as nat) == 0) || (ret >= n),
190 )]
191
192 fn align_down(self, align: Self) -> Self {
198 vstd_extra::assert!(align.is_power_of_two() && align >= 2);
199 proof!{
200 is_pow2_equiv(align as int);
201 lemma_low_bits_mask_values();
202 let mask = (align - 1) as Self;
203 let e = choose |e: nat| pow(2, e) == align;
204 lemma_pow2(e);
205 assert(e < $uint_type::BITS) by {
206 if e >= $uint_type::BITS {
207 lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
208 lemma2_to64();
209 }
210 }
211 call_lemma_low_bits_mask_is_mod!($uint_type, self, e);
212 assert(self == (self & mask) + (self & !mask)) by (bit_vector);
213 assert((self & !mask) as nat == nat_align_down(self as nat, align as nat));
214 lemma_nat_align_down_sound(self as nat, align as nat);
215 }
216 self & !(align - 1)
217 }
218 }
219 }
220 )*
221 }
222}
223
224impl_align_ext! {
225 u8,
226 u16,
227 u32,
228 u64,
229 usize,
230}
231
232#[cfg(test)]
233mod test {
234 use super::*;
235
236 #[test]
237 fn test_align_up() {
238 let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
239 let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
240 let output_ns = [0usize, 2, 2, 10, 16, 24, 32, 48, 56];
241
242 for i in 0..input_ns.len() {
243 let n = input_ns[i];
244 let a = input_as[i];
245 let n2 = output_ns[i];
246 assert!(n.align_up(a) == n2);
247 }
248 }
249
250 #[test]
251 fn test_align_down() {
252 let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
253 let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
254 let output_ns = [0usize, 0, 2, 8, 12, 20, 32, 40, 48];
255
256 for i in 0..input_ns.len() {
257 let n = input_ns[i];
258 let a = input_as[i];
259 let n2 = output_ns[i];
260 assert!(n.align_down(a) == n2);
261 }
262 }
263}