p3_util/transpose.rs
1use core::ptr::{swap, swap_nonoverlapping};
2#[cfg(feature = "parallel")]
3use core::sync::atomic::{AtomicPtr, Ordering};
4
5/// Log2 of the matrix dimension below which we use the base-case direct swap loop.
6/// e.g. BASE_CASE_LOG = 3 means base case is used for ≤ 8×8 submatrices
7const BASE_CASE_LOG: usize = 3;
8
9/// Absolute size threshold (in elements) below which recursive swap stops
10const BASE_CASE_ELEMENT_THRESHOLD: usize = 1 << (2 * BASE_CASE_LOG);
11
12#[cfg(feature = "parallel")]
13/// Threshold (in number of elements) beyond which we enable parallel recursion
14const PARALLEL_RECURSION_THRESHOLD: usize = 1 << 10;
15
16/// Transpose a small square matrix in-place using element-wise swaps.
17///
18/// # Parameters
19/// - `arr`: A mutable reference to a 1D array representing a larger row-major matrix.
20/// - `log_stride`: Log2 of the stride between rows in the array.
21/// - `log_size`: Log2 of the dimension of the square matrix to transpose.
22/// - `x`: Offset (in rows and columns) from the top-left corner of the full array.
23///
24/// The matrix occupies a logical square region starting at `(x, x)` and of size `1 << log_size`.
25///
26/// ## SAFETY
27/// - All accesses to `arr` must be in-bounds.
28/// - `log_size <= log_stride` must hold to prevent overlapping indices during swaps.
29unsafe fn transpose_in_place_square_small<T>(
30 arr: &mut [T],
31 log_stride: usize,
32 log_size: usize,
33 x: usize,
34) {
35 unsafe {
36 // Loop over upper triangle (excluding diagonal)
37 for i in (x + 1)..(x + (1 << log_size)) {
38 for j in x..i {
39 // Compute memory offsets and swap M[i, j] <-> M[j, i]
40 swap(
41 arr.get_unchecked_mut(i + (j << log_stride)),
42 arr.get_unchecked_mut((i << log_stride) + j),
43 );
44 }
45 }
46 }
47}
48
49/// Recursively swaps two submatrices across the main diagonal as part of a larger transposition.
50///
51/// Given:
52/// - Submatrix `A` of shape `(rows × cols)`
53/// - Submatrix `B` of shape `(cols × rows)`
54///
55/// This function swaps element `A[i, j]` with `B[j, i]`, effectively transposing them
56/// relative to each other.
57///
58/// `A` is assumed to be row-major, starting at pointer `a`, where A[i,j] = a[i * width_outer_mat + j].
59/// `B` is assumed to be row-major, starting at pointer `b`, where B[j,i] = b[j * width_outer_mat + i].
60///
61/// The recursion always splits along the longer dimension to balance cache and workload.
62///
63/// # Safety
64/// - `a` and `b` must be valid for `rows * cols` reads and writes.
65/// - The regions pointed to by `a` and `b` must be disjoint.
66/// - `width_outer_mat` must be large enough to avoid overlapping accesses during index calculations.
67pub(super) unsafe fn transpose_swap<T: Copy>(
68 a: *mut T,
69 b: *mut T,
70 width_outer_mat: usize,
71 (rows, cols): (usize, usize),
72) {
73 let size = rows * cols;
74
75 // Base case: directly swap A[i,j] with B[j,i] using pointer offsets
76 if size < BASE_CASE_ELEMENT_THRESHOLD {
77 for i in 0..rows {
78 for j in 0..cols {
79 let ai = i * width_outer_mat + j;
80 let bi = j * width_outer_mat + i;
81 unsafe {
82 swap_nonoverlapping(a.add(ai), b.add(bi), 1);
83 }
84 }
85 }
86 return;
87 }
88
89 #[cfg(feature = "parallel")]
90 {
91 // If large enough, split work recursively in parallel
92 if size > PARALLEL_RECURSION_THRESHOLD {
93 let a = AtomicPtr::new(a);
94 let b = AtomicPtr::new(b);
95
96 // Prefer splitting the longer dimension for better balance and locality
97 if rows > cols {
98 let top = rows / 2;
99 let bottom = rows - top;
100 rayon::join(
101 || {
102 let a = a.load(Ordering::Relaxed);
103 let b = b.load(Ordering::Relaxed);
104 unsafe {
105 transpose_swap(a, b, width_outer_mat, (top, cols));
106 }
107 },
108 || {
109 let a = a.load(Ordering::Relaxed);
110 let b = b.load(Ordering::Relaxed);
111 unsafe {
112 transpose_swap(
113 a.add(top * width_outer_mat),
114 b.add(top),
115 width_outer_mat,
116 (bottom, cols),
117 );
118 }
119 },
120 );
121 } else {
122 let left = cols / 2;
123 let right = cols - left;
124 rayon::join(
125 || {
126 let a = a.load(Ordering::Relaxed);
127 let b = b.load(Ordering::Relaxed);
128 unsafe {
129 transpose_swap(a, b, width_outer_mat, (rows, left));
130 }
131 },
132 || {
133 let a = a.load(Ordering::Relaxed);
134 let b = b.load(Ordering::Relaxed);
135 unsafe {
136 transpose_swap(
137 a.add(left),
138 b.add(left * width_outer_mat),
139 width_outer_mat,
140 (rows, right),
141 );
142 }
143 },
144 );
145 }
146 return;
147 }
148 }
149
150 // Sequential case: same recursive logic without threading
151 if rows > cols {
152 let top = rows / 2;
153 let bottom = rows - top;
154 unsafe {
155 transpose_swap(a, b, width_outer_mat, (top, cols));
156 transpose_swap(
157 a.add(top * width_outer_mat),
158 b.add(top),
159 width_outer_mat,
160 (bottom, cols),
161 );
162 }
163 } else {
164 let left = cols / 2;
165 let right = cols - left;
166 unsafe {
167 transpose_swap(a, b, width_outer_mat, (rows, left));
168 transpose_swap(
169 a.add(left),
170 b.add(left * width_outer_mat),
171 width_outer_mat,
172 (rows, right),
173 );
174 }
175 }
176}
177
178/// In-place recursive transposition of a square matrix of size `2^log_size × 2^log_size`,
179/// embedded inside a larger row-major array at offset `(x, x)`.
180///
181/// Each matrix element `M[i,j]` is stored at:
182/// ```text
183/// \begin{equation}
184/// \text{index}(i,j) = ((i + x) << log_stride) + (j + x)
185/// \end{equation}
186/// ```
187///
188/// The matrix is recursively split into four quadrants:
189/// ```text
190/// +----+----+
191/// | TL | TR |
192/// +----+----+
193/// | BL | BR |
194/// +----+----+
195/// ```
196/// Transposition proceeds by:
197/// 1. Recursively transposing `TL`
198/// 2. Swapping `TR` and `BL` across the diagonal
199/// 3. Recursively transposing `BR`
200///
201/// # Safety
202/// - Assumes all accesses via `((i + x) << log_stride) + (j + x)` are in-bounds.
203/// - Requires `log_size <= log_stride` to avoid index overlap.
204pub(crate) unsafe fn transpose_in_place_square<T>(
205 arr: &mut [T],
206 log_stride: usize,
207 log_size: usize,
208 x: usize,
209) where
210 T: Copy + Send + Sync,
211{
212 // If small, switch to base case
213 if log_size <= BASE_CASE_LOG {
214 unsafe {
215 transpose_in_place_square_small(arr, log_stride, log_size, x);
216 }
217 return;
218 }
219
220 #[cfg(feature = "parallel")]
221 {
222 // Log2 of half the matrix dimension
223 let log_half_size = log_size - 1;
224 // Half the matrix size (e.g. 8 for 16×16)
225 let half = 1 << log_half_size;
226 // Total number of elements in the full square matrix
227 let elements = 1 << (2 * log_size);
228
229 if elements >= PARALLEL_RECURSION_THRESHOLD {
230 // Shared base pointer for parallel recursion
231 let base = AtomicPtr::new(arr.as_mut_ptr());
232 // Total length of the backing array
233 let len = arr.len();
234 // Row stride in physical memory
235 let stride = 1 << log_stride;
236 // Size of each quadrant (half x half)
237 let dim = 1 << log_half_size;
238
239 // Coordinate each quadrant via `rayon::join`:
240 // - TL and BR are recursive calls
241 // - TR and BL are swapped directly
242 rayon::join(
243 || unsafe {
244 transpose_in_place_square(
245 core::slice::from_raw_parts_mut(base.load(Ordering::Relaxed), len),
246 log_stride,
247 log_half_size,
248 x,
249 );
250 },
251 || {
252 rayon::join(
253 // TR: starts at (x, x + half)
254 // BL: starts at (x + half, x)
255 || unsafe {
256 let ptr = base.load(Ordering::Relaxed);
257 transpose_swap(
258 ptr.add((x << log_stride) + (x + half)),
259 ptr.add(((x + half) << log_stride) + x),
260 stride,
261 (dim, dim),
262 );
263 },
264 || unsafe {
265 transpose_in_place_square(
266 core::slice::from_raw_parts_mut(base.load(Ordering::Relaxed), len),
267 log_stride,
268 log_half_size,
269 x + half,
270 );
271 },
272 )
273 },
274 );
275 return;
276 }
277 }
278
279 // Sequential version of above logic
280 // Log2 of the new quadrant size (we're splitting the matrix in half)
281 let log_block_size = log_size - 1;
282 // Actual size of each quadrant (i.e., half the current matrix size)
283 let block_size = 1 << log_block_size;
284 // Physical stride between rows in memory (in elements)
285 let stride = 1 << log_stride;
286 // The size of each submatrix (used as a dimension for swapping TR/BL)
287 let dim = block_size;
288 // Raw pointer to the base of the array for manual offset calculations
289 let ptr = arr.as_mut_ptr();
290
291 unsafe {
292 // Transpose TL quadrant (top-left)
293 transpose_in_place_square(arr, log_stride, log_block_size, x);
294 // Swap TR (top-right) with BL (bottom-left)
295 transpose_swap(
296 ptr.add((x << log_stride) + (x + block_size)),
297 ptr.add(((x + block_size) << log_stride) + x),
298 stride,
299 (dim, dim),
300 );
301 // Transpose BR quadrant (bottom-right)
302 transpose_in_place_square(arr, log_stride, log_block_size, x + block_size);
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use alloc::vec;
309 use alloc::vec::Vec;
310
311 use super::*;
312
313 /// Helper to create a square matrix of size `2^log_size` with elements `0..n^2`
314 fn generate_matrix(log_size: usize) -> Vec<u32> {
315 let size = 1 << log_size;
316 (0..size * size).collect()
317 }
318
319 /// Reference transpose that returns a new vector (row-major layout)
320 fn transpose_reference(input: &[u32], log_size: usize) -> Vec<u32> {
321 let size = 1 << log_size;
322 let mut transposed = vec![0; size * size];
323 for i in 0..size {
324 for j in 0..size {
325 transposed[j * size + i] = input[i * size + j];
326 }
327 }
328 transposed
329 }
330
331 #[test]
332 fn transpose_square() {
333 // Loop over matrix sizes:
334 // Each size is of the form 2^log_size × 2^log_size
335 for log_size in 1..=10 {
336 // Compute the actual dimension: size = 2^log_size
337 let size = 1 << log_size;
338
339 // Generate a flat matrix of size×size elements
340 let mut mat = generate_matrix(log_size);
341
342 // Compute the reference result using a naive transpose implementation
343 let expected = transpose_reference(&mat, log_size);
344
345 // Perform the in-place transpose on `mat`.
346 unsafe {
347 transpose_in_place_square(&mut mat, log_size, log_size, 0);
348 }
349
350 // Compare the transposed matrix against the reference.
351 assert_eq!(mat, expected, "Transpose failed for {size}x{size} matrix");
352 }
353 }
354}