ostd_pod/
array_helper.rs

1// SPDX-License-Identifier: MPL-2.0
2
3//! Aligned array helpers for Pod types.
4//!
5//! This module provides type-level utilities
6//! for creating arrays with specific alignment requirements.
7//! It's primarily used internally to support Pod unions
8//! that need to maintain precise memory layouts with guaranteed alignment.
9
10use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
11
12/// A transparent wrapper around `[u8; N]` with guaranteed 1-byte alignment.
13///
14/// This type implements the zerocopy traits (`FromBytes`, `IntoBytes`, `Immutable`, `KnownLayout`)
15/// making it safe to transmute to/from byte arrays. It is primarily used internally by the
16/// `ArrayFactory` type system to provide aligned arrays for POD unions.
17#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
18#[repr(transparent)]
19pub struct U8Array<const N: usize>([u8; N]);
20
21const _: () = assert!(align_of::<U8Array<0>>() == 1);
22
23/// A transparent wrapper around `[u16; N]` with guaranteed 2-byte alignment.
24#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
25#[repr(transparent)]
26pub struct U16Array<const N: usize>([u16; N]);
27
28const _: () = assert!(align_of::<U16Array<0>>() == 2);
29
30/// A transparent wrapper around `[u32; N]` with guaranteed 4-byte alignment.
31#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
32#[repr(transparent)]
33pub struct U32Array<const N: usize>([u32; N]);
34
35const _: () = assert!(align_of::<U32Array<0>>() == 4);
36
37/// A transparent wrapper around `[u64; N]` with guaranteed 8-byte alignment.
38#[derive(FromBytes, IntoBytes, Immutable, KnownLayout, Clone, Copy)]
39#[repr(transparent)]
40pub struct U64Array<const N: usize>([u64; N]);
41
42const _: () = assert!(align_of::<U64Array<0>>() == 8);
43
44/// A type-level factory for creating aligned arrays based on alignment requirements.
45///
46/// This zero-sized type uses const generics to select the appropriate underlying array type
47/// (`U8Array`, `U16Array`, `U32Array`, or `U64Array`) based on the alignment requirement `A` and
48/// the number of elements `N`.
49///
50/// # Type Parameters
51///
52/// * `A` - The required alignment in bytes (1, 2, 4, or 8).
53/// * `N` - The number of elements in the array.
54///
55/// # Examples
56///
57/// ```rust
58/// use ostd_pod::array_helper::{ArrayFactory, ArrayManufacture};
59///
60/// // Creates a `U32Array<8>` (8 `u32` elements with 4-byte alignment)
61/// type MyArray = <ArrayFactory<4, 8> as ArrayManufacture>::Array;
62/// ```
63pub enum ArrayFactory<const A: usize, const N: usize> {}
64
65/// Trait that associates an `ArrayFactory` with its corresponding aligned array type.
66///
67/// This trait is implemented for `ArrayFactory<A, N>` where `A` is 1, 2, 4, or 8,
68/// mapping to `U8Array`, `U16Array`, `U32Array`, and `U64Array` respectively.
69pub trait ArrayManufacture {
70    /// The aligned array type produced by this factory.
71    type Array: FromBytes + IntoBytes + Immutable;
72}
73
74impl<const N: usize> ArrayManufacture for ArrayFactory<1, N> {
75    type Array = U8Array<N>;
76}
77
78impl<const N: usize> ArrayManufacture for ArrayFactory<2, N> {
79    type Array = U16Array<N>;
80}
81
82impl<const N: usize> ArrayManufacture for ArrayFactory<4, N> {
83    type Array = U32Array<N>;
84}
85
86impl<const N: usize> ArrayManufacture for ArrayFactory<8, N> {
87    type Array = U64Array<N>;
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn u8array_alignment() {
96        assert_eq!(align_of::<U8Array<0>>(), 1);
97        assert_eq!(align_of::<U8Array<1>>(), 1);
98        assert_eq!(align_of::<U8Array<10>>(), 1);
99    }
100
101    #[test]
102    fn u8array_size() {
103        assert_eq!(size_of::<U8Array<0>>(), 0);
104        assert_eq!(size_of::<U8Array<1>>(), 1);
105        assert_eq!(size_of::<U8Array<4>>(), 4);
106        assert_eq!(size_of::<U8Array<10>>(), 10);
107    }
108
109    #[test]
110    fn u16array_alignment() {
111        assert_eq!(align_of::<U16Array<0>>(), 2);
112        assert_eq!(align_of::<U16Array<1>>(), 2);
113        assert_eq!(align_of::<U16Array<10>>(), 2);
114    }
115
116    #[test]
117    fn u16array_size() {
118        assert_eq!(size_of::<U16Array<0>>(), 0);
119        assert_eq!(size_of::<U16Array<1>>(), 2);
120        assert_eq!(size_of::<U16Array<4>>(), 8);
121        assert_eq!(size_of::<U16Array<10>>(), 20);
122    }
123
124    #[test]
125    fn u32array_alignment() {
126        assert_eq!(align_of::<U32Array<0>>(), 4);
127        assert_eq!(align_of::<U32Array<1>>(), 4);
128        assert_eq!(align_of::<U32Array<10>>(), 4);
129    }
130
131    #[test]
132    fn u32array_size() {
133        assert_eq!(size_of::<U32Array<0>>(), 0);
134        assert_eq!(size_of::<U32Array<1>>(), 4);
135        assert_eq!(size_of::<U32Array<4>>(), 16);
136        assert_eq!(size_of::<U32Array<10>>(), 40);
137    }
138
139    #[test]
140    fn u64array_alignment() {
141        assert_eq!(align_of::<U64Array<0>>(), 8);
142        assert_eq!(align_of::<U64Array<1>>(), 8);
143        assert_eq!(align_of::<U64Array<10>>(), 8);
144    }
145
146    #[test]
147    fn u64array_size() {
148        assert_eq!(size_of::<U64Array<0>>(), 0);
149        assert_eq!(size_of::<U64Array<1>>(), 8);
150        assert_eq!(size_of::<U64Array<4>>(), 32);
151        assert_eq!(size_of::<U64Array<10>>(), 80);
152    }
153
154    #[test]
155    fn array_factory_1byte_alignment() {
156        type Array = <ArrayFactory<1, 5> as ArrayManufacture>::Array;
157        assert_eq!(align_of::<Array>(), 1);
158        assert_eq!(size_of::<Array>(), 5);
159    }
160
161    #[test]
162    fn array_factory_2byte_alignment() {
163        type Array = <ArrayFactory<2, 5> as ArrayManufacture>::Array;
164        assert_eq!(align_of::<Array>(), 2);
165        assert_eq!(size_of::<Array>(), 10);
166    }
167
168    #[test]
169    fn array_factory_4byte_alignment() {
170        type Array = <ArrayFactory<4, 5> as ArrayManufacture>::Array;
171        assert_eq!(align_of::<Array>(), 4);
172        assert_eq!(size_of::<Array>(), 20);
173    }
174
175    #[test]
176    fn array_factory_8byte_alignment() {
177        type Array = <ArrayFactory<8, 5> as ArrayManufacture>::Array;
178        assert_eq!(align_of::<Array>(), 8);
179        assert_eq!(size_of::<Array>(), 40);
180    }
181
182    #[test]
183    fn zerocopy_traits() {
184        // Test that the types implement the required zerocopy traits
185        fn assert_traits<T: FromBytes + IntoBytes + Immutable + KnownLayout>() {}
186
187        assert_traits::<U16Array<4>>();
188        assert_traits::<U32Array<4>>();
189        assert_traits::<U64Array<4>>();
190    }
191}