Skip to main content

align_ext/
lib.rs

1// SPDX-License-Identifier: MPL-2.0
2#![allow(unused_macros)]
3#![cfg_attr(not(test), no_std)]
4#![allow(non_snake_case)]
5#![allow(unused_parens)]
6#![allow(unused_braces)]
7#![allow(rustdoc::invalid_rust_codeblocks)]
8#![allow(rustdoc::invalid_html_tags)]
9#![allow(rustdoc::broken_intra_doc_links)]
10
11use vstd::arithmetic::div_mod::*;
12use vstd::arithmetic::mul::*;
13use vstd::arithmetic::{power::pow, power2::*};
14use vstd::bits::*;
15use vstd::pervasive::trigger;
16use vstd::prelude::*;
17use vstd_extra::panic::*;
18use vstd_extra::prelude::*;
19
20/// An extension trait for Rust integer types, including `u8`, `u16`, `u32`,
21/// `u64`, and `usize`, to provide methods to make integers aligned to a
22/// power of two.
23pub trait AlignExt {
24    /// Returns to the smallest number that is greater than or equal to
25    /// `self` and is a multiple of the given power of two.
26    ///
27    /// The method panics if `power_of_two` is not a
28    /// power of two or is smaller than 2 or the calculation overflows
29    /// because `self` is too large.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use crate::align_ext::AlignExt;
35    /// assert_eq!(12usize.align_up(2), 12);
36    /// assert_eq!(12usize.align_up(4), 12);
37    /// assert_eq!(12usize.align_up(8), 16);
38    /// assert_eq!(12usize.align_up(16), 16);
39    /// ```
40    fn align_up(self, power_of_two: Self) -> Self;
41
42    /// Returns to the greatest number that is smaller than or equal to
43    /// `self` and is a multiple of the given power of two.
44    ///
45    /// The method panics if `power_of_two` is not a
46    /// power of two or is smaller than 2 or the calculation overflows
47    /// because `self` is too large. In release mode,
48    ///
49    /// # Examples
50    ///
51    /// ```
52    /// use crate::align_ext::AlignExt;
53    /// assert_eq!(12usize.align_down(2), 12);
54    /// assert_eq!(12usize.align_down(4), 12);
55    /// assert_eq!(12usize.align_down(8), 8);
56    /// assert_eq!(12usize.align_down(16), 0);
57    /// ```
58    fn align_down(self, power_of_two: Self) -> Self;
59}
60
61verus! {
62
63proof fn lemma_usize_low_bits_mask_is_mod(x: usize, n: nat)
64    requires
65        n < usize::BITS,
66    ensures
67        (x & (low_bits_mask(n) as usize)) == x % (pow2(n) as usize),
68{
69    if usize::BITS == 64 {
70        lemma_u64_low_bits_mask_is_mod(x as u64, n);
71    } else {
72        lemma_u32_low_bits_mask_is_mod(x as u32, n);
73    }
74}
75
76} // verus!
77macro_rules! call_lemma_low_bits_mask_is_mod {
78    (u8, $x:expr, $n:expr) => {
79        lemma_u8_low_bits_mask_is_mod($x, $n)
80    };
81    (u16, $x:expr, $n:expr) => {
82        lemma_u16_low_bits_mask_is_mod($x, $n)
83    };
84    (u32, $x:expr, $n:expr) => {
85        lemma_u32_low_bits_mask_is_mod($x, $n)
86    };
87    (u64, $x:expr, $n:expr) => {
88        lemma_u64_low_bits_mask_is_mod($x, $n)
89    };
90    (usize, $x:expr, $n:expr) => {
91        lemma_usize_low_bits_mask_is_mod($x, $n)
92    };
93}
94
95macro_rules! impl_align_ext {
96    ($( $uint_type:ty ),+,) => {
97        $(
98            verus!{
99            /// # Verified Properties
100            /// ## Safety
101            /// The implementation is written in safe Rust and there is no undefined behavior.
102            /// ## Functional correctness
103            /// The implementation meets the specification given in the trait `AlignExt`.
104            impl AlignExt for $uint_type {
105                /// ## Preconditions
106                /// - `self + (align - 1)` does not overflow.
107                /// ## Postconditions
108                /// - `align` is a power of two `>= 2` (panic-enforced; the
109                ///   function panics on invalid `align`, so a returning call
110                ///   guarantees validity).
111                /// - The return value is the smallest number that is greater
112                ///   than or equal to `self` and is a multiple of `align`.
113                #[inline]
114                #[verus_spec(ret =>
115                    requires
116                        self + (align - 1) <= $uint_type::MAX,
117                        !(align >= 2 && is_pow2(align as int)) ==> may_panic(),
118                    ensures
119                        align >= 2,
120                        is_pow2(align as int),
121                        ret >= self,
122                        ret % align == 0,
123                        ret == nat_align_up(self as nat, align as nat),
124                        forall |n: nat| !(n>=self && #[trigger] (n % align as nat) == 0) || (ret <= n),
125                )]
126                fn align_up(self, align: Self) -> Self {
127                    vstd_extra::assert!(align.is_power_of_two() && align >= 2);
128                    proof!{
129                        is_pow2_equiv(align as int);
130                        let x_int = self as int + align as int - 1;
131                        let x = x_int as Self;
132                        if self as int % align as int == 0 {
133                            assert((align as int - 1) % align as int == align as int - 1) by {
134                                lemma_small_mod((align as int - 1) as nat, align as nat);
135                            }
136                            assert(x_int % align as int == align as int - 1) by {
137                                lemma_mod_adds(self as int, align as int - 1, align as int);
138                            }
139                        } else {
140                            let q = self as int / align as int;
141                            let r = self as int % align as int;
142                            lemma_fundamental_div_mod(self as int, align as int);
143                            assert(self as int == q * align as int + r) by {
144                                lemma_mul_is_commutative(align as int, q);
145                            }
146
147                            assert((q + 1) * align as int == q * align as int + align as int) by {
148                                lemma_mul_is_distributive_add_other_way(align as int, q, 1);
149                            }
150                            assert(((q + 1) * align as int) % align as int == 0) by {
151                                lemma_mod_multiples_basic(q + 1, align as int);
152                            }
153                            assert((r - 1) % align as int == r - 1) by {
154                                lemma_small_mod((r - 1) as nat, align as nat);
155                            }
156                            assert(x_int == (q + 1) * align as int + (r - 1));
157                            assert(x_int % align as int == (r - 1)) by {
158                                lemma_mod_adds((q + 1) * align as int, r - 1, align as int);
159                            }
160                        }
161
162                        lemma_low_bits_mask_values();
163                        let mask = (align - 1) as Self;
164                        let e = choose |e: nat| pow(2, e) == align;
165                        lemma_pow2(e);
166                        assert(e < $uint_type::BITS) by {
167                            if e >= $uint_type::BITS {
168                                lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
169                                lemma2_to64();
170                            }
171                        }
172                        call_lemma_low_bits_mask_is_mod!($uint_type, x, e);
173                        assert(x == (x & mask) + (x & !mask)) by (bit_vector);
174                        lemma_nat_align_up_sound(self as nat, align as nat);
175                    }
176                    self.checked_add(align - 1).unwrap() & !(align - 1)
177                }
178
179                #[inline]
180                #[verus_spec(ret =>
181                    requires
182                        !(is_pow2(align as int) && align >= 2) ==> may_panic(),
183                    ensures
184                        align >= 2,
185                        is_pow2(align as int),
186                        ret <= self,
187                        ret % align == 0,
188                        ret == nat_align_down(self as nat, align as nat),
189                        forall |n: nat|  !(n<=self && #[trigger] (n % align as nat) == 0) || (ret >= n),
190                )]
191
192                /// ## Postconditions
193                /// - `align` is a power of two `>= 2` (panic-enforced; the
194                ///   function panics on invalid `align`, so a returning call
195                ///   guarantees validity).
196                /// - The return value is the greatest number that is smaller than or equal to `self` and is a multiple of `align`.
197                fn align_down(self, align: Self) -> Self {
198                    vstd_extra::assert!(align.is_power_of_two() && align >= 2);
199                    proof!{
200                        is_pow2_equiv(align as int);
201                        lemma_low_bits_mask_values();
202                        let mask = (align - 1) as Self;
203                        let e = choose |e: nat| pow(2, e) == align;
204                        lemma_pow2(e);
205                        assert(e < $uint_type::BITS) by {
206                            if e >= $uint_type::BITS {
207                                lemma_pow2_strictly_increases($uint_type::BITS as nat, e);
208                                lemma2_to64();
209                            }
210                        }
211                        call_lemma_low_bits_mask_is_mod!($uint_type, self, e);
212                        assert(self == (self & mask) + (self & !mask)) by (bit_vector);
213                        assert((self & !mask) as nat == nat_align_down(self as nat, align as nat));
214                        lemma_nat_align_down_sound(self as nat, align as nat);
215                    }
216                    self & !(align - 1)
217                }
218            }
219        }
220            )*
221    }
222}
223
224impl_align_ext! {
225    u8,
226    u16,
227    u32,
228    u64,
229    usize,
230}
231
232#[cfg(test)]
233mod test {
234    use super::*;
235
236    #[test]
237    fn test_align_up() {
238        let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
239        let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
240        let output_ns = [0usize, 2, 2, 10, 16, 24, 32, 48, 56];
241
242        for i in 0..input_ns.len() {
243            let n = input_ns[i];
244            let a = input_as[i];
245            let n2 = output_ns[i];
246            assert!(n.align_up(a) == n2);
247        }
248    }
249
250    #[test]
251    fn test_align_down() {
252        let input_ns = [0usize, 1, 2, 9, 15, 21, 32, 47, 50];
253        let input_as = [2usize, 2, 2, 2, 4, 4, 8, 8, 8];
254        let output_ns = [0usize, 0, 2, 8, 12, 20, 32, 40, 48];
255
256        for i in 0..input_ns.len() {
257            let n = input_ns[i];
258            let a = input_as[i];
259            let n2 = output_ns[i];
260            assert!(n.align_down(a) == n2);
261        }
262    }
263}