p3_matrix/
stack.rs

1use core::ops::Deref;
2
3use crate::Matrix;
4use crate::dense::RowMajorMatrixView;
5
6/// A type alias representing a vertical composition of two row-major matrix views.
7///
8/// `ViewPair` combines two [`RowMajorMatrixView`]'s with the same element type `T`
9/// and lifetime `'a` into a single virtual matrix stacked vertically.
10///
11/// Both views must have the same width; the resulting view has a height equal
12/// to the sum of the two original heights.
13pub type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixView<'a, T>>;
14
15/// A matrix composed by stacking two matrices vertically, one on top of the other.
16///
17/// Both matrices must have the same `width`.
18/// The resulting matrix has dimensions:
19/// - `width`: The same as the inputs.
20/// - `height`: The sum of the `heights` of the input matrices.
21///
22/// Element access and iteration will first access the rows of the top matrix,
23/// followed by the rows of the bottom matrix.
24#[derive(Copy, Clone, Debug)]
25pub struct VerticalPair<Top, Bottom> {
26    /// The top matrix in the vertical composition.
27    pub top: Top,
28    /// The bottom matrix in the vertical composition.
29    pub bottom: Bottom,
30}
31
32/// A matrix composed by placing two matrices side-by-side horizontally.
33///
34/// Both matrices must have the same `height`.
35/// The resulting matrix has dimensions:
36/// - `width`: The sum of the `widths` of the input matrices.
37/// - `height`: The same as the inputs.
38///
39/// Element access and iteration for a given row `i` will first access the elements in the `i`'th row of the left matrix,
40/// followed by elements in the `i'`th row of the right matrix.
41#[derive(Copy, Clone, Debug)]
42pub struct HorizontalPair<Left, Right> {
43    /// The left matrix in the horizontal composition.
44    pub left: Left,
45    /// The right matrix in the horizontal composition.
46    pub right: Right,
47}
48
49impl<Top, Bottom> VerticalPair<Top, Bottom> {
50    /// Create a new `VerticalPair` by stacking two matrices vertically.
51    ///
52    /// # Panics
53    /// Panics if the two matrices do not have the same width (i.e., number of columns),
54    /// since vertical composition requires column alignment.
55    ///
56    /// # Returns
57    /// A `VerticalPair` that represents the combined matrix.
58    pub fn new<T>(top: Top, bottom: Bottom) -> Self
59    where
60        T: Send + Sync + Clone,
61        Top: Matrix<T>,
62        Bottom: Matrix<T>,
63    {
64        assert_eq!(top.width(), bottom.width());
65        Self { top, bottom }
66    }
67}
68
69impl<Left, Right> HorizontalPair<Left, Right> {
70    /// Create a new `HorizontalPair` by joining two matrices side by side.
71    ///
72    /// # Panics
73    /// Panics if the two matrices do not have the same height (i.e., number of rows),
74    /// since horizontal composition requires row alignment.
75    ///
76    /// # Returns
77    /// A `HorizontalPair` that represents the combined matrix.
78    pub fn new<T>(left: Left, right: Right) -> Self
79    where
80        T: Send + Sync + Clone,
81        Left: Matrix<T>,
82        Right: Matrix<T>,
83    {
84        assert_eq!(left.height(), right.height());
85        Self { left, right }
86    }
87}
88
89impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
90    for VerticalPair<Top, Bottom>
91{
92    fn width(&self) -> usize {
93        self.top.width()
94    }
95
96    fn height(&self) -> usize {
97        self.top.height() + self.bottom.height()
98    }
99
100    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
101        unsafe {
102            // Safety: The caller must ensure that r < self.height() and c < self.width()
103            if r < self.top.height() {
104                self.top.get_unchecked(r, c)
105            } else {
106                self.bottom.get_unchecked(r - self.top.height(), c)
107            }
108        }
109    }
110
111    unsafe fn row_unchecked(
112        &self,
113        r: usize,
114    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
115        unsafe {
116            // Safety: The caller must ensure that r < self.height()
117            if r < self.top.height() {
118                EitherRow::Left(self.top.row_unchecked(r).into_iter())
119            } else {
120                EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
121            }
122        }
123    }
124
125    unsafe fn row_subseq_unchecked(
126        &self,
127        r: usize,
128        start: usize,
129        end: usize,
130    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
131        unsafe {
132            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width()
133            if r < self.top.height() {
134                EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
135            } else {
136                EitherRow::Right(
137                    self.bottom
138                        .row_subseq_unchecked(r - self.top.height(), start, end)
139                        .into_iter(),
140                )
141            }
142        }
143    }
144
145    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
146        unsafe {
147            // Safety: The caller must ensure that r < self.height()
148            if r < self.top.height() {
149                EitherRow::Left(self.top.row_slice_unchecked(r))
150            } else {
151                EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
152            }
153        }
154    }
155
156    unsafe fn row_subslice_unchecked(
157        &self,
158        r: usize,
159        start: usize,
160        end: usize,
161    ) -> impl Deref<Target = [T]> {
162        unsafe {
163            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width()
164            if r < self.top.height() {
165                EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
166            } else {
167                EitherRow::Right(self.bottom.row_subslice_unchecked(
168                    r - self.top.height(),
169                    start,
170                    end,
171                ))
172            }
173        }
174    }
175}
176
177impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
178    for HorizontalPair<Left, Right>
179{
180    fn width(&self) -> usize {
181        self.left.width() + self.right.width()
182    }
183
184    fn height(&self) -> usize {
185        self.left.height()
186    }
187
188    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
189        unsafe {
190            // Safety: The caller must ensure that r < self.height() and c < self.width()
191            if c < self.left.width() {
192                self.left.get_unchecked(r, c)
193            } else {
194                self.right.get_unchecked(r, c - self.left.width())
195            }
196        }
197    }
198
199    unsafe fn row_unchecked(
200        &self,
201        r: usize,
202    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
203        unsafe {
204            // Safety: The caller must ensure that r < self.height()
205            self.left
206                .row_unchecked(r)
207                .into_iter()
208                .chain(self.right.row_unchecked(r))
209        }
210    }
211}
212
213/// We use this to wrap both the row iterator and the row slice.
214#[derive(Debug)]
215pub enum EitherRow<L, R> {
216    Left(L),
217    Right(R),
218}
219
220impl<T, L, R> Iterator for EitherRow<L, R>
221where
222    L: Iterator<Item = T>,
223    R: Iterator<Item = T>,
224{
225    type Item = T;
226
227    fn next(&mut self) -> Option<Self::Item> {
228        match self {
229            Self::Left(l) => l.next(),
230            Self::Right(r) => r.next(),
231        }
232    }
233}
234
235impl<T, L, R> Deref for EitherRow<L, R>
236where
237    L: Deref<Target = [T]>,
238    R: Deref<Target = [T]>,
239{
240    type Target = [T];
241    fn deref(&self) -> &Self::Target {
242        match self {
243            Self::Left(l) => l,
244            Self::Right(r) => r,
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use alloc::vec;
252    use alloc::vec::Vec;
253
254    use itertools::Itertools;
255
256    use super::*;
257    use crate::RowMajorMatrix;
258
259    #[test]
260    fn test_vertical_pair_empty_top() {
261        let top = RowMajorMatrix::new(vec![], 2); // 0x2
262        let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
263        let vpair = VerticalPair::new::<i32>(top, bottom);
264        assert_eq!(vpair.height(), 2);
265        assert_eq!(vpair.get(1, 1), Some(4));
266        unsafe {
267            assert_eq!(vpair.get_unchecked(0, 0), 1);
268        }
269    }
270
271    #[test]
272    fn test_vertical_pair_composition() {
273        let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
274        let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); // 2x2
275        let vertical = VerticalPair::new::<i32>(top, bottom);
276
277        // Dimensions
278        assert_eq!(vertical.width(), 2);
279        assert_eq!(vertical.height(), 4);
280
281        // Values from top
282        assert_eq!(vertical.get(0, 0), Some(1));
283        assert_eq!(vertical.get(1, 1), Some(4));
284
285        // Values from bottom
286        unsafe {
287            assert_eq!(vertical.get_unchecked(2, 0), 5);
288            assert_eq!(vertical.get_unchecked(3, 1), 8);
289        }
290
291        // Row iter from bottom
292        let row = vertical.row(3).unwrap().into_iter().collect_vec();
293        assert_eq!(row, vec![7, 8]);
294
295        unsafe {
296            // Row iter from top
297            let row = vertical.row_unchecked(1).into_iter().collect_vec();
298            assert_eq!(row, vec![3, 4]);
299
300            let row = vertical
301                .row_subseq_unchecked(0, 0, 1)
302                .into_iter()
303                .collect_vec();
304            assert_eq!(row, vec![1]);
305        }
306
307        // Row slice
308        assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
309
310        unsafe {
311            // Row slice unchecked
312            assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
313            assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
314        }
315
316        assert_eq!(vertical.get(0, 2), None); // Width out of bounds
317        assert_eq!(vertical.get(4, 0), None); // Height out of bounds
318        assert!(vertical.row(4).is_none()); // Height out of bounds
319        assert!(vertical.row_slice(4).is_none()); // Height out of bounds
320    }
321
322    #[test]
323    fn test_horizontal_pair_composition() {
324        let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
325        let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); // 2x2
326        let horizontal = HorizontalPair::new::<i32>(left, right);
327
328        // Dimensions
329        assert_eq!(horizontal.height(), 2);
330        assert_eq!(horizontal.width(), 4);
331
332        // Left values
333        assert_eq!(horizontal.get(0, 0), Some(1));
334        assert_eq!(horizontal.get(1, 1), Some(4));
335
336        // Right values
337        unsafe {
338            assert_eq!(horizontal.get_unchecked(0, 2), 5);
339            assert_eq!(horizontal.get_unchecked(1, 3), 8);
340        }
341
342        // Row iter
343        let row = horizontal.row(0).unwrap().into_iter().collect_vec();
344        assert_eq!(row, vec![1, 2, 5, 6]);
345
346        unsafe {
347            let row = horizontal.row_unchecked(1).into_iter().collect_vec();
348            assert_eq!(row, vec![3, 4, 7, 8]);
349        }
350
351        assert_eq!(horizontal.get(0, 4), None); // Width out of bounds
352        assert_eq!(horizontal.get(2, 0), None); // Height out of bounds
353        assert!(horizontal.row(2).is_none()); // Height out of bounds
354    }
355
356    #[test]
357    fn test_either_row_iterator_behavior() {
358        type Iter = alloc::vec::IntoIter<i32>;
359
360        // Left variant
361        let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
362        assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
363
364        // Right variant
365        let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
366        assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
367    }
368
369    #[test]
370    fn test_either_row_deref_behavior() {
371        let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
372        let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
373
374        assert_eq!(&*left, &[1, 2, 3]);
375        assert_eq!(&*right, &[4, 5]);
376    }
377
378    #[test]
379    #[should_panic]
380    fn test_vertical_pair_width_mismatch_should_panic() {
381        let a = RowMajorMatrix::new(vec![1, 2, 3], 1); // 3x1
382        let b = RowMajorMatrix::new(vec![4, 5], 2); // 1x2
383        let _ = VerticalPair::new::<i32>(a, b);
384    }
385
386    #[test]
387    #[should_panic]
388    fn test_horizontal_pair_height_mismatch_should_panic() {
389        let a = RowMajorMatrix::new(vec![1, 2, 3], 3); // 1x3
390        let b = RowMajorMatrix::new(vec![4, 5], 1); // 2x1
391        let _ = HorizontalPair::new::<i32>(a, b);
392    }
393}