align_ext/
lib.rs

1// SPDX-License-Identifier: MPL-2.0
2#![allow(unused_macros)]
3#![cfg_attr(not(test), no_std)]
4
5#![allow(non_snake_case)]
6#![allow(unused_parens)]
7#![allow(unused_braces)]
8#![allow(rustdoc::invalid_rust_codeblocks)]
9#![allow(rustdoc::invalid_html_tags)]
10#![allow(rustdoc::broken_intra_doc_links)]
11
12use vstd::arithmetic::div_mod::*;
13use vstd::arithmetic::mul::*;
14use vstd::arithmetic::power2::pow2;
15use vstd::arithmetic::power2::{lemma2_to64, lemma_pow2_strictly_increases};
16use vstd::bits::*;
17use vstd::pervasive::trigger;
18use vstd::prelude::*;
19use vstd_extra::prelude::*;
20
21/// An extension trait for Rust integer types, including `u8`, `u16`, `u32`,
22/// `u64`, and `usize`, to provide methods to make integers aligned to a
23/// power of two.
24pub trait AlignExt {
25    /// Returns to the smallest number that is greater than or equal to
26    /// `self` and is a multiple of the given power of two.
27    ///
28    /// The method panics if `power_of_two` is not a
29    /// power of two or is smaller than 2 or the calculation overflows
30    /// because `self` is too large.
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// use crate::align_ext::AlignExt;
36    /// assert_eq!(12usize.align_up(2), 12);
37    /// assert_eq!(12usize.align_up(4), 12);
38    /// assert_eq!(12usize.align_up(8), 16);
39    /// assert_eq!(12usize.align_up(16), 16);
40    /// ```
41    fn align_up(self, power_of_two: Self) -> Self;
42
43    /// Returns to the greatest number that is smaller than or equal to
44    /// `self` and is a multiple of the given power of two.
45    ///
46    /// The method panics if `power_of_two` is not a
47    /// power of two or is smaller than 2 or the calculation overflows
48    /// because `self` is too large. In release mode,
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use crate::align_ext::AlignExt;
54    /// assert_eq!(12usize.align_down(2), 12);
55    /// assert_eq!(12usize.align_down(4), 12);
56    /// assert_eq!(12usize.align_down(8), 8);
57    /// assert_eq!(12usize.align_down(16), 0);
58    /// ```
59    fn align_down(self, power_of_two: Self) -> Self;
60}
61
62verus! {
63
64proof fn lemma_usize_low_bits_mask_is_mod(x: usize, n: nat)
65    requires
66        n < usize::BITS,
67    ensures
68        (x & (low_bits_mask(n) as usize)) == x % (pow2(n) as usize),
69{
70    if usize::BITS == 64 {
71        lemma_u64_low_bits_mask_is_mod(x as u64, n);
72    } else {
73        lemma_u32_low_bits_mask_is_mod(x as u32, n);
74    }
75}
76
77} // verus!
78macro_rules! call_lemma_low_bits_mask_is_mod {
79    (u8, $x:expr, $n:expr) => {
80        lemma_u8_low_bits_mask_is_mod($x, $n)
81    };
82    (u16, $x:expr, $n:expr) => {
83        lemma_u16_low_bits_mask_is_mod($x, $n)
84    };
85    (u32, $x:expr, $n:expr) => {
86        lemma_u32_low_bits_mask_is_mod($x, $n)
87    };
88    (u64, $x:expr, $n:expr) => {
89        lemma_u64_low_bits_mask_is_mod($x, $n)
90    };
91    (usize, $x:expr, $n:expr) => {
92        lemma_usize_low_bits_mask_is_mod($x, $n)
93    };
94}
95
96macro_rules! impl_align_ext {
97    ($( $uint_type:ty ),+,) => {
98        $(
99                /// # Verified Properties
100                /// ## Safety
101                /// The implementation is written in safe Rust and there is no undefined behavior.
102                /// ## Functional correctness
103                /// The implementation meets the specification given in the trait `AlignExt`.
104            #[verus_verify]
105            impl AlignExt for $uint_type {
106                /// ## Preconditions
107                /// - `align` is a power of two.
108                /// - `align >= 2`.
109                /// - `self + (align - 1)` does not overflow.
110                /// ## Postconditions
111                /// - The function will not panic.
112                /// - The return value is the smallest number that is greater than or equal to `self` and is a multiple of `align`.
113                #[inline]
114                #[verus_spec(ret =>
115                    requires
116                        exists |e:nat| pow2(e) == align,
117                        align >= 2,
118                        self + (align - 1) <= $uint_type::MAX,
119                    ensures
120                        ret >= self,
121                        ret % align == 0,
122                        ret == nat_align_up(self as nat, align as nat),
123                        forall |n: nat| !(n>=self && #[trigger] (n % align as nat) == 0) || (ret <= n),
124                )]
125                fn align_up(self, align: Self) -> Self {
126                    //assert!(align.is_power_of_two() && align >= 2);
127                    proof!{
128                        let x_int = self as int + align as int - 1;
129                        let x = x_int as Self;
130                        if self as int % align as int == 0 {
131                            assert((align as int - 1) % align as int == align as int - 1) by {
132                                lemma_small_mod((align as int - 1) as nat, align as nat);
133                            }
134                            assert(x_int % align as int == align as int - 1) by {
135                                lemma_mod_adds(self as int, align as int - 1, align as int);
136                            }
137                        } else {
138                            let q = self as int / align as int;
139                            let r = self as int % align as int;
140                            lemma_fundamental_div_mod(self as int, align as int);
141
142                            assert((q + 1) * align as int == q * align as int + align as int) by {
143                                lemma_mul_is_distributive_add(align as int, q, 1);
144                            }
145                            assert(((q + 1) * align as int) % align as int == 0) by {
146                                lemma_mod_multiples_basic(q + 1, align as int);
147                            }
148                            assert((r - 1) % align as int == r - 1) by {
149                                lemma_small_mod((r - 1) as nat, align as nat);
150                            }
151                            assert(x_int % align as int == (r - 1)) by {
152                                lemma_mod_adds((q + 1) * align as int, r - 1, align as int);
153                            }
154                        }
155
156                        lemma_low_bits_mask_values();
157                        let mask = (align - 1) as Self;
158                        let e = choose |e: nat| pow2(e) == align;
159                        assert(e < $uint_type::BITS) by {
160                            if e >= $uint_type::BITS {
161                                lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
162                                lemma2_to64();
163                            }
164                        }
165                        call_lemma_low_bits_mask_is_mod!($uint_type, x, e);
166                        assert(x == (x & mask) + (x & !mask)) by (bit_vector);
167                        lemma_nat_align_up_sound(self as nat, align as nat);
168                    }
169                    self.checked_add(align - 1).unwrap() & !(align - 1)
170                }
171
172                #[inline]
173                #[verus_spec(ret =>
174                    requires
175                        exists |e:nat| pow2(e) == align,
176                        align >= 2,
177                    ensures
178                        ret <= self,
179                        ret % align == 0,
180                        ret == nat_align_down(self as nat, align as nat),
181                        forall |n: nat|  !(n<=self && #[trigger] (n % align as nat) == 0) || (ret >= n),
182                )]
183
184                /// ## Preconditions
185                /// - `align` is a power of two.
186                /// - `align >= 2`.
187                /// ## Postconditions
188                /// - The function will not panic.
189                /// - The return value is the greatest number that is smaller than or equal to `self` and is a multiple of `align`.
190                fn align_down(self, align: Self) -> Self {
191                    //assert!(align.is_power_of_two() && align >= 2);
192                    proof!{
193                        lemma_low_bits_mask_values();
194                        let mask = (align - 1) as Self;
195                        let e = choose |e: nat| pow2(e) == align;
196                        assert(e < $uint_type::BITS) by {
197                            if e >= $uint_type::BITS {
198                                lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
199                                lemma2_to64();
200                            }
201                        }
202                        call_lemma_low_bits_mask_is_mod!($uint_type, self, e);
203                        assert(self == (self & mask) + (self & !mask)) by (bit_vector);
204                        assert((self & !mask) as nat == nat_align_down(self as nat, align as nat));
205                        lemma_nat_align_down_sound(self as nat, align as nat);
206                    }
207                    self & !(align - 1)
208                }
209            }
210        )*
211    }
212}
213
214impl_align_ext! {
215    u8,
216    u16,
217    u32,
218    u64,
219    usize,
220}
221
222#[cfg(test)]
223mod test {
224    use super::*;
225
226    #[test]
227    fn test_align_up() {
228        let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
229        let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
230        let output_ns = [0usize, 2, 2, 10, 16, 24, 32, 48, 56];
231
232        for i in 0..input_ns.len() {
233            let n = input_ns[i];
234            let a = input_as[i];
235            let n2 = output_ns[i];
236            assert!(n.align_up(a) == n2);
237        }
238    }
239
240    #[test]
241    fn test_align_down() {
242        let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
243        let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
244        let output_ns = [0usize, 0, 2, 8, 12, 20, 32, 40, 48];
245
246        for i in 0..input_ns.len() {
247            let n = input_ns[i];
248            let a = input_as[i];
249            let n2 = output_ns[i];
250            assert!(n.align_down(a) == n2);
251        }
252    }
253}