p3_util/
transpose.rs

1use core::ptr::{swap, swap_nonoverlapping};
2#[cfg(feature = "parallel")]
3use core::sync::atomic::{AtomicPtr, Ordering};
4
5/// Log2 of the matrix dimension below which we use the base-case direct swap loop.
6/// e.g. BASE_CASE_LOG = 3 means base case is used for ≤ 8×8 submatrices
7const BASE_CASE_LOG: usize = 3;
8
9/// Absolute size threshold (in elements) below which recursive swap stops
10const BASE_CASE_ELEMENT_THRESHOLD: usize = 1 << (2 * BASE_CASE_LOG);
11
12#[cfg(feature = "parallel")]
13/// Threshold (in number of elements) beyond which we enable parallel recursion
14const PARALLEL_RECURSION_THRESHOLD: usize = 1 << 10;
15
16/// Transpose a small square matrix in-place using element-wise swaps.
17///
18/// # Parameters
19/// - `arr`: A mutable reference to a 1D array representing a larger row-major matrix.
20/// - `log_stride`: Log2 of the stride between rows in the array.
21/// - `log_size`: Log2 of the dimension of the square matrix to transpose.
22/// - `x`: Offset (in rows and columns) from the top-left corner of the full array.
23///
24/// The matrix occupies a logical square region starting at `(x, x)` and of size `1 << log_size`.
25///
26/// ## SAFETY
27/// - All accesses to `arr` must be in-bounds.
28/// - `log_size <= log_stride` must hold to prevent overlapping indices during swaps.
29unsafe fn transpose_in_place_square_small<T>(
30    arr: &mut [T],
31    log_stride: usize,
32    log_size: usize,
33    x: usize,
34) {
35    unsafe {
36        // Loop over upper triangle (excluding diagonal)
37        for i in (x + 1)..(x + (1 << log_size)) {
38            for j in x..i {
39                // Compute memory offsets and swap M[i, j] <-> M[j, i]
40                swap(
41                    arr.get_unchecked_mut(i + (j << log_stride)),
42                    arr.get_unchecked_mut((i << log_stride) + j),
43                );
44            }
45        }
46    }
47}
48
49/// Recursively swaps two submatrices across the main diagonal as part of a larger transposition.
50///
51/// Given:
52/// - Submatrix `A` of shape `(rows × cols)`
53/// - Submatrix `B` of shape `(cols × rows)`
54///
55/// This function swaps element `A[i, j]` with `B[j, i]`, effectively transposing them
56/// relative to each other.
57///
58/// `A` is assumed to be row-major, starting at pointer `a`, where A[i,j] = a[i * width_outer_mat + j].
59/// `B` is assumed to be row-major, starting at pointer `b`, where B[j,i] = b[j * width_outer_mat + i].
60///
61/// The recursion always splits along the longer dimension to balance cache and workload.
62///
63/// # Safety
64/// - `a` and `b` must be valid for `rows * cols` reads and writes.
65/// - The regions pointed to by `a` and `b` must be disjoint.
66/// - `width_outer_mat` must be large enough to avoid overlapping accesses during index calculations.
67pub(super) unsafe fn transpose_swap<T: Copy>(
68    a: *mut T,
69    b: *mut T,
70    width_outer_mat: usize,
71    (rows, cols): (usize, usize),
72) {
73    let size = rows * cols;
74
75    // Base case: directly swap A[i,j] with B[j,i] using pointer offsets
76    if size < BASE_CASE_ELEMENT_THRESHOLD {
77        for i in 0..rows {
78            for j in 0..cols {
79                let ai = i * width_outer_mat + j;
80                let bi = j * width_outer_mat + i;
81                unsafe {
82                    swap_nonoverlapping(a.add(ai), b.add(bi), 1);
83                }
84            }
85        }
86        return;
87    }
88
89    #[cfg(feature = "parallel")]
90    {
91        // If large enough, split work recursively in parallel
92        if size > PARALLEL_RECURSION_THRESHOLD {
93            let a = AtomicPtr::new(a);
94            let b = AtomicPtr::new(b);
95
96            // Prefer splitting the longer dimension for better balance and locality
97            if rows > cols {
98                let top = rows / 2;
99                let bottom = rows - top;
100                rayon::join(
101                    || {
102                        let a = a.load(Ordering::Relaxed);
103                        let b = b.load(Ordering::Relaxed);
104                        unsafe {
105                            transpose_swap(a, b, width_outer_mat, (top, cols));
106                        }
107                    },
108                    || {
109                        let a = a.load(Ordering::Relaxed);
110                        let b = b.load(Ordering::Relaxed);
111                        unsafe {
112                            transpose_swap(
113                                a.add(top * width_outer_mat),
114                                b.add(top),
115                                width_outer_mat,
116                                (bottom, cols),
117                            );
118                        }
119                    },
120                );
121            } else {
122                let left = cols / 2;
123                let right = cols - left;
124                rayon::join(
125                    || {
126                        let a = a.load(Ordering::Relaxed);
127                        let b = b.load(Ordering::Relaxed);
128                        unsafe {
129                            transpose_swap(a, b, width_outer_mat, (rows, left));
130                        }
131                    },
132                    || {
133                        let a = a.load(Ordering::Relaxed);
134                        let b = b.load(Ordering::Relaxed);
135                        unsafe {
136                            transpose_swap(
137                                a.add(left),
138                                b.add(left * width_outer_mat),
139                                width_outer_mat,
140                                (rows, right),
141                            );
142                        }
143                    },
144                );
145            }
146            return;
147        }
148    }
149
150    // Sequential case: same recursive logic without threading
151    if rows > cols {
152        let top = rows / 2;
153        let bottom = rows - top;
154        unsafe {
155            transpose_swap(a, b, width_outer_mat, (top, cols));
156            transpose_swap(
157                a.add(top * width_outer_mat),
158                b.add(top),
159                width_outer_mat,
160                (bottom, cols),
161            );
162        }
163    } else {
164        let left = cols / 2;
165        let right = cols - left;
166        unsafe {
167            transpose_swap(a, b, width_outer_mat, (rows, left));
168            transpose_swap(
169                a.add(left),
170                b.add(left * width_outer_mat),
171                width_outer_mat,
172                (rows, right),
173            );
174        }
175    }
176}
177
178/// In-place recursive transposition of a square matrix of size `2^log_size × 2^log_size`,
179/// embedded inside a larger row-major array at offset `(x, x)`.
180///
181/// Each matrix element `M[i,j]` is stored at:
182/// ```text
183/// \begin{equation}
184///     \text{index}(i,j) = ((i + x) << log_stride) + (j + x)
185/// \end{equation}
186/// ```
187///
188/// The matrix is recursively split into four quadrants:
189/// ```text
190/// +----+----+
191/// | TL | TR |
192/// +----+----+
193/// | BL | BR |
194/// +----+----+
195/// ```
196/// Transposition proceeds by:
197/// 1. Recursively transposing `TL`
198/// 2. Swapping `TR` and `BL` across the diagonal
199/// 3. Recursively transposing `BR`
200///
201/// # Safety
202/// - Assumes all accesses via `((i + x) << log_stride) + (j + x)` are in-bounds.
203/// - Requires `log_size <= log_stride` to avoid index overlap.
204pub(crate) unsafe fn transpose_in_place_square<T>(
205    arr: &mut [T],
206    log_stride: usize,
207    log_size: usize,
208    x: usize,
209) where
210    T: Copy + Send + Sync,
211{
212    // If small, switch to base case
213    if log_size <= BASE_CASE_LOG {
214        unsafe {
215            transpose_in_place_square_small(arr, log_stride, log_size, x);
216        }
217        return;
218    }
219
220    #[cfg(feature = "parallel")]
221    {
222        // Log2 of half the matrix dimension
223        let log_half_size = log_size - 1;
224        // Half the matrix size (e.g. 8 for 16×16)
225        let half = 1 << log_half_size;
226        // Total number of elements in the full square matrix
227        let elements = 1 << (2 * log_size);
228
229        if elements >= PARALLEL_RECURSION_THRESHOLD {
230            // Shared base pointer for parallel recursion
231            let base = AtomicPtr::new(arr.as_mut_ptr());
232            // Total length of the backing array
233            let len = arr.len();
234            // Row stride in physical memory
235            let stride = 1 << log_stride;
236            // Size of each quadrant (half x half)
237            let dim = 1 << log_half_size;
238
239            // Coordinate each quadrant via `rayon::join`:
240            // - TL and BR are recursive calls
241            // - TR and BL are swapped directly
242            rayon::join(
243                || unsafe {
244                    transpose_in_place_square(
245                        core::slice::from_raw_parts_mut(base.load(Ordering::Relaxed), len),
246                        log_stride,
247                        log_half_size,
248                        x,
249                    );
250                },
251                || {
252                    rayon::join(
253                        // TR: starts at (x, x + half)
254                        // BL: starts at (x + half, x)
255                        || unsafe {
256                            let ptr = base.load(Ordering::Relaxed);
257                            transpose_swap(
258                                ptr.add((x << log_stride) + (x + half)),
259                                ptr.add(((x + half) << log_stride) + x),
260                                stride,
261                                (dim, dim),
262                            );
263                        },
264                        || unsafe {
265                            transpose_in_place_square(
266                                core::slice::from_raw_parts_mut(base.load(Ordering::Relaxed), len),
267                                log_stride,
268                                log_half_size,
269                                x + half,
270                            );
271                        },
272                    )
273                },
274            );
275            return;
276        }
277    }
278
279    // Sequential version of above logic
280    // Log2 of the new quadrant size (we're splitting the matrix in half)
281    let log_block_size = log_size - 1;
282    // Actual size of each quadrant (i.e., half the current matrix size)
283    let block_size = 1 << log_block_size;
284    // Physical stride between rows in memory (in elements)
285    let stride = 1 << log_stride;
286    // The size of each submatrix (used as a dimension for swapping TR/BL)
287    let dim = block_size;
288    // Raw pointer to the base of the array for manual offset calculations
289    let ptr = arr.as_mut_ptr();
290
291    unsafe {
292        // Transpose TL quadrant (top-left)
293        transpose_in_place_square(arr, log_stride, log_block_size, x);
294        // Swap TR (top-right) with BL (bottom-left)
295        transpose_swap(
296            ptr.add((x << log_stride) + (x + block_size)),
297            ptr.add(((x + block_size) << log_stride) + x),
298            stride,
299            (dim, dim),
300        );
301        // Transpose BR quadrant (bottom-right)
302        transpose_in_place_square(arr, log_stride, log_block_size, x + block_size);
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use alloc::vec;
309    use alloc::vec::Vec;
310
311    use super::*;
312
313    /// Helper to create a square matrix of size `2^log_size` with elements `0..n^2`
314    fn generate_matrix(log_size: usize) -> Vec<u32> {
315        let size = 1 << log_size;
316        (0..size * size).collect()
317    }
318
319    /// Reference transpose that returns a new vector (row-major layout)
320    fn transpose_reference(input: &[u32], log_size: usize) -> Vec<u32> {
321        let size = 1 << log_size;
322        let mut transposed = vec![0; size * size];
323        for i in 0..size {
324            for j in 0..size {
325                transposed[j * size + i] = input[i * size + j];
326            }
327        }
328        transposed
329    }
330
331    #[test]
332    fn transpose_square() {
333        // Loop over matrix sizes:
334        // Each size is of the form 2^log_size × 2^log_size
335        for log_size in 1..=10 {
336            // Compute the actual dimension: size = 2^log_size
337            let size = 1 << log_size;
338
339            // Generate a flat matrix of size×size elements
340            let mut mat = generate_matrix(log_size);
341
342            // Compute the reference result using a naive transpose implementation
343            let expected = transpose_reference(&mat, log_size);
344
345            // Perform the in-place transpose on `mat`.
346            unsafe {
347                transpose_in_place_square(&mut mat, log_size, log_size, 0);
348            }
349
350            // Compare the transposed matrix against the reference.
351            assert_eq!(mat, expected, "Transpose failed for {size}x{size} matrix");
352        }
353    }
354}