1use core::marker::PhantomData;
2
3use crate::Matrix;
4
5pub struct HorizontallyTruncated<T, Inner> {
9 inner: Inner,
11 truncated_width: usize,
13 _phantom: PhantomData<T>,
15}
16
17impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
18where
19 T: Send + Sync + Clone,
20{
21 pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
29 (truncated_width <= inner.width()).then(|| Self {
30 inner,
31 truncated_width,
32 _phantom: PhantomData,
33 })
34 }
35}
36
37impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
38where
39 T: Send + Sync + Clone,
40 Inner: Matrix<T>,
41{
42 #[inline(always)]
44 fn width(&self) -> usize {
45 self.truncated_width
46 }
47
48 #[inline(always)]
50 fn height(&self) -> usize {
51 self.inner.height()
52 }
53
54 #[inline(always)]
55 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
56 unsafe {
57 self.inner.get_unchecked(r, c)
59 }
60 }
61
62 unsafe fn row_unchecked(
63 &self,
64 r: usize,
65 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
66 unsafe {
67 self.inner.row_subseq_unchecked(r, 0, self.truncated_width)
69 }
70 }
71
72 unsafe fn row_subseq_unchecked(
73 &self,
74 r: usize,
75 start: usize,
76 end: usize,
77 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
78 unsafe {
79 self.inner.row_subseq_unchecked(r, start, end)
81 }
82 }
83
84 unsafe fn row_subslice_unchecked(
85 &self,
86 r: usize,
87 start: usize,
88 end: usize,
89 ) -> impl core::ops::Deref<Target = [T]> {
90 unsafe {
91 self.inner.row_subslice_unchecked(r, start, end)
93 }
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use alloc::vec;
100 use alloc::vec::Vec;
101
102 use super::*;
103 use crate::dense::RowMajorMatrix;
104
105 #[test]
106 fn test_truncate_width_by_one() {
107 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
112
113 let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
115
116 assert_eq!(truncated.width(), 3);
118
119 assert_eq!(truncated.height(), 3);
121
122 assert_eq!(truncated.get(0, 0), Some(1)); assert_eq!(truncated.get(1, 1), Some(6)); unsafe {
126 assert_eq!(truncated.get_unchecked(0, 1), 2); assert_eq!(truncated.get_unchecked(2, 2), 11); }
129
130 let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
132 assert_eq!(row0, vec![1, 2, 3]);
133 unsafe {
134 let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
136 assert_eq!(row1, vec![5, 6, 7]);
137
138 let row3_subset: Vec<_> = truncated
140 .row_subseq_unchecked(2, 1, 2)
141 .into_iter()
142 .collect();
143 assert_eq!(row3_subset, vec![10]);
144 }
145
146 unsafe {
147 let row1 = truncated.row_slice(1).unwrap();
148 assert_eq!(&*row1, &[5, 6, 7]);
149
150 let row2 = truncated.row_slice_unchecked(2);
151 assert_eq!(&*row2, &[9, 10, 11]);
152
153 let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
154 assert_eq!(&*row0_subslice, &[1, 2]);
155 }
156
157 assert!(truncated.get(0, 3).is_none()); assert!(truncated.get(3, 0).is_none()); assert!(truncated.row(3).is_none()); assert!(truncated.row_slice(3).is_none()); let as_matrix = truncated.to_row_major_matrix();
164
165 let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
170
171 assert_eq!(as_matrix, expected);
172 }
173
174 #[test]
175 fn test_no_truncation() {
176 let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
180
181 let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
183
184 assert_eq!(truncated.width(), 2);
185 assert_eq!(truncated.height(), 2);
186 assert_eq!(truncated.get(0, 1).unwrap(), 8);
187 assert_eq!(truncated.get(1, 0).unwrap(), 9);
188
189 unsafe {
190 assert_eq!(truncated.get_unchecked(0, 0), 7);
191 assert_eq!(truncated.get_unchecked(1, 1), 10);
192 }
193
194 let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
195 assert_eq!(row0, vec![7, 8]);
196
197 let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
198 assert_eq!(row1, vec![9, 10]);
199
200 assert!(truncated.get(0, 2).is_none()); assert!(truncated.get(2, 0).is_none()); assert!(truncated.row(2).is_none()); assert!(truncated.row_slice(2).is_none()); }
205
206 #[test]
207 fn test_truncate_to_zero_width() {
208 let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
210
211 let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
213
214 assert_eq!(truncated.width(), 0);
215 assert_eq!(truncated.height(), 1);
216
217 assert!(truncated.row(0).unwrap().into_iter().next().is_none());
219
220 assert!(truncated.get(0, 0).is_none()); assert!(truncated.get(1, 0).is_none()); assert!(truncated.row(1).is_none()); assert!(truncated.row_slice(1).is_none()); }
225
226 #[test]
227 fn test_invalid_truncation_width() {
228 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
232
233 assert!(HorizontallyTruncated::new(inner, 5).is_none());
235 }
236}