align_ext/
lib.rs

1// SPDX-License-Identifier: MPL-2.0
2#![allow(unused_macros)]
3#![cfg_attr(not(test), no_std)]
4#![allow(non_snake_case)]
5#![allow(unused_parens)]
6#![allow(unused_braces)]
7#![allow(rustdoc::invalid_rust_codeblocks)]
8#![allow(rustdoc::invalid_html_tags)]
9#![allow(rustdoc::broken_intra_doc_links)]
10
11use vstd::arithmetic::div_mod::*;
12use vstd::arithmetic::mul::*;
13use vstd::arithmetic::power2::pow2;
14use vstd::arithmetic::power2::{lemma2_to64, lemma_pow2_strictly_increases};
15use vstd::bits::*;
16use vstd::pervasive::trigger;
17use vstd::prelude::*;
18use vstd_extra::prelude::*;
19
20/// An extension trait for Rust integer types, including `u8`, `u16`, `u32`,
21/// `u64`, and `usize`, to provide methods to make integers aligned to a
22/// power of two.
23pub trait AlignExt {
24    /// Returns to the smallest number that is greater than or equal to
25    /// `self` and is a multiple of the given power of two.
26    ///
27    /// The method panics if `power_of_two` is not a
28    /// power of two or is smaller than 2 or the calculation overflows
29    /// because `self` is too large.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use crate::align_ext::AlignExt;
35    /// assert_eq!(12usize.align_up(2), 12);
36    /// assert_eq!(12usize.align_up(4), 12);
37    /// assert_eq!(12usize.align_up(8), 16);
38    /// assert_eq!(12usize.align_up(16), 16);
39    /// ```
40    fn align_up(self, power_of_two: Self) -> Self;
41
42    /// Returns to the greatest number that is smaller than or equal to
43    /// `self` and is a multiple of the given power of two.
44    ///
45    /// The method panics if `power_of_two` is not a
46    /// power of two or is smaller than 2 or the calculation overflows
47    /// because `self` is too large. In release mode,
48    ///
49    /// # Examples
50    ///
51    /// ```
52    /// use crate::align_ext::AlignExt;
53    /// assert_eq!(12usize.align_down(2), 12);
54    /// assert_eq!(12usize.align_down(4), 12);
55    /// assert_eq!(12usize.align_down(8), 8);
56    /// assert_eq!(12usize.align_down(16), 0);
57    /// ```
58    fn align_down(self, power_of_two: Self) -> Self;
59}
60
61verus! {
62
63proof fn lemma_usize_low_bits_mask_is_mod(x: usize, n: nat)
64    requires
65        n < usize::BITS,
66    ensures
67        (x & (low_bits_mask(n) as usize)) == x % (pow2(n) as usize),
68{
69    if usize::BITS == 64 {
70        lemma_u64_low_bits_mask_is_mod(x as u64, n);
71    } else {
72        lemma_u32_low_bits_mask_is_mod(x as u32, n);
73    }
74}
75
76} // verus!
77macro_rules! call_lemma_low_bits_mask_is_mod {
78    (u8, $x:expr, $n:expr) => {
79        lemma_u8_low_bits_mask_is_mod($x, $n)
80    };
81    (u16, $x:expr, $n:expr) => {
82        lemma_u16_low_bits_mask_is_mod($x, $n)
83    };
84    (u32, $x:expr, $n:expr) => {
85        lemma_u32_low_bits_mask_is_mod($x, $n)
86    };
87    (u64, $x:expr, $n:expr) => {
88        lemma_u64_low_bits_mask_is_mod($x, $n)
89    };
90    (usize, $x:expr, $n:expr) => {
91        lemma_usize_low_bits_mask_is_mod($x, $n)
92    };
93}
94
95macro_rules! impl_align_ext {
96    ($( $uint_type:ty ),+,) => {
97        $(
98                /// # Verified Properties
99                /// ## Safety
100                /// The implementation is written in safe Rust and there is no undefined behavior.
101                /// ## Functional correctness
102                /// The implementation meets the specification given in the trait `AlignExt`.
103            #[verus_verify]
104            impl AlignExt for $uint_type {
105                /// ## Preconditions
106                /// - `align` is a power of two.
107                /// - `align >= 2`.
108                /// - `self + (align - 1)` does not overflow.
109                /// ## Postconditions
110                /// - The function will not panic.
111                /// - The return value is the smallest number that is greater than or equal to `self` and is a multiple of `align`.
112                #[inline]
113                #[verus_spec(ret =>
114                    requires
115                        exists |e:nat| pow2(e) == align,
116                        align >= 2,
117                        self + (align - 1) <= $uint_type::MAX,
118                    ensures
119                        ret >= self,
120                        ret % align == 0,
121                        ret == nat_align_up(self as nat, align as nat),
122                        forall |n: nat| !(n>=self && #[trigger] (n % align as nat) == 0) || (ret <= n),
123                )]
124                fn align_up(self, align: Self) -> Self {
125                    //assert!(align.is_power_of_two() && align >= 2);
126                    proof!{
127                        let x_int = self as int + align as int - 1;
128                        let x = x_int as Self;
129                        if self as int % align as int == 0 {
130                            assert((align as int - 1) % align as int == align as int - 1) by {
131                                lemma_small_mod((align as int - 1) as nat, align as nat);
132                            }
133                            assert(x_int % align as int == align as int - 1) by {
134                                lemma_mod_adds(self as int, align as int - 1, align as int);
135                            }
136                        } else {
137                            let q = self as int / align as int;
138                            let r = self as int % align as int;
139                            lemma_fundamental_div_mod(self as int, align as int);
140
141                            assert((q + 1) * align as int == q * align as int + align as int) by {
142                                lemma_mul_is_distributive_add(align as int, q, 1);
143                            }
144                            assert(((q + 1) * align as int) % align as int == 0) by {
145                                lemma_mod_multiples_basic(q + 1, align as int);
146                            }
147                            assert((r - 1) % align as int == r - 1) by {
148                                lemma_small_mod((r - 1) as nat, align as nat);
149                            }
150                            assert(x_int % align as int == (r - 1)) by {
151                                lemma_mod_adds((q + 1) * align as int, r - 1, align as int);
152                            }
153                        }
154
155                        lemma_low_bits_mask_values();
156                        let mask = (align - 1) as Self;
157                        let e = choose |e: nat| pow2(e) == align;
158                        assert(e < $uint_type::BITS) by {
159                            if e >= $uint_type::BITS {
160                                lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
161                                lemma2_to64();
162                            }
163                        }
164                        call_lemma_low_bits_mask_is_mod!($uint_type, x, e);
165                        assert(x == (x & mask) + (x & !mask)) by (bit_vector);
166                        lemma_nat_align_up_sound(self as nat, align as nat);
167                    }
168                    self.checked_add(align - 1).unwrap() & !(align - 1)
169                }
170
171                #[inline]
172                #[verus_spec(ret =>
173                    requires
174                        exists |e:nat| pow2(e) == align,
175                        align >= 2,
176                    ensures
177                        ret <= self,
178                        ret % align == 0,
179                        ret == nat_align_down(self as nat, align as nat),
180                        forall |n: nat|  !(n<=self && #[trigger] (n % align as nat) == 0) || (ret >= n),
181                )]
182
183                /// ## Preconditions
184                /// - `align` is a power of two.
185                /// - `align >= 2`.
186                /// ## Postconditions
187                /// - The function will not panic.
188                /// - The return value is the greatest number that is smaller than or equal to `self` and is a multiple of `align`.
189                fn align_down(self, align: Self) -> Self {
190                    //assert!(align.is_power_of_two() && align >= 2);
191                    proof!{
192                        lemma_low_bits_mask_values();
193                        let mask = (align - 1) as Self;
194                        let e = choose |e: nat| pow2(e) == align;
195                        assert(e < $uint_type::BITS) by {
196                            if e >= $uint_type::BITS {
197                                lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
198                                lemma2_to64();
199                            }
200                        }
201                        call_lemma_low_bits_mask_is_mod!($uint_type, self, e);
202                        assert(self == (self & mask) + (self & !mask)) by (bit_vector);
203                        assert((self & !mask) as nat == nat_align_down(self as nat, align as nat));
204                        lemma_nat_align_down_sound(self as nat, align as nat);
205                    }
206                    self & !(align - 1)
207                }
208            }
209        )*
210    }
211}
212
213impl_align_ext! {
214    u8,
215    u16,
216    u32,
217    u64,
218    usize,
219}
220
221#[cfg(test)]
222mod test {
223    use super::*;
224
225    #[test]
226    fn test_align_up() {
227        let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
228        let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
229        let output_ns = [0usize, 2, 2, 10, 16, 24, 32, 48, 56];
230
231        for i in 0..input_ns.len() {
232            let n = input_ns[i];
233            let a = input_as[i];
234            let n2 = output_ns[i];
235            assert!(n.align_up(a) == n2);
236        }
237    }
238
239    #[test]
240    fn test_align_down() {
241        let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
242        let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
243        let output_ns = [0usize, 0, 2, 8, 12, 20, 32, 40, 48];
244
245        for i in 0..input_ns.len() {
246            let n = input_ns[i];
247            let a = input_as[i];
248            let n2 = output_ns[i];
249            assert!(n.align_down(a) == n2);
250        }
251    }
252}