ostd/io/io_mem/
mod.rs

1// SPDX-License-Identifier: MPL-2.0
2
3//! I/O memory and its allocator that allocates memory I/O (MMIO) to device drivers.
4
5mod allocator;
6pub(crate) mod util;
7
8use core::{
9    marker::PhantomData,
10    ops::{Deref, Range},
11};
12
13use align_ext::AlignExt;
14use inherit_methods_macro::inherit_methods;
15
16pub(crate) use self::allocator::IoMemAllocatorBuilder;
17pub(super) use self::allocator::init;
18#[cfg(all(target_arch = "x86_64", feature = "cvm_guest"))]
19use crate::arch::{if_tdx_enabled, tdx_guest::unprotect_gpa_tdvm_call};
20use crate::{
21    Error,
22    arch::io::io_mem::{read_once, write_once},
23    cpu::{AtomicCpuSet, CpuSet},
24    mm::{
25        Fallible, HasPaddr, HasSize, Infallible, PAGE_SIZE, Paddr, PodOnce, VmIo, VmIoFill,
26        VmIoOnce, VmReader, VmWriter,
27        io::{
28            Io,
29            copy::{memcpy, memset},
30        },
31        kspace::kvirt_area::KVirtArea,
32        page_prop::{CachePolicy, PageFlags, PageProperty, PrivilegedPageFlags},
33        tlb::{TlbFlushOp, TlbFlusher},
34    },
35    prelude::*,
36    task::disable_preempt,
37};
38
39/// A marker type used for [`IoMem`],
40/// representing that the underlying MMIO is used for security-sensitive operations.
41#[derive(Clone, Debug)]
42pub(crate) enum Sensitive {}
43
44/// A marker type used for [`IoMem`],
45/// representing that the underlying MMIO is used for security-insensitive operations.
46#[derive(Clone, Debug)]
47pub enum Insensitive {}
48
49/// I/O memory.
50#[derive(Clone, Debug)]
51pub struct IoMem<SecuritySensitivity = Insensitive> {
52    kvirt_area: Arc<KVirtArea>,
53    // The actually used range for MMIO is `kvirt_area.start + offset..kvirt_area.start + offset + limit`
54    offset: usize,
55    limit: usize,
56    pa: Paddr,
57    cache_policy: CachePolicy,
58    phantom: PhantomData<SecuritySensitivity>,
59}
60
61impl<SecuritySensitivity> IoMem<SecuritySensitivity> {
62    /// Slices the `IoMem`, returning another `IoMem` representing the subslice.
63    ///
64    /// # Panics
65    ///
66    /// This method will panic if the range is empty or out of bounds.
67    pub fn slice(&self, range: Range<usize>) -> Self {
68        // This ensures `range.start < range.end` and `range.end <= limit`.
69        assert!(!range.is_empty() && range.end <= self.limit);
70
71        // We've checked the range is in bounds, so we can construct the new `IoMem` safely.
72        Self {
73            kvirt_area: self.kvirt_area.clone(),
74            offset: self.offset + range.start,
75            limit: range.len(),
76            pa: self.pa + range.start,
77            cache_policy: self.cache_policy,
78            phantom: PhantomData,
79        }
80    }
81
82    /// Creates a new `IoMem`.
83    ///
84    /// # Safety
85    ///
86    /// 1. This function must be called after the kernel page table is activated.
87    /// 2. The given physical address range must be in the I/O memory region.
88    /// 3. Reading from or writing to I/O memory regions may have side effects.
89    ///    If `SecuritySensitivity` is `Insensitive`, those side effects must
90    ///    not cause soundness problems (e.g., they must not corrupt the kernel
91    ///    memory).
92    pub(crate) unsafe fn new(range: Range<Paddr>, flags: PageFlags, cache: CachePolicy) -> Self {
93        let first_page_start = range.start.align_down(PAGE_SIZE);
94        let last_page_end = range.end.align_up(PAGE_SIZE);
95
96        let frames_range = first_page_start..last_page_end;
97        let area_size = frames_range.len();
98
99        #[cfg(target_arch = "x86_64")]
100        let priv_flags = if_tdx_enabled!({
101            assert!(
102                first_page_start == range.start && last_page_end == range.end,
103                "I/O memory is not page aligned, which cannot be unprotected in TDX: {:#x?}..{:#x?}",
104                range.start,
105                range.end,
106            );
107
108            // SAFETY:
109            //  - The range `first_page_start..last_page_end` is always page aligned.
110            //  - FIXME: We currently do not limit the I/O memory allocator with the maximum GPA,
111            //    so the address range may not fall in the GPA limit.
112            //  - The caller guarantees that operations on the I/O memory do not have any side
113            //    effects that may cause soundness problems, so the pages can safely be viewed as
114            //    untyped memory.
115            unsafe { unprotect_gpa_tdvm_call(first_page_start, area_size).unwrap() };
116
117            PrivilegedPageFlags::SHARED
118        } else {
119            PrivilegedPageFlags::empty()
120        });
121        #[cfg(not(target_arch = "x86_64"))]
122        let priv_flags = PrivilegedPageFlags::empty();
123
124        let prop = PageProperty {
125            flags,
126            cache,
127            priv_flags,
128        };
129
130        let kva = {
131            // SAFETY: The caller of `IoMem::new()` ensures that the given
132            // physical address range is I/O memory, so it is safe to map.
133            let kva = unsafe { KVirtArea::map_untracked_frames(area_size, 0, frames_range, prop) };
134
135            let target_cpus = AtomicCpuSet::new(CpuSet::new_full());
136            let mut flusher = TlbFlusher::new(&target_cpus, disable_preempt());
137            flusher.issue_tlb_flush(TlbFlushOp::for_range(kva.range()));
138            flusher.dispatch_tlb_flush();
139            flusher.sync_tlb_flush();
140
141            kva
142        };
143
144        Self {
145            kvirt_area: Arc::new(kva),
146            offset: range.start - first_page_start,
147            limit: range.len(),
148            pa: range.start,
149            cache_policy: cache,
150            phantom: PhantomData,
151        }
152    }
153
154    /// Returns the cache policy of this `IoMem`.
155    pub fn cache_policy(&self) -> CachePolicy {
156        self.cache_policy
157    }
158
159    /// Returns the base virtual address of the MMIO range.
160    fn base(&self) -> usize {
161        self.kvirt_area.deref().start() + self.offset
162    }
163
164    /// Validates that the offset range lies within the MMIO window.
165    fn check_range(&self, offset: usize, len: usize) -> Result<()> {
166        if offset.checked_add(len).is_none_or(|end| end > self.limit) {
167            return Err(Error::InvalidArgs);
168        }
169        Ok(())
170    }
171}
172
173#[cfg_attr(target_arch = "loongarch64", expect(unused))]
174impl IoMem<Sensitive> {
175    /// Reads a value of the `PodOnce` type at the specified offset using one
176    /// non-tearing memory load.
177    ///
178    /// Except that the offset is specified explicitly, the semantics of this
179    /// method is the same as [`VmReader::read_once`].
180    ///
181    /// # Safety
182    ///
183    /// The caller must ensure that the offset and the read operation is valid,
184    /// e.g., follows the specification when used for implementing drivers, does
185    /// not cause any out-of-bounds access, and does not cause unsound side
186    /// effects (e.g., corrupting the kernel memory).
187    pub(crate) unsafe fn read_once<T: PodOnce>(&self, offset: usize) -> T {
188        debug_assert!(offset + size_of::<T>() <= self.limit);
189        let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *const T;
190        // SAFETY: The safety of the read operation's semantics is upheld by the caller.
191        unsafe { read_once(ptr) }
192    }
193
194    /// Writes a value of the `PodOnce` type at the specified offset using one
195    /// non-tearing memory store.
196    ///
197    /// Except that the offset is specified explicitly, the semantics of this
198    /// method is the same as [`VmWriter::write_once`].
199    ///
200    /// # Safety
201    ///
202    /// The caller must ensure that the offset and the write operation is valid,
203    /// e.g., follows the specification when used for implementing drivers, does
204    /// not cause any out-of-bounds access, and does not cause unsound side
205    /// effects (e.g., corrupting the kernel memory).
206    pub(crate) unsafe fn write_once<T: PodOnce>(&self, offset: usize, value: &T) {
207        debug_assert!(offset + size_of::<T>() <= self.limit);
208        let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *mut T;
209        // SAFETY: The safety of the write operation's semantics is upheld by the caller.
210        unsafe { write_once(ptr, *value) };
211    }
212}
213
214impl IoMem<Insensitive> {
215    /// Acquires an `IoMem` instance for the given range.
216    ///
217    /// The I/O memory cache policy is set to uncacheable by default.
218    pub fn acquire(range: Range<Paddr>) -> Result<IoMem<Insensitive>> {
219        Self::acquire_with_cache_policy(range, CachePolicy::Uncacheable)
220    }
221
222    /// Acquires an `IoMem` instance for the given range with the specified cache policy.
223    pub fn acquire_with_cache_policy(
224        range: Range<Paddr>,
225        cache_policy: CachePolicy,
226    ) -> Result<IoMem<Insensitive>> {
227        allocator::IO_MEM_ALLOCATOR
228            .get()
229            .unwrap()
230            .acquire(range, cache_policy)
231            .ok_or(Error::AccessDenied)
232    }
233
234    /// Reads from MMIO into fallible memory and returns the copied length.
235    ///
236    /// This method performs the same low-level copy primitive as [`VmIo::read`],
237    /// but exposes partial progress instead of enforcing no-short-read semantics.
238    pub fn read_fallible(
239        &self,
240        offset: usize,
241        writer: &mut VmWriter,
242    ) -> core::result::Result<usize, (Error, usize)> {
243        let len = writer.avail();
244        self.check_range(offset, len).map_err(|err| (err, 0))?;
245
246        let src = (self.base() + offset) as *const u8;
247        // SAFETY: `src` points to a validated MMIO range and `writer.cursor()` points to
248        // fallible destination memory tracked by `writer`.
249        let copied = unsafe { memcpy::<Fallible, Io>(writer.cursor(), src, len) };
250        writer.skip(copied);
251
252        if copied < len {
253            Err((Error::PageFault, copied))
254        } else {
255            Ok(copied)
256        }
257    }
258
259    /// Writes from fallible memory to MMIO and returns the copied length.
260    ///
261    /// This method performs the same low-level copy primitive as [`VmIo::write`],
262    /// but exposes partial progress instead of enforcing no-short-write semantics.
263    pub fn write_fallible(
264        &self,
265        offset: usize,
266        reader: &mut VmReader,
267    ) -> core::result::Result<usize, (Error, usize)> {
268        let len = reader.remain();
269        self.check_range(offset, len).map_err(|err| (err, 0))?;
270
271        let dst = (self.base() + offset) as *mut u8;
272        // SAFETY: `dst` points to a validated MMIO range and `reader.cursor()` points to
273        // fallible source memory tracked by `reader`.
274        let copied = unsafe { memcpy::<Io, Fallible>(dst, reader.cursor(), len) };
275        reader.skip(copied);
276
277        if copied < len {
278            Err((Error::PageFault, copied))
279        } else {
280            Ok(copied)
281        }
282    }
283}
284
285impl VmIoOnce for IoMem<Insensitive> {
286    fn read_once<T: PodOnce>(&self, offset: usize) -> Result<T> {
287        self.check_range(offset, size_of::<T>())?;
288        let ptr = (self.base() + offset) as *const T;
289        if !ptr.is_aligned() {
290            return Err(Error::InvalidArgs);
291        }
292
293        // SAFETY: The pointer is properly aligned and within the validated range.
294        let val = unsafe { read_once(ptr) };
295        Ok(val)
296    }
297
298    fn write_once<T: PodOnce>(&self, offset: usize, value: &T) -> Result<()> {
299        self.check_range(offset, size_of::<T>())?;
300        let ptr = (self.base() + offset) as *mut T;
301        if !ptr.is_aligned() {
302            return Err(Error::InvalidArgs);
303        }
304
305        // SAFETY: The pointer is properly aligned and within the validated range.
306        unsafe { write_once(ptr, *value) };
307        Ok(())
308    }
309}
310
311impl VmIo for IoMem<Insensitive> {
312    fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()> {
313        let len = writer.avail();
314        self.check_range(offset, len)?;
315
316        let src = (self.base() + offset) as *const u8;
317        // SAFETY: `src` points to a validated MMIO range and `writer.cursor()` points to
318        // fallible destination memory tracked by `writer`.
319        let copied = unsafe { memcpy::<Fallible, Io>(writer.cursor(), src, len) };
320        if copied < len {
321            return Err(Error::PageFault);
322        }
323
324        writer.skip(copied);
325        Ok(())
326    }
327
328    fn read_bytes(&self, offset: usize, buf: &mut [u8]) -> Result<()> {
329        let len = buf.len();
330        self.check_range(offset, len)?;
331        let src = (self.base() + offset) as *const u8;
332        let dst = buf.as_mut_ptr();
333
334        // SAFETY: The `dst` and `src` buffers are valid to write and read for `len` bytes.
335        unsafe { memcpy::<Infallible, Io>(dst, src, len) };
336        Ok(())
337    }
338
339    fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()> {
340        let len = reader.remain();
341        self.check_range(offset, len)?;
342
343        let dst = (self.base() + offset) as *mut u8;
344        // SAFETY: `dst` points to a validated MMIO range and `reader.cursor()` points to
345        // fallible source memory tracked by `reader`.
346        let copied = unsafe { memcpy::<Io, Fallible>(dst, reader.cursor(), len) };
347        if copied < len {
348            return Err(Error::PageFault);
349        }
350
351        reader.skip(copied);
352        Ok(())
353    }
354
355    fn write_bytes(&self, offset: usize, buf: &[u8]) -> Result<()> {
356        let len = buf.len();
357        self.check_range(offset, len)?;
358        let src = buf.as_ptr();
359        let dst = (self.base() + offset) as *mut u8;
360
361        // SAFETY: The `dst` and `src` buffers are valid to write and read for `len` bytes.
362        unsafe { memcpy::<Io, Infallible>(dst, src, len) };
363        Ok(())
364    }
365}
366
367impl VmIoFill for IoMem<Insensitive> {
368    fn fill_zeros(&self, offset: usize, len: usize) -> core::result::Result<(), (Error, usize)> {
369        if len == 0 {
370            return Ok(());
371        }
372
373        if offset > self.limit {
374            return Err((Error::InvalidArgs, 0));
375        }
376
377        let available = self.limit - offset;
378        let write_len = core::cmp::min(len, available);
379        if write_len == 0 {
380            return Err((Error::InvalidArgs, 0));
381        }
382
383        let dst = (self.base() + offset) as *mut u8;
384        // SAFETY: `dst` points to the validated MMIO subrange of `write_len` bytes.
385        unsafe { memset::<Io>(dst, 0u8, write_len) };
386
387        if write_len == len {
388            Ok(())
389        } else {
390            Err((Error::InvalidArgs, write_len))
391        }
392    }
393}
394
395macro_rules! impl_vm_io_pointer {
396    ($ty:ty, $from:tt) => {
397        #[inherit_methods(from = $from)]
398        impl VmIo for $ty {
399            fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()>;
400            fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()>;
401        }
402
403        #[inherit_methods(from = $from)]
404        impl VmIoOnce for $ty {
405            fn read_once<T: PodOnce>(&self, offset: usize) -> Result<T>;
406            fn write_once<T: PodOnce>(&self, offset: usize, value: &T) -> Result<()>;
407        }
408
409        #[inherit_methods(from = $from)]
410        impl VmIoFill for $ty {
411            fn fill_zeros(
412                &self,
413                offset: usize,
414                len: usize,
415            ) -> core::result::Result<(), (Error, usize)>;
416        }
417    };
418}
419
420impl_vm_io_pointer!(&IoMem<Insensitive>, "(**self)");
421impl_vm_io_pointer!(&mut IoMem<Insensitive>, "(**self)");
422
423impl<SecuritySensitivity> HasPaddr for IoMem<SecuritySensitivity> {
424    fn paddr(&self) -> Paddr {
425        self.pa
426    }
427}
428
429impl<SecuritySensitivity> HasSize for IoMem<SecuritySensitivity> {
430    fn size(&self) -> usize {
431        self.limit
432    }
433}
434
435impl<SecuritySensitivity> Drop for IoMem<SecuritySensitivity> {
436    fn drop(&mut self) {
437        // TODO: Multiple `IoMem` instances should not overlap, we should refactor the driver code and
438        // remove the `Clone` and `IoMem::slice`. After refactoring, the `Drop` can be implemented to recycle
439        // the `IoMem`.
440    }
441}
442
443#[cfg(ktest)]
444mod test {
445    use core::mem::size_of;
446
447    use crate::{
448        arch::io::io_mem::{copy_from_mmio, copy_to_mmio, read_once, write_once},
449        prelude::ktest,
450    };
451
452    #[ktest]
453    fn read_write_u8() {
454        let mut data: u8 = 0;
455        // SAFETY: `data` is valid for a single MMIO read/write.
456        unsafe {
457            write_once(&mut data, 42u8);
458            assert_eq!(read_once(&data), 42u8);
459        }
460    }
461
462    #[ktest]
463    fn read_write_u16() {
464        let mut data: u16 = 0;
465        let val: u16 = 0x1234;
466        // SAFETY: `data` is valid for a single MMIO read/write.
467        unsafe {
468            write_once(&mut data, val);
469            assert_eq!(read_once(&data), val);
470        }
471    }
472
473    #[ktest]
474    fn read_write_u32() {
475        let mut data: u32 = 0;
476        let val: u32 = 0x12345678;
477        // SAFETY: `data` is valid for a single MMIO read/write.
478        unsafe {
479            write_once(&mut data, val);
480            assert_eq!(read_once(&data), val);
481        }
482    }
483
484    #[ktest]
485    fn read_write_u64() {
486        let mut data: u64 = 0;
487        let val: u64 = 0xDEADBEEFCAFEBABE;
488        // SAFETY: `data` is valid for a single MMIO read/write.
489        unsafe {
490            write_once(&mut data, val);
491            assert_eq!(read_once(&data), val);
492        }
493    }
494
495    #[ktest]
496    fn boundary_overlap() {
497        let mut data: [u8; 2] = [0xAA, 0xBB];
498        // SAFETY: `data` is valid for a single MMIO read/write.
499        unsafe {
500            write_once(&mut data[0], 0x11u8);
501            assert_eq!(data[0], 0x11);
502            assert_eq!(data[1], 0xBB);
503        }
504    }
505
506    fn fill_pattern(buf: &mut [u8]) {
507        for (idx, byte) in buf.iter_mut().enumerate() {
508            *byte = (idx as u8).wrapping_mul(3).wrapping_add(1);
509        }
510    }
511
512    fn run_copy_from_case(src_offset: usize, dst_offset: usize, len: usize) {
513        let mut src = [0u8; 64];
514        let mut dst = [0u8; 64];
515        fill_pattern(&mut src);
516
517        // SAFETY: Offsets are validated by callers before this helper is invoked.
518        let src_ptr = unsafe { src.as_ptr().add(src_offset) };
519        // SAFETY: Offsets are validated by callers before this helper is invoked.
520        let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) };
521
522        // SAFETY: The test buffers are valid for the requested range.
523        unsafe { copy_from_mmio(dst_ptr, src_ptr, len) };
524
525        assert_eq!(
526            &dst[dst_offset..dst_offset + len],
527            &src[src_offset..src_offset + len]
528        );
529    }
530
531    fn run_copy_to_case(src_offset: usize, dst_offset: usize, len: usize) {
532        let mut src = [0u8; 64];
533        let mut dst = [0u8; 64];
534        fill_pattern(&mut src);
535
536        // SAFETY: Offsets are validated by callers before this helper is invoked.
537        let src_ptr = unsafe { src.as_ptr().add(src_offset) };
538        // SAFETY: Offsets are validated by callers before this helper is invoked.
539        let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) };
540
541        // SAFETY: The test buffers are valid for the requested range.
542        unsafe { copy_to_mmio(src_ptr, dst_ptr, len) };
543
544        assert_eq!(
545            &dst[dst_offset..dst_offset + len],
546            &src[src_offset..src_offset + len]
547        );
548    }
549
550    #[ktest]
551    fn copy_from_alignment_and_sizes() {
552        let word_size = size_of::<usize>();
553        let sizes = [
554            0,
555            1,
556            word_size.saturating_sub(1),
557            word_size,
558            word_size + 1,
559            word_size * 2 + 3,
560        ];
561        let offsets = [0, 1, 2];
562
563        for &len in &sizes {
564            for &src_offset in &offsets {
565                for &dst_offset in &offsets {
566                    if src_offset + len <= 64 && dst_offset + len <= 64 {
567                        run_copy_from_case(src_offset, dst_offset, len);
568                    }
569                }
570            }
571        }
572    }
573
574    #[ktest]
575    fn copy_to_alignment_and_sizes() {
576        let word_size = size_of::<usize>();
577        let sizes = [
578            0,
579            1,
580            word_size.saturating_sub(1),
581            word_size,
582            word_size + 1,
583            word_size * 2 + 3,
584        ];
585        let offsets = [0, 1, 2];
586
587        for &len in &sizes {
588            for &src_offset in &offsets {
589                for &dst_offset in &offsets {
590                    if src_offset + len <= 64 && dst_offset + len <= 64 {
591                        run_copy_to_case(src_offset, dst_offset, len);
592                    }
593                }
594            }
595        }
596    }
597}