1#![allow(unused_macros)]
3#![cfg_attr(not(test), no_std)]
4
5#![allow(non_snake_case)]
6#![allow(unused_parens)]
7#![allow(unused_braces)]
8#![allow(rustdoc::invalid_rust_codeblocks)]
9#![allow(rustdoc::invalid_html_tags)]
10#![allow(rustdoc::broken_intra_doc_links)]
11
12use vstd::arithmetic::div_mod::*;
13use vstd::arithmetic::mul::*;
14use vstd::arithmetic::power2::pow2;
15use vstd::arithmetic::power2::{lemma2_to64, lemma_pow2_strictly_increases};
16use vstd::bits::*;
17use vstd::pervasive::trigger;
18use vstd::prelude::*;
19use vstd_extra::prelude::*;
20
21pub trait AlignExt {
25 fn align_up(self, power_of_two: Self) -> Self;
42
43 fn align_down(self, power_of_two: Self) -> Self;
60}
61
62verus! {
63
64proof fn lemma_usize_low_bits_mask_is_mod(x: usize, n: nat)
65 requires
66 n < usize::BITS,
67 ensures
68 (x & (low_bits_mask(n) as usize)) == x % (pow2(n) as usize),
69{
70 if usize::BITS == 64 {
71 lemma_u64_low_bits_mask_is_mod(x as u64, n);
72 } else {
73 lemma_u32_low_bits_mask_is_mod(x as u32, n);
74 }
75}
76
77} macro_rules! call_lemma_low_bits_mask_is_mod {
79 (u8, $x:expr, $n:expr) => {
80 lemma_u8_low_bits_mask_is_mod($x, $n)
81 };
82 (u16, $x:expr, $n:expr) => {
83 lemma_u16_low_bits_mask_is_mod($x, $n)
84 };
85 (u32, $x:expr, $n:expr) => {
86 lemma_u32_low_bits_mask_is_mod($x, $n)
87 };
88 (u64, $x:expr, $n:expr) => {
89 lemma_u64_low_bits_mask_is_mod($x, $n)
90 };
91 (usize, $x:expr, $n:expr) => {
92 lemma_usize_low_bits_mask_is_mod($x, $n)
93 };
94}
95
96macro_rules! impl_align_ext {
97 ($( $uint_type:ty ),+,) => {
98 $(
99 #[verus_verify]
105 impl AlignExt for $uint_type {
106 #[inline]
114 #[verus_spec(ret =>
115 requires
116 exists |e:nat| pow2(e) == align,
117 align >= 2,
118 self + (align - 1) <= $uint_type::MAX,
119 ensures
120 ret >= self,
121 ret % align == 0,
122 ret == nat_align_up(self as nat, align as nat),
123 forall |n: nat| !(n>=self && #[trigger] (n % align as nat) == 0) || (ret <= n),
124 )]
125 fn align_up(self, align: Self) -> Self {
126 proof!{
128 let x_int = self as int + align as int - 1;
129 let x = x_int as Self;
130 if self as int % align as int == 0 {
131 assert((align as int - 1) % align as int == align as int - 1) by {
132 lemma_small_mod((align as int - 1) as nat, align as nat);
133 }
134 assert(x_int % align as int == align as int - 1) by {
135 lemma_mod_adds(self as int, align as int - 1, align as int);
136 }
137 } else {
138 let q = self as int / align as int;
139 let r = self as int % align as int;
140 lemma_fundamental_div_mod(self as int, align as int);
141
142 assert((q + 1) * align as int == q * align as int + align as int) by {
143 lemma_mul_is_distributive_add(align as int, q, 1);
144 }
145 assert(((q + 1) * align as int) % align as int == 0) by {
146 lemma_mod_multiples_basic(q + 1, align as int);
147 }
148 assert((r - 1) % align as int == r - 1) by {
149 lemma_small_mod((r - 1) as nat, align as nat);
150 }
151 assert(x_int % align as int == (r - 1)) by {
152 lemma_mod_adds((q + 1) * align as int, r - 1, align as int);
153 }
154 }
155
156 lemma_low_bits_mask_values();
157 let mask = (align - 1) as Self;
158 let e = choose |e: nat| pow2(e) == align;
159 assert(e < $uint_type::BITS) by {
160 if e >= $uint_type::BITS {
161 lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
162 lemma2_to64();
163 }
164 }
165 call_lemma_low_bits_mask_is_mod!($uint_type, x, e);
166 assert(x == (x & mask) + (x & !mask)) by (bit_vector);
167 lemma_nat_align_up_sound(self as nat, align as nat);
168 }
169 self.checked_add(align - 1).unwrap() & !(align - 1)
170 }
171
172 #[inline]
173 #[verus_spec(ret =>
174 requires
175 exists |e:nat| pow2(e) == align,
176 align >= 2,
177 ensures
178 ret <= self,
179 ret % align == 0,
180 ret == nat_align_down(self as nat, align as nat),
181 forall |n: nat| !(n<=self && #[trigger] (n % align as nat) == 0) || (ret >= n),
182 )]
183
184 fn align_down(self, align: Self) -> Self {
191 proof!{
193 lemma_low_bits_mask_values();
194 let mask = (align - 1) as Self;
195 let e = choose |e: nat| pow2(e) == align;
196 assert(e < $uint_type::BITS) by {
197 if e >= $uint_type::BITS {
198 lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
199 lemma2_to64();
200 }
201 }
202 call_lemma_low_bits_mask_is_mod!($uint_type, self, e);
203 assert(self == (self & mask) + (self & !mask)) by (bit_vector);
204 assert((self & !mask) as nat == nat_align_down(self as nat, align as nat));
205 lemma_nat_align_down_sound(self as nat, align as nat);
206 }
207 self & !(align - 1)
208 }
209 }
210 )*
211 }
212}
213
214impl_align_ext! {
215 u8,
216 u16,
217 u32,
218 u64,
219 usize,
220}
221
222#[cfg(test)]
223mod test {
224 use super::*;
225
226 #[test]
227 fn test_align_up() {
228 let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
229 let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
230 let output_ns = [0usize, 2, 2, 10, 16, 24, 32, 48, 56];
231
232 for i in 0..input_ns.len() {
233 let n = input_ns[i];
234 let a = input_as[i];
235 let n2 = output_ns[i];
236 assert!(n.align_up(a) == n2);
237 }
238 }
239
240 #[test]
241 fn test_align_down() {
242 let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
243 let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
244 let output_ns = [0usize, 0, 2, 8, 12, 20, 32, 40, 48];
245
246 for i in 0..input_ns.len() {
247 let n = input_ns[i];
248 let a = input_as[i];
249 let n2 = output_ns[i];
250 assert!(n.align_down(a) == n2);
251 }
252 }
253}