1mod allocator;
6pub(crate) mod util;
7
8use core::{
9 marker::PhantomData,
10 ops::{Deref, Range},
11};
12
13use align_ext::AlignExt;
14use inherit_methods_macro::inherit_methods;
15
16pub(crate) use self::allocator::IoMemAllocatorBuilder;
17pub(super) use self::allocator::init;
18#[cfg(all(target_arch = "x86_64", feature = "cvm_guest"))]
19use crate::arch::{if_tdx_enabled, tdx_guest::unprotect_gpa_tdvm_call};
20use crate::{
21 Error,
22 arch::io::io_mem::{read_once, write_once},
23 cpu::{AtomicCpuSet, CpuSet},
24 mm::{
25 Fallible, HasPaddr, HasSize, Infallible, PAGE_SIZE, Paddr, PodOnce, VmIo, VmIoFill,
26 VmIoOnce, VmReader, VmWriter,
27 io::{
28 Io,
29 copy::{memcpy, memset},
30 },
31 kspace::kvirt_area::KVirtArea,
32 page_prop::{CachePolicy, PageFlags, PageProperty, PrivilegedPageFlags},
33 tlb::{TlbFlushOp, TlbFlusher},
34 },
35 prelude::*,
36 task::disable_preempt,
37};
38
39#[derive(Clone, Debug)]
42pub(crate) enum Sensitive {}
43
44#[derive(Clone, Debug)]
47pub enum Insensitive {}
48
49#[derive(Clone, Debug)]
51pub struct IoMem<SecuritySensitivity = Insensitive> {
52 kvirt_area: Arc<KVirtArea>,
53 offset: usize,
55 limit: usize,
56 pa: Paddr,
57 cache_policy: CachePolicy,
58 phantom: PhantomData<SecuritySensitivity>,
59}
60
61impl<SecuritySensitivity> IoMem<SecuritySensitivity> {
62 pub fn slice(&self, range: Range<usize>) -> Self {
68 assert!(!range.is_empty() && range.end <= self.limit);
70
71 Self {
73 kvirt_area: self.kvirt_area.clone(),
74 offset: self.offset + range.start,
75 limit: range.len(),
76 pa: self.pa + range.start,
77 cache_policy: self.cache_policy,
78 phantom: PhantomData,
79 }
80 }
81
82 pub(crate) unsafe fn new(range: Range<Paddr>, flags: PageFlags, cache: CachePolicy) -> Self {
93 let first_page_start = range.start.align_down(PAGE_SIZE);
94 let last_page_end = range.end.align_up(PAGE_SIZE);
95
96 let frames_range = first_page_start..last_page_end;
97 let area_size = frames_range.len();
98
99 #[cfg(target_arch = "x86_64")]
100 let priv_flags = if_tdx_enabled!({
101 assert!(
102 first_page_start == range.start && last_page_end == range.end,
103 "I/O memory is not page aligned, which cannot be unprotected in TDX: {:#x?}..{:#x?}",
104 range.start,
105 range.end,
106 );
107
108 unsafe { unprotect_gpa_tdvm_call(first_page_start, area_size).unwrap() };
116
117 PrivilegedPageFlags::SHARED
118 } else {
119 PrivilegedPageFlags::empty()
120 });
121 #[cfg(not(target_arch = "x86_64"))]
122 let priv_flags = PrivilegedPageFlags::empty();
123
124 let prop = PageProperty {
125 flags,
126 cache,
127 priv_flags,
128 };
129
130 let kva = {
131 let kva = unsafe { KVirtArea::map_untracked_frames(area_size, 0, frames_range, prop) };
134
135 let target_cpus = AtomicCpuSet::new(CpuSet::new_full());
136 let mut flusher = TlbFlusher::new(&target_cpus, disable_preempt());
137 flusher.issue_tlb_flush(TlbFlushOp::for_range(kva.range()));
138 flusher.dispatch_tlb_flush();
139 flusher.sync_tlb_flush();
140
141 kva
142 };
143
144 Self {
145 kvirt_area: Arc::new(kva),
146 offset: range.start - first_page_start,
147 limit: range.len(),
148 pa: range.start,
149 cache_policy: cache,
150 phantom: PhantomData,
151 }
152 }
153
154 pub fn cache_policy(&self) -> CachePolicy {
156 self.cache_policy
157 }
158
159 fn base(&self) -> usize {
161 self.kvirt_area.deref().start() + self.offset
162 }
163
164 fn check_range(&self, offset: usize, len: usize) -> Result<()> {
166 if offset.checked_add(len).is_none_or(|end| end > self.limit) {
167 return Err(Error::InvalidArgs);
168 }
169 Ok(())
170 }
171}
172
173#[cfg_attr(target_arch = "loongarch64", expect(unused))]
174impl IoMem<Sensitive> {
175 pub(crate) unsafe fn read_once<T: PodOnce>(&self, offset: usize) -> T {
188 debug_assert!(offset + size_of::<T>() <= self.limit);
189 let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *const T;
190 unsafe { read_once(ptr) }
192 }
193
194 pub(crate) unsafe fn write_once<T: PodOnce>(&self, offset: usize, value: &T) {
207 debug_assert!(offset + size_of::<T>() <= self.limit);
208 let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *mut T;
209 unsafe { write_once(ptr, *value) };
211 }
212}
213
214impl IoMem<Insensitive> {
215 pub fn acquire(range: Range<Paddr>) -> Result<IoMem<Insensitive>> {
219 Self::acquire_with_cache_policy(range, CachePolicy::Uncacheable)
220 }
221
222 pub fn acquire_with_cache_policy(
224 range: Range<Paddr>,
225 cache_policy: CachePolicy,
226 ) -> Result<IoMem<Insensitive>> {
227 allocator::IO_MEM_ALLOCATOR
228 .get()
229 .unwrap()
230 .acquire(range, cache_policy)
231 .ok_or(Error::AccessDenied)
232 }
233
234 pub fn read_fallible(
239 &self,
240 offset: usize,
241 writer: &mut VmWriter,
242 ) -> core::result::Result<usize, (Error, usize)> {
243 let len = writer.avail();
244 self.check_range(offset, len).map_err(|err| (err, 0))?;
245
246 let src = (self.base() + offset) as *const u8;
247 let copied = unsafe { memcpy::<Fallible, Io>(writer.cursor(), src, len) };
250 writer.skip(copied);
251
252 if copied < len {
253 Err((Error::PageFault, copied))
254 } else {
255 Ok(copied)
256 }
257 }
258
259 pub fn write_fallible(
264 &self,
265 offset: usize,
266 reader: &mut VmReader,
267 ) -> core::result::Result<usize, (Error, usize)> {
268 let len = reader.remain();
269 self.check_range(offset, len).map_err(|err| (err, 0))?;
270
271 let dst = (self.base() + offset) as *mut u8;
272 let copied = unsafe { memcpy::<Io, Fallible>(dst, reader.cursor(), len) };
275 reader.skip(copied);
276
277 if copied < len {
278 Err((Error::PageFault, copied))
279 } else {
280 Ok(copied)
281 }
282 }
283}
284
285impl VmIoOnce for IoMem<Insensitive> {
286 fn read_once<T: PodOnce>(&self, offset: usize) -> Result<T> {
287 self.check_range(offset, size_of::<T>())?;
288 let ptr = (self.base() + offset) as *const T;
289 if !ptr.is_aligned() {
290 return Err(Error::InvalidArgs);
291 }
292
293 let val = unsafe { read_once(ptr) };
295 Ok(val)
296 }
297
298 fn write_once<T: PodOnce>(&self, offset: usize, value: &T) -> Result<()> {
299 self.check_range(offset, size_of::<T>())?;
300 let ptr = (self.base() + offset) as *mut T;
301 if !ptr.is_aligned() {
302 return Err(Error::InvalidArgs);
303 }
304
305 unsafe { write_once(ptr, *value) };
307 Ok(())
308 }
309}
310
311impl VmIo for IoMem<Insensitive> {
312 fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()> {
313 let len = writer.avail();
314 self.check_range(offset, len)?;
315
316 let src = (self.base() + offset) as *const u8;
317 let copied = unsafe { memcpy::<Fallible, Io>(writer.cursor(), src, len) };
320 if copied < len {
321 return Err(Error::PageFault);
322 }
323
324 writer.skip(copied);
325 Ok(())
326 }
327
328 fn read_bytes(&self, offset: usize, buf: &mut [u8]) -> Result<()> {
329 let len = buf.len();
330 self.check_range(offset, len)?;
331 let src = (self.base() + offset) as *const u8;
332 let dst = buf.as_mut_ptr();
333
334 unsafe { memcpy::<Infallible, Io>(dst, src, len) };
336 Ok(())
337 }
338
339 fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()> {
340 let len = reader.remain();
341 self.check_range(offset, len)?;
342
343 let dst = (self.base() + offset) as *mut u8;
344 let copied = unsafe { memcpy::<Io, Fallible>(dst, reader.cursor(), len) };
347 if copied < len {
348 return Err(Error::PageFault);
349 }
350
351 reader.skip(copied);
352 Ok(())
353 }
354
355 fn write_bytes(&self, offset: usize, buf: &[u8]) -> Result<()> {
356 let len = buf.len();
357 self.check_range(offset, len)?;
358 let src = buf.as_ptr();
359 let dst = (self.base() + offset) as *mut u8;
360
361 unsafe { memcpy::<Io, Infallible>(dst, src, len) };
363 Ok(())
364 }
365}
366
367impl VmIoFill for IoMem<Insensitive> {
368 fn fill_zeros(&self, offset: usize, len: usize) -> core::result::Result<(), (Error, usize)> {
369 if len == 0 {
370 return Ok(());
371 }
372
373 if offset > self.limit {
374 return Err((Error::InvalidArgs, 0));
375 }
376
377 let available = self.limit - offset;
378 let write_len = core::cmp::min(len, available);
379 if write_len == 0 {
380 return Err((Error::InvalidArgs, 0));
381 }
382
383 let dst = (self.base() + offset) as *mut u8;
384 unsafe { memset::<Io>(dst, 0u8, write_len) };
386
387 if write_len == len {
388 Ok(())
389 } else {
390 Err((Error::InvalidArgs, write_len))
391 }
392 }
393}
394
395macro_rules! impl_vm_io_pointer {
396 ($ty:ty, $from:tt) => {
397 #[inherit_methods(from = $from)]
398 impl VmIo for $ty {
399 fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()>;
400 fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()>;
401 }
402
403 #[inherit_methods(from = $from)]
404 impl VmIoOnce for $ty {
405 fn read_once<T: PodOnce>(&self, offset: usize) -> Result<T>;
406 fn write_once<T: PodOnce>(&self, offset: usize, value: &T) -> Result<()>;
407 }
408
409 #[inherit_methods(from = $from)]
410 impl VmIoFill for $ty {
411 fn fill_zeros(
412 &self,
413 offset: usize,
414 len: usize,
415 ) -> core::result::Result<(), (Error, usize)>;
416 }
417 };
418}
419
420impl_vm_io_pointer!(&IoMem<Insensitive>, "(**self)");
421impl_vm_io_pointer!(&mut IoMem<Insensitive>, "(**self)");
422
423impl<SecuritySensitivity> HasPaddr for IoMem<SecuritySensitivity> {
424 fn paddr(&self) -> Paddr {
425 self.pa
426 }
427}
428
429impl<SecuritySensitivity> HasSize for IoMem<SecuritySensitivity> {
430 fn size(&self) -> usize {
431 self.limit
432 }
433}
434
435impl<SecuritySensitivity> Drop for IoMem<SecuritySensitivity> {
436 fn drop(&mut self) {
437 }
441}
442
443#[cfg(ktest)]
444mod test {
445 use core::mem::size_of;
446
447 use crate::{
448 arch::io::io_mem::{copy_from_mmio, copy_to_mmio, read_once, write_once},
449 prelude::ktest,
450 };
451
452 #[ktest]
453 fn read_write_u8() {
454 let mut data: u8 = 0;
455 unsafe {
457 write_once(&mut data, 42u8);
458 assert_eq!(read_once(&data), 42u8);
459 }
460 }
461
462 #[ktest]
463 fn read_write_u16() {
464 let mut data: u16 = 0;
465 let val: u16 = 0x1234;
466 unsafe {
468 write_once(&mut data, val);
469 assert_eq!(read_once(&data), val);
470 }
471 }
472
473 #[ktest]
474 fn read_write_u32() {
475 let mut data: u32 = 0;
476 let val: u32 = 0x12345678;
477 unsafe {
479 write_once(&mut data, val);
480 assert_eq!(read_once(&data), val);
481 }
482 }
483
484 #[ktest]
485 fn read_write_u64() {
486 let mut data: u64 = 0;
487 let val: u64 = 0xDEADBEEFCAFEBABE;
488 unsafe {
490 write_once(&mut data, val);
491 assert_eq!(read_once(&data), val);
492 }
493 }
494
495 #[ktest]
496 fn boundary_overlap() {
497 let mut data: [u8; 2] = [0xAA, 0xBB];
498 unsafe {
500 write_once(&mut data[0], 0x11u8);
501 assert_eq!(data[0], 0x11);
502 assert_eq!(data[1], 0xBB);
503 }
504 }
505
506 fn fill_pattern(buf: &mut [u8]) {
507 for (idx, byte) in buf.iter_mut().enumerate() {
508 *byte = (idx as u8).wrapping_mul(3).wrapping_add(1);
509 }
510 }
511
512 fn run_copy_from_case(src_offset: usize, dst_offset: usize, len: usize) {
513 let mut src = [0u8; 64];
514 let mut dst = [0u8; 64];
515 fill_pattern(&mut src);
516
517 let src_ptr = unsafe { src.as_ptr().add(src_offset) };
519 let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) };
521
522 unsafe { copy_from_mmio(dst_ptr, src_ptr, len) };
524
525 assert_eq!(
526 &dst[dst_offset..dst_offset + len],
527 &src[src_offset..src_offset + len]
528 );
529 }
530
531 fn run_copy_to_case(src_offset: usize, dst_offset: usize, len: usize) {
532 let mut src = [0u8; 64];
533 let mut dst = [0u8; 64];
534 fill_pattern(&mut src);
535
536 let src_ptr = unsafe { src.as_ptr().add(src_offset) };
538 let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) };
540
541 unsafe { copy_to_mmio(src_ptr, dst_ptr, len) };
543
544 assert_eq!(
545 &dst[dst_offset..dst_offset + len],
546 &src[src_offset..src_offset + len]
547 );
548 }
549
550 #[ktest]
551 fn copy_from_alignment_and_sizes() {
552 let word_size = size_of::<usize>();
553 let sizes = [
554 0,
555 1,
556 word_size.saturating_sub(1),
557 word_size,
558 word_size + 1,
559 word_size * 2 + 3,
560 ];
561 let offsets = [0, 1, 2];
562
563 for &len in &sizes {
564 for &src_offset in &offsets {
565 for &dst_offset in &offsets {
566 if src_offset + len <= 64 && dst_offset + len <= 64 {
567 run_copy_from_case(src_offset, dst_offset, len);
568 }
569 }
570 }
571 }
572 }
573
574 #[ktest]
575 fn copy_to_alignment_and_sizes() {
576 let word_size = size_of::<usize>();
577 let sizes = [
578 0,
579 1,
580 word_size.saturating_sub(1),
581 word_size,
582 word_size + 1,
583 word_size * 2 + 3,
584 ];
585 let offsets = [0, 1, 2];
586
587 for &len in &sizes {
588 for &src_offset in &offsets {
589 for &dst_offset in &offsets {
590 if src_offset + len <= 64 && dst_offset + len <= 64 {
591 run_copy_to_case(src_offset, dst_offset, len);
592 }
593 }
594 }
595 }
596 }
597}