p3_matrix/
row_index_mapped.rs

1use core::ops::Deref;
2
3use p3_field::PackedValue;
4
5use crate::Matrix;
6use crate::dense::RowMajorMatrix;
7
8/// A trait for remapping row indices of a matrix.
9///
10/// Implementations can change the number of visible rows (`height`)
11/// and define how a given logical row index maps to a physical one.
12pub trait RowIndexMap: Send + Sync {
13    /// Returns the number of rows exposed by the mapping.
14    fn height(&self) -> usize;
15
16    /// Maps a visible row index `r` to the corresponding row index in the underlying matrix.
17    ///
18    /// The input `r` is assumed to lie in the range `0..self.height()` and the output
19    /// will lie in the range `0..self.inner.height()`.
20    ///
21    /// It is considered undefined behaviour to call `map_row_index` with `r >= self.height()`.
22    fn map_row_index(&self, r: usize) -> usize;
23
24    /// Converts the mapped matrix into a dense row-major matrix.
25    ///
26    /// This default implementation iterates over all mapped rows,
27    /// collects them in order, and builds a dense representation.
28    fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
29        &self,
30        inner: Inner,
31    ) -> RowMajorMatrix<T> {
32        RowMajorMatrix::new(
33            unsafe {
34                // Safety: The output of `map_row_index` is less than `inner.height()` for all inputs in the range `0..self.height()`.
35                (0..self.height())
36                    .flat_map(|r| inner.row_unchecked(self.map_row_index(r)))
37                    .collect()
38            },
39            inner.width(),
40        )
41    }
42}
43
44/// A matrix view that applies a row index mapping to an inner matrix.
45///
46/// The mapping changes which rows are visible and in what order.
47/// The width remains unchanged.
48#[derive(Copy, Clone, Debug)]
49pub struct RowIndexMappedView<IndexMap, Inner> {
50    /// A row index mapping that defines the number and order of visible rows.
51    pub index_map: IndexMap,
52    /// The inner matrix that holds actual data.
53    pub inner: Inner,
54}
55
56impl<T: Send + Sync + Clone, IndexMap: RowIndexMap, Inner: Matrix<T>> Matrix<T>
57    for RowIndexMappedView<IndexMap, Inner>
58{
59    fn width(&self) -> usize {
60        self.inner.width()
61    }
62
63    fn height(&self) -> usize {
64        self.index_map.height()
65    }
66
67    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
68        unsafe {
69            // Safety: The caller must ensure that r < self.height() and c < self.width().
70            self.inner.get_unchecked(self.index_map.map_row_index(r), c)
71        }
72    }
73
74    unsafe fn row_unchecked(
75        &self,
76        r: usize,
77    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
78        unsafe {
79            // Safety: The caller must ensure that r < self.height().
80            self.inner.row_unchecked(self.index_map.map_row_index(r))
81        }
82    }
83
84    unsafe fn row_subseq_unchecked(
85        &self,
86        r: usize,
87        start: usize,
88        end: usize,
89    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
90        unsafe {
91            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
92            self.inner
93                .row_subseq_unchecked(self.index_map.map_row_index(r), start, end)
94        }
95    }
96
97    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
98        unsafe {
99            // Safety: The caller must ensure that r < self.height().
100            self.inner
101                .row_slice_unchecked(self.index_map.map_row_index(r))
102        }
103    }
104
105    unsafe fn row_subslice_unchecked(
106        &self,
107        r: usize,
108        start: usize,
109        end: usize,
110    ) -> impl Deref<Target = [T]> {
111        unsafe {
112            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
113            self.inner
114                .row_subslice_unchecked(self.index_map.map_row_index(r), start, end)
115        }
116    }
117
118    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
119    where
120        Self: Sized,
121        T: Clone,
122    {
123        // Use Perm's optimized permutation routine, if it has one.
124        self.index_map.to_row_major_matrix(self.inner)
125    }
126
127    fn horizontally_packed_row<'a, P>(
128        &'a self,
129        r: usize,
130    ) -> (
131        impl Iterator<Item = P> + Send + Sync,
132        impl Iterator<Item = T> + Send + Sync,
133    )
134    where
135        P: PackedValue<Value = T>,
136        T: Clone + 'a,
137    {
138        self.inner
139            .horizontally_packed_row(self.index_map.map_row_index(r))
140    }
141
142    fn padded_horizontally_packed_row<'a, P>(
143        &'a self,
144        r: usize,
145    ) -> impl Iterator<Item = P> + Send + Sync
146    where
147        P: PackedValue<Value = T>,
148        T: Clone + Default + 'a,
149    {
150        self.inner
151            .padded_horizontally_packed_row(self.index_map.map_row_index(r))
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use alloc::vec;
158    use alloc::vec::Vec;
159
160    use itertools::Itertools;
161    use p3_baby_bear::BabyBear;
162    use p3_field::FieldArray;
163
164    use super::*;
165    use crate::dense::RowMajorMatrix;
166
167    /// Mock implementation of RowIndexMap
168    struct IdentityMap(usize);
169
170    impl RowIndexMap for IdentityMap {
171        fn height(&self) -> usize {
172            self.0
173        }
174
175        fn map_row_index(&self, r: usize) -> usize {
176            r
177        }
178    }
179
180    /// Another mock implementation for reversing rows
181    struct ReverseMap(usize);
182
183    impl RowIndexMap for ReverseMap {
184        fn height(&self) -> usize {
185            self.0
186        }
187
188        fn map_row_index(&self, r: usize) -> usize {
189            self.0 - 1 - r
190        }
191    }
192
193    /// A final Mock implementation of RowIndexMap
194    struct ConstantMap;
195
196    impl RowIndexMap for ConstantMap {
197        fn height(&self) -> usize {
198            1
199        }
200
201        fn map_row_index(&self, _r: usize) -> usize {
202            0
203        }
204    }
205
206    #[test]
207    fn test_identity_row_index_map() {
208        // Create an inner matrix.
209        // The matrix will be:
210        // [ 1  2  3 ]
211        // [ 4  5  6 ]
212        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
213
214        // Create a mapped view using an `IdentityMap`, which does not alter row indices.
215        let mapped_view = RowIndexMappedView {
216            index_map: IdentityMap(inner.height()),
217            inner,
218        };
219
220        // Check dimensions.
221        assert_eq!(mapped_view.height(), 2);
222        assert_eq!(mapped_view.width(), 3);
223
224        // Check values.
225        assert_eq!(mapped_view.get(0, 0).unwrap(), 1);
226        assert_eq!(mapped_view.get(1, 2).unwrap(), 6);
227
228        unsafe {
229            assert_eq!(mapped_view.get_unchecked(0, 1), 2);
230            assert_eq!(mapped_view.get_unchecked(1, 0), 4);
231        }
232
233        // Check rows.
234        let rows: Vec<Vec<_>> = mapped_view.rows().map(|row| row.collect()).collect();
235        assert_eq!(rows, vec![vec![1, 2, 3], vec![4, 5, 6]]);
236
237        // Check dense matrix.
238        let dense = mapped_view.to_row_major_matrix();
239        assert_eq!(dense.values, vec![1, 2, 3, 4, 5, 6]);
240    }
241
242    #[test]
243    fn test_reverse_row_index_map() {
244        // Create an inner matrix.
245        // The matrix will be:
246        // [ 1  2  3 ]
247        // [ 4  5  6 ]
248        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
249
250        // Create a mapped view using a ReverseMap, which reverses row indices.
251        let mapped_view = RowIndexMappedView {
252            index_map: ReverseMap(inner.height()),
253            inner,
254        };
255
256        // Check dimensions.
257        assert_eq!(mapped_view.height(), 2);
258        assert_eq!(mapped_view.width(), 3);
259
260        // Check the first element of the mapped view (originally the second row, first column).
261        assert_eq!(mapped_view.get(0, 0).unwrap(), 4);
262        // Check the last element of the mapped view (originally the first row, last column).
263        assert_eq!(mapped_view.get(1, 2).unwrap(), 3);
264
265        unsafe {
266            assert_eq!(mapped_view.get_unchecked(0, 1), 5);
267            assert_eq!(mapped_view.get_unchecked(1, 0), 1);
268        }
269
270        // Check rows.
271        let rows: Vec<Vec<_>> = mapped_view.rows().map(|row| row.collect()).collect();
272        assert_eq!(rows, vec![vec![4, 5, 6], vec![1, 2, 3]]);
273
274        // Check dense matrix.
275        let dense = mapped_view.to_row_major_matrix();
276        assert_eq!(dense.values, vec![4, 5, 6, 1, 2, 3]);
277    }
278
279    #[test]
280    fn test_horizontally_packed_row() {
281        // Define the packed type with width 2
282        type Packed = FieldArray<BabyBear, 2>;
283
284        // Create an inner matrix of BabyBear elements.
285        // Matrix layout:
286        // [ 1  2 ]
287        // [ 3  4 ]
288        let inner = RowMajorMatrix::new(
289            vec![
290                BabyBear::new(1),
291                BabyBear::new(2),
292                BabyBear::new(3),
293                BabyBear::new(4),
294            ],
295            2,
296        );
297
298        // Apply a reverse row index mapping.
299        let mapped_view = RowIndexMappedView {
300            index_map: ReverseMap(inner.height()),
301            inner,
302        };
303
304        // Extract the packed and suffix iterators from row 0 (which is reversed row 1).
305        let (packed_iter, mut suffix_iter) = mapped_view.horizontally_packed_row::<Packed>(0);
306
307        // Collect iterators to concrete values.
308        let packed: Vec<_> = packed_iter.collect();
309
310        // Check the packed row values match reversed second row.
311        assert_eq!(
312            packed,
313            &[Packed::from([BabyBear::new(3), BabyBear::new(4)])]
314        );
315
316        // Check there are no suffix leftovers.
317        assert!(suffix_iter.next().is_none());
318    }
319
320    #[test]
321    fn test_padded_horizontally_packed_row() {
322        // Define a packed type with width 3
323        type Packed = FieldArray<BabyBear, 3>;
324
325        // Create a 2x2 matrix of BabyBear elements:
326        // [ 1  2 ]
327        // [ 3  4 ]
328        let inner = RowMajorMatrix::new(
329            vec![
330                BabyBear::new(1),
331                BabyBear::new(2),
332                BabyBear::new(3),
333                BabyBear::new(4),
334            ],
335            2,
336        );
337
338        // Use identity mapping (rows remain unchanged).
339        let mapped_view = RowIndexMappedView {
340            index_map: IdentityMap(inner.height()),
341            inner,
342        };
343
344        // Pad the second row (row 1) into chunks of size 3.
345        let packed: Vec<_> = mapped_view
346            .padded_horizontally_packed_row::<Packed>(1)
347            .collect();
348
349        // Verify the packed result includes padding with zero at the end.
350        assert_eq!(
351            packed,
352            vec![Packed::from([
353                BabyBear::new(3),
354                BabyBear::new(4),
355                BabyBear::new(0),
356            ])]
357        );
358    }
359
360    #[test]
361    fn test_row_and_row_slice_methods() {
362        // Create a 2x3 matrix of integers:
363        // [ 10  20  30 ]
364        // [ 40  50  60 ]
365        let inner = RowMajorMatrix::new(vec![10, 20, 30, 40, 50, 60], 3);
366
367        // Apply reverse row mapping (row 0 becomes 1, row 1 becomes 0).
368        let mapped_view = RowIndexMappedView {
369            index_map: ReverseMap(inner.height()),
370            inner,
371        };
372
373        // Get row slices through dereferencing and verify content.
374        assert_eq!(mapped_view.row_slice(0).unwrap().deref(), &[40, 50, 60]); // was row 1
375        assert_eq!(
376            mapped_view.row(1).unwrap().into_iter().collect_vec(),
377            vec![10, 20, 30]
378        ); // was row 0
379
380        unsafe {
381            // Check unsafe row slices.
382            assert_eq!(
383                mapped_view.row_unchecked(0).into_iter().collect_vec(),
384                vec![40, 50, 60]
385            ); // was row 1
386            assert_eq!(mapped_view.row_slice_unchecked(1).deref(), &[10, 20, 30]); // was row 0
387
388            assert_eq!(
389                mapped_view.row_subslice_unchecked(0, 1, 3).deref(),
390                &[50, 60]
391            ); // was row 1
392            assert_eq!(
393                mapped_view
394                    .row_subseq_unchecked(1, 0, 2)
395                    .into_iter()
396                    .collect_vec(),
397                vec![10, 20]
398            ); // was row 0
399        }
400
401        assert!(mapped_view.row(2).is_none()); // Height out of bounds.
402        assert!(mapped_view.row_slice(2).is_none()); // Height out of bounds.
403    }
404
405    #[test]
406    fn test_out_of_bounds_access() {
407        // Create a 2x2 matrix:
408        // [ 1  2 ]
409        // [ 3  4 ]
410        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
411
412        // Use identity mapping.
413        let mapped_view = RowIndexMappedView {
414            index_map: IdentityMap(inner.height()),
415            inner,
416        };
417
418        // Attempt to access out-of-bounds row (index 2). Should panic.
419        assert_eq!(mapped_view.get(2, 1), None);
420        assert!(mapped_view.row(5).is_none());
421        assert!(mapped_view.row_slice(11).is_none());
422        assert_eq!(mapped_view.get(0, 20), None);
423    }
424
425    #[test]
426    fn test_out_of_bounds_access_with_bad_map() {
427        // Create a 2x2 matrix:
428        // [ 1  2 ]
429        // [ 3  4 ]
430        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 4);
431
432        // Use identity mapping.
433        let mapped_view = RowIndexMappedView {
434            index_map: ConstantMap,
435            inner,
436        };
437
438        assert_eq!(mapped_view.get(0, 2), Some(3));
439
440        // Attempt to access out-of-bounds row (index 1). Should panic.
441        assert_eq!(mapped_view.get(1, 0), None);
442        assert!(mapped_view.row(1).is_none());
443        assert!(mapped_view.row_slice(1).is_none());
444    }
445}