1use core::ops::Deref;
2
3use p3_field::PackedValue;
4
5use crate::Matrix;
6use crate::dense::RowMajorMatrix;
7
8pub trait RowIndexMap: Send + Sync {
13 fn height(&self) -> usize;
15
16 fn map_row_index(&self, r: usize) -> usize;
23
24 fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
29 &self,
30 inner: Inner,
31 ) -> RowMajorMatrix<T> {
32 RowMajorMatrix::new(
33 unsafe {
34 (0..self.height())
36 .flat_map(|r| inner.row_unchecked(self.map_row_index(r)))
37 .collect()
38 },
39 inner.width(),
40 )
41 }
42}
43
44#[derive(Copy, Clone, Debug)]
49pub struct RowIndexMappedView<IndexMap, Inner> {
50 pub index_map: IndexMap,
52 pub inner: Inner,
54}
55
56impl<T: Send + Sync + Clone, IndexMap: RowIndexMap, Inner: Matrix<T>> Matrix<T>
57 for RowIndexMappedView<IndexMap, Inner>
58{
59 fn width(&self) -> usize {
60 self.inner.width()
61 }
62
63 fn height(&self) -> usize {
64 self.index_map.height()
65 }
66
67 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
68 unsafe {
69 self.inner.get_unchecked(self.index_map.map_row_index(r), c)
71 }
72 }
73
74 unsafe fn row_unchecked(
75 &self,
76 r: usize,
77 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
78 unsafe {
79 self.inner.row_unchecked(self.index_map.map_row_index(r))
81 }
82 }
83
84 unsafe fn row_subseq_unchecked(
85 &self,
86 r: usize,
87 start: usize,
88 end: usize,
89 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
90 unsafe {
91 self.inner
93 .row_subseq_unchecked(self.index_map.map_row_index(r), start, end)
94 }
95 }
96
97 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
98 unsafe {
99 self.inner
101 .row_slice_unchecked(self.index_map.map_row_index(r))
102 }
103 }
104
105 unsafe fn row_subslice_unchecked(
106 &self,
107 r: usize,
108 start: usize,
109 end: usize,
110 ) -> impl Deref<Target = [T]> {
111 unsafe {
112 self.inner
114 .row_subslice_unchecked(self.index_map.map_row_index(r), start, end)
115 }
116 }
117
118 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
119 where
120 Self: Sized,
121 T: Clone,
122 {
123 self.index_map.to_row_major_matrix(self.inner)
125 }
126
127 fn horizontally_packed_row<'a, P>(
128 &'a self,
129 r: usize,
130 ) -> (
131 impl Iterator<Item = P> + Send + Sync,
132 impl Iterator<Item = T> + Send + Sync,
133 )
134 where
135 P: PackedValue<Value = T>,
136 T: Clone + 'a,
137 {
138 self.inner
139 .horizontally_packed_row(self.index_map.map_row_index(r))
140 }
141
142 fn padded_horizontally_packed_row<'a, P>(
143 &'a self,
144 r: usize,
145 ) -> impl Iterator<Item = P> + Send + Sync
146 where
147 P: PackedValue<Value = T>,
148 T: Clone + Default + 'a,
149 {
150 self.inner
151 .padded_horizontally_packed_row(self.index_map.map_row_index(r))
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use alloc::vec;
158 use alloc::vec::Vec;
159
160 use itertools::Itertools;
161 use p3_baby_bear::BabyBear;
162 use p3_field::FieldArray;
163
164 use super::*;
165 use crate::dense::RowMajorMatrix;
166
167 struct IdentityMap(usize);
169
170 impl RowIndexMap for IdentityMap {
171 fn height(&self) -> usize {
172 self.0
173 }
174
175 fn map_row_index(&self, r: usize) -> usize {
176 r
177 }
178 }
179
180 struct ReverseMap(usize);
182
183 impl RowIndexMap for ReverseMap {
184 fn height(&self) -> usize {
185 self.0
186 }
187
188 fn map_row_index(&self, r: usize) -> usize {
189 self.0 - 1 - r
190 }
191 }
192
193 struct ConstantMap;
195
196 impl RowIndexMap for ConstantMap {
197 fn height(&self) -> usize {
198 1
199 }
200
201 fn map_row_index(&self, _r: usize) -> usize {
202 0
203 }
204 }
205
206 #[test]
207 fn test_identity_row_index_map() {
208 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
213
214 let mapped_view = RowIndexMappedView {
216 index_map: IdentityMap(inner.height()),
217 inner,
218 };
219
220 assert_eq!(mapped_view.height(), 2);
222 assert_eq!(mapped_view.width(), 3);
223
224 assert_eq!(mapped_view.get(0, 0).unwrap(), 1);
226 assert_eq!(mapped_view.get(1, 2).unwrap(), 6);
227
228 unsafe {
229 assert_eq!(mapped_view.get_unchecked(0, 1), 2);
230 assert_eq!(mapped_view.get_unchecked(1, 0), 4);
231 }
232
233 let rows: Vec<Vec<_>> = mapped_view.rows().map(|row| row.collect()).collect();
235 assert_eq!(rows, vec![vec![1, 2, 3], vec![4, 5, 6]]);
236
237 let dense = mapped_view.to_row_major_matrix();
239 assert_eq!(dense.values, vec![1, 2, 3, 4, 5, 6]);
240 }
241
242 #[test]
243 fn test_reverse_row_index_map() {
244 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
249
250 let mapped_view = RowIndexMappedView {
252 index_map: ReverseMap(inner.height()),
253 inner,
254 };
255
256 assert_eq!(mapped_view.height(), 2);
258 assert_eq!(mapped_view.width(), 3);
259
260 assert_eq!(mapped_view.get(0, 0).unwrap(), 4);
262 assert_eq!(mapped_view.get(1, 2).unwrap(), 3);
264
265 unsafe {
266 assert_eq!(mapped_view.get_unchecked(0, 1), 5);
267 assert_eq!(mapped_view.get_unchecked(1, 0), 1);
268 }
269
270 let rows: Vec<Vec<_>> = mapped_view.rows().map(|row| row.collect()).collect();
272 assert_eq!(rows, vec![vec![4, 5, 6], vec![1, 2, 3]]);
273
274 let dense = mapped_view.to_row_major_matrix();
276 assert_eq!(dense.values, vec![4, 5, 6, 1, 2, 3]);
277 }
278
279 #[test]
280 fn test_horizontally_packed_row() {
281 type Packed = FieldArray<BabyBear, 2>;
283
284 let inner = RowMajorMatrix::new(
289 vec![
290 BabyBear::new(1),
291 BabyBear::new(2),
292 BabyBear::new(3),
293 BabyBear::new(4),
294 ],
295 2,
296 );
297
298 let mapped_view = RowIndexMappedView {
300 index_map: ReverseMap(inner.height()),
301 inner,
302 };
303
304 let (packed_iter, mut suffix_iter) = mapped_view.horizontally_packed_row::<Packed>(0);
306
307 let packed: Vec<_> = packed_iter.collect();
309
310 assert_eq!(
312 packed,
313 &[Packed::from([BabyBear::new(3), BabyBear::new(4)])]
314 );
315
316 assert!(suffix_iter.next().is_none());
318 }
319
320 #[test]
321 fn test_padded_horizontally_packed_row() {
322 type Packed = FieldArray<BabyBear, 3>;
324
325 let inner = RowMajorMatrix::new(
329 vec![
330 BabyBear::new(1),
331 BabyBear::new(2),
332 BabyBear::new(3),
333 BabyBear::new(4),
334 ],
335 2,
336 );
337
338 let mapped_view = RowIndexMappedView {
340 index_map: IdentityMap(inner.height()),
341 inner,
342 };
343
344 let packed: Vec<_> = mapped_view
346 .padded_horizontally_packed_row::<Packed>(1)
347 .collect();
348
349 assert_eq!(
351 packed,
352 vec![Packed::from([
353 BabyBear::new(3),
354 BabyBear::new(4),
355 BabyBear::new(0),
356 ])]
357 );
358 }
359
360 #[test]
361 fn test_row_and_row_slice_methods() {
362 let inner = RowMajorMatrix::new(vec![10, 20, 30, 40, 50, 60], 3);
366
367 let mapped_view = RowIndexMappedView {
369 index_map: ReverseMap(inner.height()),
370 inner,
371 };
372
373 assert_eq!(mapped_view.row_slice(0).unwrap().deref(), &[40, 50, 60]); assert_eq!(
376 mapped_view.row(1).unwrap().into_iter().collect_vec(),
377 vec![10, 20, 30]
378 ); unsafe {
381 assert_eq!(
383 mapped_view.row_unchecked(0).into_iter().collect_vec(),
384 vec![40, 50, 60]
385 ); assert_eq!(mapped_view.row_slice_unchecked(1).deref(), &[10, 20, 30]); assert_eq!(
389 mapped_view.row_subslice_unchecked(0, 1, 3).deref(),
390 &[50, 60]
391 ); assert_eq!(
393 mapped_view
394 .row_subseq_unchecked(1, 0, 2)
395 .into_iter()
396 .collect_vec(),
397 vec![10, 20]
398 ); }
400
401 assert!(mapped_view.row(2).is_none()); assert!(mapped_view.row_slice(2).is_none()); }
404
405 #[test]
406 fn test_out_of_bounds_access() {
407 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
411
412 let mapped_view = RowIndexMappedView {
414 index_map: IdentityMap(inner.height()),
415 inner,
416 };
417
418 assert_eq!(mapped_view.get(2, 1), None);
420 assert!(mapped_view.row(5).is_none());
421 assert!(mapped_view.row_slice(11).is_none());
422 assert_eq!(mapped_view.get(0, 20), None);
423 }
424
425 #[test]
426 fn test_out_of_bounds_access_with_bad_map() {
427 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 4);
431
432 let mapped_view = RowIndexMappedView {
434 index_map: ConstantMap,
435 inner,
436 };
437
438 assert_eq!(mapped_view.get(0, 2), Some(3));
439
440 assert_eq!(mapped_view.get(1, 0), None);
442 assert!(mapped_view.row(1).is_none());
443 assert!(mapped_view.row_slice(1).is_none());
444 }
445}