ostd/specs/mm/
virt_mem.rs

1use vstd::pervasive::arbitrary;
2use vstd::prelude::*;
3
4use vstd::layout;
5use vstd::raw_ptr;
6use vstd::set;
7use vstd::set_lib;
8
9use core::marker::PhantomData;
10use core::ops::Range;
11
12use crate::mm::{Paddr, Vaddr};
13use crate::prelude::Inv;
14use crate::specs::mm::page_table::Mapping;
15
16verus! {
17
18/// Concrete representation of a pointer
19pub struct VirtPtr {
20    pub vaddr: Vaddr,
21    pub ghost range: Ghost<Range<Vaddr>>,
22}
23
24pub struct FrameContents {
25    pub contents: Map<Paddr, raw_ptr::MemContents<u8>>,
26    pub ghost range: Ghost<Range<Paddr>>,
27}
28
29pub tracked struct MemView {
30    pub mappings: Set<Mapping>,
31    pub memory: Map<Paddr, raw_ptr::MemContents<u8>>
32}
33
34impl MemView {
35    pub open spec fn addr_transl(self, va: usize) -> Option<usize> {
36        let mappings = self.mappings.filter(|m: Mapping| m.va_range.start <= va < m.va_range.end);
37        if 0 < mappings.len() {
38            let m = mappings.choose();  // In a well-formed PT there will only be one, but if malformed this is non-deterministic!
39            let off = va - m.va_range.start;
40            Some((m.pa_range.start + off) as usize)
41        } else {
42            None
43        }
44    }
45
46    pub open spec fn read(self, va: usize) -> Option<raw_ptr::MemContents<u8>> {
47        let pa = self.addr_transl(va);
48        if pa is Some {
49            Some(self.memory[pa.unwrap()])
50        } else {
51            None
52        }
53    }
54
55    pub open spec fn write(self, va: usize, x: u8) -> Option<Self> {
56        let pa = self.addr_transl(va);
57        if pa is Some {
58            Some(
59                MemView {
60                    memory: self.memory.insert(pa.unwrap(), raw_ptr::MemContents::Init(x)),
61                    ..self
62                },
63            )
64        } else {
65            None
66        }
67    }
68
69    pub open spec fn eq_at(self, va1: usize, va2: usize) -> bool {
70        let pa1 = self.addr_transl(va1);
71        let pa2 = self.addr_transl(va2);
72        if pa1 is Some && pa2 is Some {
73            self.memory[pa1.unwrap()] == self.memory[pa2.unwrap()]
74        } else {
75            false
76        }
77    }
78
79    pub open spec fn borrow_at_spec(&self, vaddr: usize, len: usize) -> MemView {
80        let range_end = vaddr + len;
81
82        let valid_pas = Set::new(
83            |pa: usize|
84                exists|va: usize|
85                    vaddr <= va < range_end && #[trigger] self.addr_transl(va) == Some(pa),
86        );
87
88        MemView {
89            mappings: self.mappings.filter(
90                |m: Mapping| m.va_range.start < range_end && m.va_range.end > vaddr,
91            ),
92            memory: self.memory.restrict(valid_pas),
93        }
94    }
95
96    pub open spec fn mappings_are_disjoint(self) -> bool {
97        forall|m1: Mapping, m2: Mapping|
98            #![trigger self.mappings.contains(m1), self.mappings.contains(m2)]
99            self.mappings.contains(m1) && self.mappings.contains(m2) && m1 != m2 ==> {
100                m1.va_range.end <= m2.va_range.start || m2.va_range.end <= m1.va_range.start
101            }
102    }
103
104    pub open spec fn split_spec(self, vaddr: usize, len: usize) -> (MemView, MemView) {
105        let split_end = vaddr + len;
106
107        // The left part.
108        let left_mappings = self.mappings.filter(
109            |m: Mapping| m.va_range.start < split_end && m.va_range.end > vaddr,
110        );
111        let right_mappings = self.mappings.filter(|m: Mapping| m.va_range.end > split_end);
112
113        let left_pas = Set::new(
114            |pa: usize|
115                exists|va: usize| vaddr <= va < split_end && self.addr_transl(va) == Some(pa),
116        );
117        let right_pas = Set::new(
118            |pa: usize| exists|va: usize| va >= split_end && self.addr_transl(va) == Some(pa),
119        );
120
121        (
122            MemView { mappings: left_mappings, memory: self.memory.restrict(left_pas) },
123            MemView { mappings: right_mappings, memory: self.memory.restrict(right_pas) },
124        )
125    }
126
127    pub proof fn lemma_disjoint_filter_at_one(&self, va: usize)
128        requires
129            self.mappings_are_disjoint(),
130        ensures
131            self.mappings.filter(
132                |m: Mapping| m.va_range.start <= va < m.va_range.end,
133            ).len() <= 1,
134    {
135        admit();
136    }
137
138    /// Borrows a memory view for a sub-range.
139    #[verifier::external_body]
140    pub proof fn borrow_at(tracked &self, vaddr: usize, len: usize) -> (tracked r: &MemView)
141        ensures
142            r == self.borrow_at_spec(vaddr, len),
143    {
144        unimplemented!()
145    }
146
147    /// Splits the memory view into two disjoint views.
148    ///
149    /// Returns the split memory views where the first is
150    /// for `[vaddr, vaddr + len)` and the second is for the rest.
151    #[verifier::external_body]
152    pub proof fn split(tracked self, vaddr: usize, len: usize) -> (tracked r: (Self, Self))
153        ensures
154            r == self.split_spec(vaddr, len),
155    {
156        unimplemented!()
157    }
158
159    /// This proves that if split is performed and we have
160    /// (lhs, rhs) = self.split(vaddr, len), then we have
161    /// all translations preserved in lhs and rhs.
162    pub proof fn lemma_split_preserves_transl(
163        original: MemView,
164        vaddr: usize,
165        len: usize,
166        left: MemView,
167        right: MemView,
168    )
169        requires
170            original.split_spec(vaddr, len) == (left, right),
171        ensures
172            right.memory.dom().subset_of(original.memory.dom()),
173            forall|va: usize|
174                vaddr <= va < vaddr + len ==> {
175                    #[trigger] original.addr_transl(va) == left.addr_transl(va)
176                },
177            forall|va: usize|
178                va >= vaddr + len ==> {
179                    #[trigger] original.addr_transl(va) == right.addr_transl(va)
180                },
181    {
182        // Auto.
183        assert(right.memory.dom().subset_of(original.memory.dom()));
184
185        assert forall|va: usize| vaddr <= va < vaddr + len implies original.addr_transl(va)
186            == left.addr_transl(va) by {
187            assert(left.mappings =~= original.mappings.filter(
188                |m: Mapping| m.va_range.start < vaddr + len && m.va_range.end > vaddr,
189            ));
190            let o_mappings = original.mappings.filter(
191                |m: Mapping| m.va_range.start <= va < m.va_range.end,
192            );
193            let l_mappings = left.mappings.filter(
194                |m: Mapping| m.va_range.start <= va < m.va_range.end,
195            );
196
197            assert(l_mappings.subset_of(o_mappings));
198            assert(o_mappings.subset_of(l_mappings)) by {
199                assert forall|m: Mapping| #[trigger]
200                    o_mappings.contains(m) implies l_mappings.contains(m) by {
201                    assume(o_mappings.contains(m));
202                    assert(m.va_range.start < vaddr + len);
203                    assert(m.va_range.end > vaddr);
204                    assert(m.va_range.start <= va < m.va_range.end);
205                    assert(left.mappings.contains(m));
206                }
207            };
208
209            assert(o_mappings =~= l_mappings);
210        }
211
212        assert forall|va: usize| va >= vaddr + len implies original.addr_transl(va)
213            == right.addr_transl(va) by {
214            let split_end = vaddr + len;
215
216            let o_mappings = original.mappings.filter(
217                |m: Mapping| m.va_range.start <= va < m.va_range.end,
218            );
219            let r_mappings = right.mappings.filter(
220                |m: Mapping| m.va_range.start <= va < m.va_range.end,
221            );
222
223            assert forall|m: Mapping| o_mappings.contains(m) implies r_mappings.contains(m) by {
224                assert(m.va_range.end > va);
225                assert(va >= split_end);
226                assert(m.va_range.end > split_end);
227
228                assert(right.mappings.contains(m));
229                assert(r_mappings.contains(m));
230            }
231
232            assert(o_mappings =~= r_mappings);
233        }
234    }
235
236    pub open spec fn join_spec(self, other: MemView) -> MemView {
237        MemView {
238            mappings: self.mappings.union(other.mappings),
239            memory: self.memory.union_prefer_right(other.memory),
240        }
241    }
242
243    /// Merges two disjoint memory views back into one.
244    #[verifier::external_body]
245    pub proof fn join(tracked &mut self, tracked other: Self)
246        requires
247            old(self).mappings.disjoint(other.mappings),
248        ensures
249            *self == old(self).join_spec(other),
250    {
251        unimplemented!()
252    }
253
254    #[verifier::external_body]
255    pub proof fn lemma_split_join_identity(
256        this: MemView,
257        lhs: MemView,
258        rhs: MemView,
259        vaddr: usize,
260        len: usize,
261    )
262        requires
263            this.split_spec(vaddr, len) == (lhs, rhs),
264        ensures
265            this == lhs.join_spec(rhs),
266    {
267        // Auto.
268    }
269}
270
271impl Inv for VirtPtr {
272    open spec fn inv(self) -> bool {
273        &&& self.range@.start <= self.vaddr <= self.range@.end
274        &&& self.range@.start > 0
275        &&& self.range@.end >= self.range@.start
276    }
277}
278
279impl Clone for VirtPtr {
280    fn clone(&self) -> (res: Self)
281        ensures
282            res == self,
283    {
284        Self { vaddr: self.vaddr, range: Ghost(self.range@) }
285    }
286}
287
288impl Copy for VirtPtr {
289
290}
291
292impl VirtPtr {
293    pub open spec fn is_defined(self) -> bool {
294        &&& self.vaddr != 0
295        &&& self.range@.start <= self.vaddr <= self.range@.end
296    }
297
298    pub open spec fn is_valid(self) -> bool {
299        &&& self.is_defined()
300        &&& self.vaddr < self.range@.end
301    }
302
303    #[verifier::external_body]
304    pub fn read(self, Tracked(mem): Tracked<&MemView>) -> u8
305        requires
306            mem.addr_transl(self.vaddr) is Some,
307            mem.memory[mem.addr_transl(self.vaddr).unwrap()] is Init,
308            self.is_valid(),
309        returns
310            mem.read(self.vaddr).unwrap().value(),
311    {
312        unimplemented!()
313    }
314
315    #[verifier::external_body]
316    pub fn write(self, Tracked(mem): Tracked<&mut MemView>, x: u8)
317        requires
318            old(mem).addr_transl(self.vaddr) is Some,
319            self.is_valid(),
320        ensures
321            *mem == old(mem).write(self.vaddr, x).unwrap(),
322    {
323        unimplemented!()
324    }
325
326    pub open spec fn add_spec(self, n: usize) -> Self {
327        VirtPtr { vaddr: (self.vaddr + n) as usize, range: self.range }
328    }
329
330    pub fn add(&mut self, n: usize)
331        requires
332    // Option 1: strict C standard compliance
333    // old(self).range@.start <= self.vaddr + n <= old(self).range@.end,
334    // Option 2: just make sure it doesn't overflow
335
336            0 <= old(self).vaddr + n < usize::MAX,
337        ensures
338            *self == old(self).add_spec(
339                n,
340            ),
341    // If we take option 1, we can also ensure:
342    // self.is_defined()
343
344    {
345        self.vaddr = self.vaddr + n
346    }
347
348    pub open spec fn read_offset_spec(self, mem: MemView, n: usize) -> u8 {
349        mem.read((self.vaddr + n) as usize).unwrap().value()
350    }
351
352    /// Unlike `add`, we just create a temporary pointer value and read that
353    /// When `self.vaddr == self.range.start` this acts like array index notation
354    pub fn read_offset(&self, Tracked(mem): Tracked<&MemView>, n: usize) -> u8
355        requires
356            0 < self.vaddr + n < usize::MAX,
357            self.range@.start <= self.vaddr + n < self.range@.end,
358            mem.addr_transl((self.vaddr + n) as usize) is Some,
359            mem.memory[mem.addr_transl((self.vaddr + n) as usize).unwrap()] is Init,
360        returns
361            self.read_offset_spec(*mem, n),
362    {
363        let mut tmp = self.clone();
364        tmp.add(n);
365        tmp.read(Tracked(mem))
366    }
367
368    pub open spec fn write_offset_spec(self, mem: MemView, n: usize, x: u8) -> MemView {
369        mem.write((self.vaddr + n) as usize, x).unwrap()
370    }
371
372    pub fn write_offset(&self, Tracked(mem): Tracked<&mut MemView>, n: usize, x: u8)
373        requires
374            self.inv(),
375            self.range@.start <= self.vaddr + n < self.range@.end,
376            old(mem).addr_transl((self.vaddr + n) as usize) is Some,
377    {
378        let mut tmp = self.clone();
379        tmp.add(n);
380        tmp.write(Tracked(mem), x)
381    }
382
383    pub open spec fn copy_offset_spec(src: Self, dst: Self, mem: MemView, n: usize) -> MemView {
384        let x = src.read_offset_spec(mem, n);
385        dst.write_offset_spec(mem, n, x)
386    }
387
388    pub fn copy_offset(src: &Self, dst: &Self, Tracked(mem): Tracked<&mut MemView>, n: usize)
389        requires
390            src.inv(),
391            dst.inv(),
392            src.range@.start <= src.vaddr + n < src.range@.end,
393            dst.range@.start <= dst.vaddr + n < dst.range@.end,
394            old(mem).addr_transl((src.vaddr + n) as usize) is Some,
395            old(mem).addr_transl((dst.vaddr + n) as usize) is Some,
396            old(mem).memory.contains_key(old(mem).addr_transl((src.vaddr + n) as usize).unwrap()),
397            old(mem).memory[old(mem).addr_transl((src.vaddr + n) as usize).unwrap()] is Init,
398        ensures
399            *mem == Self::copy_offset_spec(*src, *dst, *old(mem), n),
400    {
401        let x = src.read_offset(Tracked(mem), n);
402        proof { admit() }
403        ;
404        dst.write_offset(Tracked(mem), n, x)
405    }
406
407    pub open spec fn memcpy_spec(src: Self, dst: Self, mem: MemView, n: usize) -> MemView
408        decreases n,
409    {
410        if n == 0 {
411            mem
412        } else {
413            let mem = Self::copy_offset_spec(src, dst, mem, (n - 1) as usize);
414            Self::memcpy_spec(src, dst, mem, (n - 1) as usize)
415        }
416    }
417
418    /// Copies `n` bytes from `src` to `dst` in the given memory view.
419    ///
420    /// The source and destination must *not* overlap.
421    /// `copy_nonoverlapping` is semantically equivalent to C’s `memcpy`,
422    /// but with the source and destination arguments swapped.
423    pub fn copy_nonoverlapping(
424        src: &Self,
425        dst: &Self,
426        Tracked(mem): Tracked<&mut MemView>,
427        n: usize,
428    )
429        requires
430            src.inv(),
431            dst.inv(),
432            src.range@.start <= src.vaddr,
433            src.vaddr + n <= src.range@.end,
434            dst.range@.start <= dst.vaddr,
435            dst.vaddr + n < dst.range@.end,
436            src.range@.end <= dst.range@.start || dst.range@.end <= src.range@.start,
437            forall|i: usize|
438                src.vaddr <= i < src.vaddr + n ==> {
439                    &&& #[trigger] old(mem).addr_transl(i) is Some
440                    &&& old(mem).memory.contains_key(old(mem).addr_transl(i).unwrap())
441                    &&& old(mem).memory[old(mem).addr_transl(i).unwrap()] is Init
442                },
443            forall|i: usize|
444                dst.vaddr <= i < dst.vaddr + n ==> {
445                    &&& old(mem).addr_transl(i) is Some
446                },
447        ensures
448            *mem == Self::memcpy_spec(*src, *dst, *old(mem), n),
449        decreases n,
450    {
451        let ghost mem0 = *mem;
452
453        if n == 0 {
454            return ;
455        } else {
456            Self::copy_offset(src, dst, Tracked(mem), n - 1);
457            assert(forall|i: usize|
458                src.vaddr <= i < src.vaddr + n - 1 ==> mem.addr_transl(i) == mem0.addr_transl(i));
459            Self::copy_nonoverlapping(src, dst, Tracked(mem), n - 1);
460        }
461    }
462
463    pub fn from_vaddr(vaddr: usize, len: usize) -> (r: Self)
464        requires
465            vaddr != 0,
466            0 < len <= usize::MAX - vaddr,
467        ensures
468            r.is_valid(),
469            r.range@.start == vaddr,
470            r.range@.end == (vaddr + len) as usize,
471    {
472        Self { vaddr, range: Ghost(Range { start: vaddr, end: (vaddr + len) as usize }) }
473    }
474
475    /// Executable helper to split the VirtPtr struct
476    /// This updates the ghost ranges to match a MemView::split operation
477    #[verus_spec(r =>
478        requires
479            self.is_valid(),
480            0 <= n <= (self.range@.end - self.range@.start),
481            self.vaddr == self.range@.start,
482        ensures
483            r.0.range@.start == self.range@.start,
484            r.0.range@.end == self.range@.start + n,
485            r.0.vaddr == self.range@.start,
486            r.1.range@.start == self.range@.start + n,
487            r.1.range@.end == self.range@.end,
488            r.1.vaddr == self.range@.start + n,
489    )]
490    pub fn split(self, n: usize) -> (Self, Self) {
491        let left = VirtPtr {
492            vaddr: self.vaddr,
493            range: Ghost(Range { start: self.vaddr, end: (self.vaddr + n) as usize }),
494        };
495
496        let right = VirtPtr {
497            vaddr: self.vaddr + n,
498            range: Ghost(Range { start: (self.vaddr + n) as usize, end: self.range@.end }),
499        };
500
501        (left, right)
502    }
503}
504
505} // verus!