p3_matrix/
horizontally_truncated.rs

1use core::marker::PhantomData;
2
3use crate::Matrix;
4
5/// A matrix wrapper that limits the number of columns visible from an inner matrix.
6///
7/// This struct wraps another matrix and restricts access to only the first `truncated_width` columns.
8pub struct HorizontallyTruncated<T, Inner> {
9    /// The underlying full matrix being wrapped.
10    inner: Inner,
11    /// The number of columns to expose from the inner matrix.
12    truncated_width: usize,
13    /// Marker for the element type `T`, not used at runtime.
14    _phantom: PhantomData<T>,
15}
16
17impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
18where
19    T: Send + Sync + Clone,
20{
21    /// Construct a new horizontally truncated view of a matrix.
22    ///
23    /// # Arguments
24    /// - `inner`: The full inner matrix to be wrapped.
25    /// - `truncated_width`: The number of columns to expose (must be ≤ `inner.width()`).
26    ///
27    /// Returns `None` if `truncated_width` is greater than the width of the inner matrix.
28    pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
29        (truncated_width <= inner.width()).then(|| Self {
30            inner,
31            truncated_width,
32            _phantom: PhantomData,
33        })
34    }
35}
36
37impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
38where
39    T: Send + Sync + Clone,
40    Inner: Matrix<T>,
41{
42    /// Returns the number of columns exposed by the truncated matrix.
43    #[inline(always)]
44    fn width(&self) -> usize {
45        self.truncated_width
46    }
47
48    /// Returns the number of rows in the matrix (same as the inner matrix).
49    #[inline(always)]
50    fn height(&self) -> usize {
51        self.inner.height()
52    }
53
54    #[inline(always)]
55    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
56        unsafe {
57            // Safety: The caller must ensure that `c < truncated_width` and `r < self.height()`.
58            self.inner.get_unchecked(r, c)
59        }
60    }
61
62    unsafe fn row_unchecked(
63        &self,
64        r: usize,
65    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
66        unsafe {
67            // Safety: The caller must ensure that `r < self.height()`.
68            self.inner.row_subseq_unchecked(r, 0, self.truncated_width)
69        }
70    }
71
72    unsafe fn row_subseq_unchecked(
73        &self,
74        r: usize,
75        start: usize,
76        end: usize,
77    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
78        unsafe {
79            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
80            self.inner.row_subseq_unchecked(r, start, end)
81        }
82    }
83
84    unsafe fn row_subslice_unchecked(
85        &self,
86        r: usize,
87        start: usize,
88        end: usize,
89    ) -> impl core::ops::Deref<Target = [T]> {
90        unsafe {
91            // Safety: The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
92            self.inner.row_subslice_unchecked(r, start, end)
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use alloc::vec;
100    use alloc::vec::Vec;
101
102    use super::*;
103    use crate::dense::RowMajorMatrix;
104
105    #[test]
106    fn test_truncate_width_by_one() {
107        // Create a 3x4 matrix:
108        // [ 1  2  3  4]
109        // [ 5  6  7  8]
110        // [ 9 10 11 12]
111        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
112
113        // Truncate to width 3.
114        let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
115
116        // Width should be 3.
117        assert_eq!(truncated.width(), 3);
118
119        // Height remains unchanged.
120        assert_eq!(truncated.height(), 3);
121
122        // Check individual elements.
123        assert_eq!(truncated.get(0, 0), Some(1)); // row 0, col 0
124        assert_eq!(truncated.get(1, 1), Some(6)); // row 1, col 1
125        unsafe {
126            assert_eq!(truncated.get_unchecked(0, 1), 2); // row 0, col 1
127            assert_eq!(truncated.get_unchecked(2, 2), 11); // row 1, col 0
128        }
129
130        // Row 0: should return [1, 2, 3]
131        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
132        assert_eq!(row0, vec![1, 2, 3]);
133        unsafe {
134            // Row 2: should return [5, 6, 7]
135            let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
136            assert_eq!(row1, vec![5, 6, 7]);
137
138            // Row 3: is equal to return [9, 10, 11]
139            let row3_subset: Vec<_> = truncated
140                .row_subseq_unchecked(2, 1, 2)
141                .into_iter()
142                .collect();
143            assert_eq!(row3_subset, vec![10]);
144        }
145
146        unsafe {
147            let row1 = truncated.row_slice(1).unwrap();
148            assert_eq!(&*row1, &[5, 6, 7]);
149
150            let row2 = truncated.row_slice_unchecked(2);
151            assert_eq!(&*row2, &[9, 10, 11]);
152
153            let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
154            assert_eq!(&*row0_subslice, &[1, 2]);
155        }
156
157        assert!(truncated.get(0, 3).is_none()); // Width out of bounds
158        assert!(truncated.get(3, 0).is_none()); // Height out of bounds
159        assert!(truncated.row(3).is_none()); // Height out of bounds
160        assert!(truncated.row_slice(3).is_none()); // Height out of bounds
161
162        // Convert the truncated view to a RowMajorMatrix and check contents.
163        let as_matrix = truncated.to_row_major_matrix();
164
165        // The expected matrix after truncation:
166        // [1  2  3]
167        // [5  6  7]
168        // [9 10 11]
169        let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
170
171        assert_eq!(as_matrix, expected);
172    }
173
174    #[test]
175    fn test_no_truncation() {
176        // 2x2 matrix:
177        // [ 7  8 ]
178        // [ 9 10 ]
179        let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
180
181        // Truncate to full width (no change).
182        let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
183
184        assert_eq!(truncated.width(), 2);
185        assert_eq!(truncated.height(), 2);
186        assert_eq!(truncated.get(0, 1).unwrap(), 8);
187        assert_eq!(truncated.get(1, 0).unwrap(), 9);
188
189        unsafe {
190            assert_eq!(truncated.get_unchecked(0, 0), 7);
191            assert_eq!(truncated.get_unchecked(1, 1), 10);
192        }
193
194        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
195        assert_eq!(row0, vec![7, 8]);
196
197        let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
198        assert_eq!(row1, vec![9, 10]);
199
200        assert!(truncated.get(0, 2).is_none()); // Width out of bounds
201        assert!(truncated.get(2, 0).is_none()); // Height out of bounds
202        assert!(truncated.row(2).is_none()); // Height out of bounds
203        assert!(truncated.row_slice(2).is_none()); // Height out of bounds
204    }
205
206    #[test]
207    fn test_truncate_to_zero_width() {
208        // 1x3 matrix: [11 12 13]
209        let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
210
211        // Truncate to width 0.
212        let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
213
214        assert_eq!(truncated.width(), 0);
215        assert_eq!(truncated.height(), 1);
216
217        // Row should be empty.
218        assert!(truncated.row(0).unwrap().into_iter().next().is_none());
219
220        assert!(truncated.get(0, 0).is_none()); // Width out of bounds
221        assert!(truncated.get(1, 0).is_none()); // Height out of bounds
222        assert!(truncated.row(1).is_none()); // Height out of bounds
223        assert!(truncated.row_slice(1).is_none()); // Height out of bounds
224    }
225
226    #[test]
227    fn test_invalid_truncation_width() {
228        // 2x2 matrix:
229        // [1 2]
230        // [3 4]
231        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
232
233        // Attempt to truncate beyond inner width (invalid).
234        assert!(HorizontallyTruncated::new(inner, 5).is_none());
235    }
236}