1use core::ops::Deref;
2
3use crate::Matrix;
4use crate::dense::RowMajorMatrixView;
5
6pub type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixView<'a, T>>;
14
15#[derive(Copy, Clone, Debug)]
25pub struct VerticalPair<Top, Bottom> {
26 pub top: Top,
28 pub bottom: Bottom,
30}
31
32#[derive(Copy, Clone, Debug)]
42pub struct HorizontalPair<Left, Right> {
43 pub left: Left,
45 pub right: Right,
47}
48
49impl<Top, Bottom> VerticalPair<Top, Bottom> {
50 pub fn new<T>(top: Top, bottom: Bottom) -> Self
59 where
60 T: Send + Sync + Clone,
61 Top: Matrix<T>,
62 Bottom: Matrix<T>,
63 {
64 assert_eq!(top.width(), bottom.width());
65 Self { top, bottom }
66 }
67}
68
69impl<Left, Right> HorizontalPair<Left, Right> {
70 pub fn new<T>(left: Left, right: Right) -> Self
79 where
80 T: Send + Sync + Clone,
81 Left: Matrix<T>,
82 Right: Matrix<T>,
83 {
84 assert_eq!(left.height(), right.height());
85 Self { left, right }
86 }
87}
88
89impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
90 for VerticalPair<Top, Bottom>
91{
92 fn width(&self) -> usize {
93 self.top.width()
94 }
95
96 fn height(&self) -> usize {
97 self.top.height() + self.bottom.height()
98 }
99
100 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
101 unsafe {
102 if r < self.top.height() {
104 self.top.get_unchecked(r, c)
105 } else {
106 self.bottom.get_unchecked(r - self.top.height(), c)
107 }
108 }
109 }
110
111 unsafe fn row_unchecked(
112 &self,
113 r: usize,
114 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
115 unsafe {
116 if r < self.top.height() {
118 EitherRow::Left(self.top.row_unchecked(r).into_iter())
119 } else {
120 EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
121 }
122 }
123 }
124
125 unsafe fn row_subseq_unchecked(
126 &self,
127 r: usize,
128 start: usize,
129 end: usize,
130 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
131 unsafe {
132 if r < self.top.height() {
134 EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
135 } else {
136 EitherRow::Right(
137 self.bottom
138 .row_subseq_unchecked(r - self.top.height(), start, end)
139 .into_iter(),
140 )
141 }
142 }
143 }
144
145 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
146 unsafe {
147 if r < self.top.height() {
149 EitherRow::Left(self.top.row_slice_unchecked(r))
150 } else {
151 EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
152 }
153 }
154 }
155
156 unsafe fn row_subslice_unchecked(
157 &self,
158 r: usize,
159 start: usize,
160 end: usize,
161 ) -> impl Deref<Target = [T]> {
162 unsafe {
163 if r < self.top.height() {
165 EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
166 } else {
167 EitherRow::Right(self.bottom.row_subslice_unchecked(
168 r - self.top.height(),
169 start,
170 end,
171 ))
172 }
173 }
174 }
175}
176
177impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
178 for HorizontalPair<Left, Right>
179{
180 fn width(&self) -> usize {
181 self.left.width() + self.right.width()
182 }
183
184 fn height(&self) -> usize {
185 self.left.height()
186 }
187
188 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
189 unsafe {
190 if c < self.left.width() {
192 self.left.get_unchecked(r, c)
193 } else {
194 self.right.get_unchecked(r, c - self.left.width())
195 }
196 }
197 }
198
199 unsafe fn row_unchecked(
200 &self,
201 r: usize,
202 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
203 unsafe {
204 self.left
206 .row_unchecked(r)
207 .into_iter()
208 .chain(self.right.row_unchecked(r))
209 }
210 }
211}
212
213#[derive(Debug)]
215pub enum EitherRow<L, R> {
216 Left(L),
217 Right(R),
218}
219
220impl<T, L, R> Iterator for EitherRow<L, R>
221where
222 L: Iterator<Item = T>,
223 R: Iterator<Item = T>,
224{
225 type Item = T;
226
227 fn next(&mut self) -> Option<Self::Item> {
228 match self {
229 Self::Left(l) => l.next(),
230 Self::Right(r) => r.next(),
231 }
232 }
233}
234
235impl<T, L, R> Deref for EitherRow<L, R>
236where
237 L: Deref<Target = [T]>,
238 R: Deref<Target = [T]>,
239{
240 type Target = [T];
241 fn deref(&self) -> &Self::Target {
242 match self {
243 Self::Left(l) => l,
244 Self::Right(r) => r,
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use alloc::vec;
252 use alloc::vec::Vec;
253
254 use itertools::Itertools;
255
256 use super::*;
257 use crate::RowMajorMatrix;
258
259 #[test]
260 fn test_vertical_pair_empty_top() {
261 let top = RowMajorMatrix::new(vec![], 2); let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let vpair = VerticalPair::new::<i32>(top, bottom);
264 assert_eq!(vpair.height(), 2);
265 assert_eq!(vpair.get(1, 1), Some(4));
266 unsafe {
267 assert_eq!(vpair.get_unchecked(0, 0), 1);
268 }
269 }
270
271 #[test]
272 fn test_vertical_pair_composition() {
273 let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let vertical = VerticalPair::new::<i32>(top, bottom);
276
277 assert_eq!(vertical.width(), 2);
279 assert_eq!(vertical.height(), 4);
280
281 assert_eq!(vertical.get(0, 0), Some(1));
283 assert_eq!(vertical.get(1, 1), Some(4));
284
285 unsafe {
287 assert_eq!(vertical.get_unchecked(2, 0), 5);
288 assert_eq!(vertical.get_unchecked(3, 1), 8);
289 }
290
291 let row = vertical.row(3).unwrap().into_iter().collect_vec();
293 assert_eq!(row, vec![7, 8]);
294
295 unsafe {
296 let row = vertical.row_unchecked(1).into_iter().collect_vec();
298 assert_eq!(row, vec![3, 4]);
299
300 let row = vertical
301 .row_subseq_unchecked(0, 0, 1)
302 .into_iter()
303 .collect_vec();
304 assert_eq!(row, vec![1]);
305 }
306
307 assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
309
310 unsafe {
311 assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
313 assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
314 }
315
316 assert_eq!(vertical.get(0, 2), None); assert_eq!(vertical.get(4, 0), None); assert!(vertical.row(4).is_none()); assert!(vertical.row_slice(4).is_none()); }
321
322 #[test]
323 fn test_horizontal_pair_composition() {
324 let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let horizontal = HorizontalPair::new::<i32>(left, right);
327
328 assert_eq!(horizontal.height(), 2);
330 assert_eq!(horizontal.width(), 4);
331
332 assert_eq!(horizontal.get(0, 0), Some(1));
334 assert_eq!(horizontal.get(1, 1), Some(4));
335
336 unsafe {
338 assert_eq!(horizontal.get_unchecked(0, 2), 5);
339 assert_eq!(horizontal.get_unchecked(1, 3), 8);
340 }
341
342 let row = horizontal.row(0).unwrap().into_iter().collect_vec();
344 assert_eq!(row, vec![1, 2, 5, 6]);
345
346 unsafe {
347 let row = horizontal.row_unchecked(1).into_iter().collect_vec();
348 assert_eq!(row, vec![3, 4, 7, 8]);
349 }
350
351 assert_eq!(horizontal.get(0, 4), None); assert_eq!(horizontal.get(2, 0), None); assert!(horizontal.row(2).is_none()); }
355
356 #[test]
357 fn test_either_row_iterator_behavior() {
358 type Iter = alloc::vec::IntoIter<i32>;
359
360 let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
362 assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
363
364 let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
366 assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
367 }
368
369 #[test]
370 fn test_either_row_deref_behavior() {
371 let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
372 let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
373
374 assert_eq!(&*left, &[1, 2, 3]);
375 assert_eq!(&*right, &[4, 5]);
376 }
377
378 #[test]
379 #[should_panic]
380 fn test_vertical_pair_width_mismatch_should_panic() {
381 let a = RowMajorMatrix::new(vec![1, 2, 3], 1); let b = RowMajorMatrix::new(vec![4, 5], 2); let _ = VerticalPair::new::<i32>(a, b);
384 }
385
386 #[test]
387 #[should_panic]
388 fn test_horizontal_pair_height_mismatch_should_panic() {
389 let a = RowMajorMatrix::new(vec![1, 2, 3], 3); let b = RowMajorMatrix::new(vec![4, 5], 1); let _ = HorizontalPair::new::<i32>(a, b);
392 }
393}