p3_util/transpose/rectangular.rs
1//! High-performance matrix transpose for generic `Copy` types.
2//!
3//! This module provides an optimized **out-of-place** matrix transpose.
4//!
5//! # Overview
6//!
7//! Matrix transposition converts a row-major matrix into its column-major equivalent.
8//! For a matrix `A` with dimensions `height × width`:
9//!
10//! ```text
11//! A[i][j] → A^T[j][i]
12//! ```
13//!
14//! In memory (row-major layout), element at position `(row, col)` is stored at:
15//! - **Input**: `input[row * width + col]`
16//! - **Output**: `output[col * height + row]`
17//!
18//! # Architecture-Specific Optimizations
19//!
20//! On **ARM64** (aarch64), this module uses NEON SIMD intrinsics for:
21//! - **4-byte elements** (typical for 32-bit field elements like `MontyField31`, `BabyBear`)
22//! using a 2-stage butterfly (`vtrn1q_u32`/`vtrn2q_u32` then `vtrn1q_u64`/`vtrn2q_u64`).
23//! - **8-byte elements** (typical for 64-bit field elements like `Goldilocks`)
24//! using a simpler 1-stage butterfly (`vtrn1q_u64`/`vtrn2q_u64`) on pairs of registers.
25//!
26//! On other architectures or for other element sizes, it falls back to the `transpose` crate.
27//!
28//! # Key Optimizations
29//!
30//! ## NEON SIMD Registers (128-bit)
31//!
32//! ARM64 NEON provides 32 vector registers, each holding 128 bits.
33//! - For 32-bit (4-byte) elements, each register holds exactly **`BLOCK_SIZE` elements**.
34//! - A `BLOCK_SIZE`×`BLOCK_SIZE` block (`BLOCK_SIZE`^2 elements) fits perfectly in **`BLOCK_SIZE` registers**.
35//!
36//! ```text
37//! ┌─────────────────────────────────┐
38//! │ q0 = [ a00, a01, a02, a03 ] │ ← 128 bits = 4 × 32-bit
39//! │ q1 = [ a10, a11, a12, a13 ] │
40//! │ q2 = [ a20, a21, a22, a23 ] │
41//! │ q3 = [ a30, a31, a32, a33 ] │
42//! └─────────────────────────────────┘
43//! ```
44//!
45//! ## In-Register Transpose (Butterfly Network)
46//!
47//! We transpose a `BLOCK_SIZE`×`BLOCK_SIZE` block entirely in registers using a 2-stage butterfly:
48//!
49//! **Stage 1**: Swap pairs of 32-bit elements using `TRN1`/`TRN2`
50//! **Stage 2**: Swap pairs of 64-bit elements using `TRN1`/`TRN2` on reinterpreted u64
51//!
52//! ```text
53//! Input: After Stage 1: After Stage 2 (Output):
54//! ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
55//! │ a b │ c d │ │ a e │ c g │ │ a e │ i m │
56//! │ e f │ g h │ │ b f │ d h │ │ b f │ j n │
57//! │─────┼───────│ │─────┼───────│ │─────┼───────│
58//! │ i j │ k l │ │ i m │ k o │ │ c g │ k o │
59//! │ m n │ o p │ │ j n │ l p │ │ d h │ l p │
60//! └─────────────┘ └─────────────┘ └─────────────┘
61//! ```
62//!
63//! ## Multi-Level Tiling Strategy
64//!
65//! Different strategies for different matrix sizes:
66//! - **Small (<`SMALL_LEN` elements)**: Scalar transpose - fits in L1, no overhead
67//! - **Medium (<`MEDIUM_LEN` elements)**: `TILE_SIZE`×`TILE_SIZE` Tiled - L2-friendly tiles
68//! - **Large (≥`MEDIUM_LEN` elements)**: Recursive + Tiled - Cache-oblivious
69
70#[cfg(target_arch = "aarch64")]
71use core::arch::aarch64::*;
72#[cfg(target_arch = "aarch64")]
73use core::mem::MaybeUninit;
74#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
75use core::sync::atomic::{AtomicUsize, Ordering};
76
77/// Software prefetch for write (PRFM PSTL1KEEP).
78///
79/// Brings the cache line containing `ptr` into the L1 data cache in exclusive
80/// state, preparing for a subsequent store. This avoids Read-For-Ownership
81/// (RFO) stalls when writing to memory not already in L1.
82#[cfg(target_arch = "aarch64")]
83#[inline(always)]
84unsafe fn prefetch_write(ptr: *const u8) {
85 // PRFM PSTL1KEEP: Prefetch for Store, L1 cache, temporal (keep in cache).
86 unsafe {
87 core::arch::asm!(
88 "prfm pstl1keep, [{ptr}]",
89 ptr = in(reg) ptr,
90 options(readonly, nostack, preserves_flags),
91 );
92 }
93}
94
95/// Maximum number of elements for the simple scalar transpose.
96///
97/// For matrices with fewer than `SMALL_LEN` elements (~1KB for 4-byte elements),
98/// the overhead of tiling isn't worth it.
99///
100/// Direct element-by-element copy is faster because:
101/// - The entire matrix fits in L1 cache (32-64KB on most CPUs)
102/// - No tile boundary calculations needed
103/// - Branch prediction works well for small loops
104#[cfg(any(target_arch = "aarch64", test))]
105const SMALL_LEN: usize = 255;
106
107/// Maximum number of elements for the single-level tiled transpose.
108///
109/// For matrices up to `MEDIUM_LEN` elements (~4MB for 4-byte elements), we use
110/// a simple tiled approach with `TILE_SIZE`×`TILE_SIZE` tiles.
111///
112/// This fits comfortably within L2 cache (256KB-512KB) with good spatial locality.
113///
114/// Beyond this threshold, we switch to recursive subdivision to ensure
115/// cache-oblivious behavior for very large matrices.
116#[cfg(any(target_arch = "aarch64", test))]
117const MEDIUM_LEN: usize = 1024 * 1024;
118
119/// Side length of a tile in elements.
120///
121/// We use `TILE_SIZE`×`TILE_SIZE` tiles because:
122/// - `TILE_SIZE`×`TILE_SIZE` × 4 bytes = 1KB per tile, fitting in L1 cache
123/// - `TILE_SIZE` is divisible by `BLOCK_SIZE`, allowing exactly (`TILE_SIZE`/`BLOCK_SIZE`)^2 NEON blocks per tile
124/// - Good balance between tile overhead and cache utilization
125#[cfg(any(target_arch = "aarch64", test))]
126const TILE_SIZE: usize = 16;
127
128/// Maximum dimension for recursive base case.
129///
130/// When recursively subdividing large matrices, we stop when both
131/// dimensions are ≤ `RECURSIVE_LIMIT` elements.
132///
133/// At this point, the sub-matrix (up to `RECURSIVE_LIMIT`×`RECURSIVE_LIMIT` elements)
134/// fits in L2 cache, so we switch to tiled transpose.
135#[cfg(target_arch = "aarch64")]
136const RECURSIVE_LIMIT: usize = 128;
137
138/// Minimum number of elements before enabling parallel processing.
139///
140/// Parallel transpose only pays off for large matrices because:
141/// - Thread spawn/join overhead (~1-10μs)
142/// - Cache coherency traffic between cores
143/// - Memory bandwidth becomes the bottleneck, not compute
144///
145/// At `PARALLEL_THRESHOLD` elements, the work per thread is large enough that
146/// parallelism overhead is amortized.
147#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
148const PARALLEL_THRESHOLD: usize = 4 * 1024 * 1024;
149
150/// Transpose a matrix from row-major `input` to row-major `output`.
151///
152/// Given an input matrix with `height` rows and `width` columns, produces
153/// an output matrix with `width` rows and `height` columns.
154///
155/// # Memory Layout
156///
157/// Both input and output are stored in **row-major order**.
158///
159/// ```text
160/// Input (height=2, width=3): Output (height=3, width=2):
161///
162/// Row 0: [ a, b, c ] Row 0: [ a, d ]
163/// Row 1: [ d, e, f ] Row 1: [ b, e ]
164/// Row 2: [ c, f ]
165///
166/// Memory: [a, b, c, d, e, f] Memory: [a, d, b, e, c, f]
167/// ```
168///
169/// # Index Transformation
170///
171/// - Initial element at `input[row * width + col]`,
172/// - Transposed position is `output[col * height + row]`.
173///
174/// # Arguments
175///
176/// * `input` - Source matrix in row-major order
177/// * `output` - Destination buffer in row-major order
178/// * `width` - Number of columns in the input matrix
179/// * `height` - Number of rows in the input matrix
180///
181/// # Panics
182///
183/// Panics if:
184/// - `input.len() != width * height`
185/// - `output.len() != width * height`
186#[inline]
187pub fn transpose<T: Copy + Send + Sync>(
188 input: &[T],
189 output: &mut [T],
190 width: usize,
191 height: usize,
192) {
193 // Input validation
194 assert_eq!(
195 input.len(),
196 width * height,
197 "Input length {} doesn't match width*height = {}",
198 input.len(),
199 width * height
200 );
201 assert_eq!(
202 output.len(),
203 width * height,
204 "Output length {} doesn't match width*height = {}",
205 output.len(),
206 width * height
207 );
208
209 // Handle empty matrices
210 if width == 0 || height == 0 {
211 return;
212 }
213
214 // Architecture dispatch
215 #[cfg(target_arch = "aarch64")]
216 {
217 // Use NEON-optimized path for 4-byte elements.
218 //
219 // This covers common field types like MontyField31.
220 if core::mem::size_of::<T>() == 4 {
221 // SAFETY:
222 // - input/output lengths verified above
223 // - T is 4 bytes, matching u32 size and alignment
224 // - Pointers derived from valid slices
225 unsafe {
226 transpose_neon_4b(
227 input.as_ptr().cast::<u32>(),
228 output.as_mut_ptr().cast::<u32>(),
229 width,
230 height,
231 );
232 }
233 return;
234 }
235
236 // Use NEON-optimized path for 8-byte elements.
237 //
238 // This covers 64-bit field types like Goldilocks.
239 // A 128-bit NEON register holds 2 u64 elements, so we use
240 // pairs of registers per row and a 1-stage butterfly.
241 if core::mem::size_of::<T>() == 8 {
242 // SAFETY:
243 // - input/output lengths verified above
244 // - T is 8 bytes, matching u64 size and alignment
245 // - Pointers derived from valid slices
246 unsafe {
247 transpose_neon_8b(
248 input.as_ptr().cast::<u64>(),
249 output.as_mut_ptr().cast::<u64>(),
250 width,
251 height,
252 );
253 }
254 return;
255 }
256 }
257
258 // Fallback for non-ARM64 or unsupported element sizes.
259 transpose::transpose(input, output, width, height);
260}
261
262/// Top-level NEON transpose dispatcher for 4-byte elements.
263///
264/// Selects the appropriate strategy based on matrix size:
265///
266/// ```text
267/// ┌───────────────────────────────────────────────────────────────────────────────────┐
268/// │ transpose_neon_4b │
269/// │ │ │
270/// │ ┌────────────────────────────────┼───────────────────────────────┐ │
271/// │ ▼ ▼ ▼ │
272/// │ len < SMALL_LEN SMALL_LEN ≤ len < MEDIUM_LEN len ≥ MEDIUM_LEN │
273/// │ │ │ │ │
274/// │ ▼ ▼ ▼ │
275/// │ scalar tiled TILE_SIZE×TILE_SIZE recursive │
276/// │ │ (→ tiled at │
277/// │ │ leaves) │
278/// │ │ │ │
279/// │ └───────────────┬───────────────┘ │
280/// │ ▼ │
281/// │ parallel (if ≥ PARALLEL_THRESHOLD │
282/// │ and feature enabled) │
283/// └───────────────────────────────────────────────────────────────────────────────────┘
284/// ```
285///
286/// # Safety
287///
288/// Caller must ensure `input` and `output` point to valid memory regions
289/// of at least `width * height` elements each.
290#[cfg(target_arch = "aarch64")]
291#[inline]
292unsafe fn transpose_neon_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
293 // Total number of elements in the matrix.
294 let len = width * height;
295
296 #[cfg(feature = "parallel")]
297 {
298 // Parallel path (if enabled and matrix is large enough)
299 if len >= PARALLEL_THRESHOLD {
300 // SAFETY: Caller guarantees valid pointers.
301 unsafe {
302 transpose_neon_4b_parallel(input, output, width, height);
303 }
304 return;
305 }
306 }
307
308 // Sequential path - choose strategy based on size
309 if len <= SMALL_LEN {
310 // Small matrix: simple scalar transpose.
311 //
312 // SAFETY: Caller guarantees valid pointers.
313 unsafe {
314 transpose_small_4b(input, output, width, height);
315 }
316 } else if len <= MEDIUM_LEN {
317 // Medium matrix: single-level `TILE_SIZE`×`TILE_SIZE` tiling.
318 //
319 // SAFETY: Caller guarantees valid pointers.
320 unsafe {
321 transpose_tiled_4b(input, output, width, height);
322 }
323 } else {
324 // Large matrix: recursive subdivision then tiling.
325 //
326 // This is the cache-oblivious approach.
327 // SAFETY: Caller guarantees valid pointers.
328 unsafe {
329 transpose_recursive_4b(input, output, 0, height, 0, width, width, height);
330 }
331 }
332}
333
334/// Parallel transpose for very large matrices (≥ `PARALLEL_THRESHOLD` elements).
335///
336/// Divides the matrix into horizontal stripes, one per thread.
337/// Each thread processes its stripe independently using the tiled algorithm.
338///
339/// # Stripe Division
340///
341/// ```text
342/// ┌─────────────────────────────────┐
343/// │ Thread 0 │ rows [0, rows_per_thread)
344/// ├─────────────────────────────────┤
345/// │ Thread 1 │ rows [rows_per_thread, 2*rows_per_thread)
346/// ├─────────────────────────────────┤
347/// │ Thread 2 │ ...
348/// ├─────────────────────────────────┤
349/// │ ... │
350/// └─────────────────────────────────┘
351/// ```
352///
353/// # Data Race Safety
354///
355/// Each thread writes to a disjoint portion of the output:
356/// - Thread processing rows `[r_start, r_end)` writes to columns `[r_start, r_end)`
357/// of the transposed output.
358/// - No synchronization needed beyond the initial stripe assignment.
359///
360/// # Safety
361///
362/// Caller must ensure valid pointers for `width * height` elements.
363#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
364#[inline]
365unsafe fn transpose_neon_4b_parallel(
366 input: *const u32,
367 output: *mut u32,
368 width: usize,
369 height: usize,
370) {
371 use rayon::prelude::*;
372
373 // Compute stripe sizes
374
375 // Number of available threads in the rayon thread pool.
376 let num_threads = rayon::current_num_threads();
377
378 // Divide rows as evenly as possible among threads.
379 //
380 // Ceiling ensures the last thread doesn't get an oversized chunk.
381 let rows_per_thread = height.div_ceil(num_threads);
382
383 // Share pointers across threads
384
385 // We use `AtomicUsize` to pass pointer addresses to threads.
386 //
387 // This is safe because:
388 // 1. We only read the addresses (Relaxed ordering is fine)
389 // 2. Each thread writes to disjoint output regions
390 let inp = AtomicUsize::new(input as usize);
391 let out = AtomicUsize::new(output as usize);
392
393 // Parallel stripe processing
394 (0..num_threads).into_par_iter().for_each(|thread_idx| {
395 // Compute this thread's row range.
396 let row_start = thread_idx * rows_per_thread;
397 let row_end = (row_start + rows_per_thread).min(height);
398
399 // Skip if this thread has no work (can happen with more threads than rows).
400 if row_start < row_end {
401 // Recover pointers from atomic storage.
402 let input_ptr = inp.load(Ordering::Relaxed) as *const u32;
403 let output_ptr = out.load(Ordering::Relaxed) as *mut u32;
404
405 // SAFETY:
406 // - Pointers are valid (from caller)
407 // - Each thread writes to disjoint output columns
408 unsafe {
409 transpose_region_tiled_4b(
410 input_ptr, output_ptr, row_start, row_end, 0, width, width, height,
411 );
412 }
413 }
414 });
415}
416
417/// Simple element-by-element transpose for small matrices.
418///
419/// For matrices with <= `SMALL_LEN` elements, the overhead of tiling isn't justified.
420/// Direct copying with good cache behavior is faster.
421///
422/// # Algorithm
423///
424/// For each position `(x, y)`:
425/// - Read from `input[y * width + x]`
426/// - Write to `output[x * height + y]`
427///
428/// # Loop Order
429///
430/// We iterate `x` in the outer loop to improve **output locality**.
431///
432/// This means consecutive writes go to consecutive memory addresses,
433/// which is better for the write-combining buffers.
434///
435/// # Safety
436///
437/// Caller must ensure valid pointers for `width * height` elements.
438#[cfg(target_arch = "aarch64")]
439#[inline]
440unsafe fn transpose_small_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
441 // Outer loop over columns (output rows).
442 for x in 0..width {
443 // Inner loop over rows (output columns).
444 for y in 0..height {
445 // Input index: row-major position of element (y, x).
446 let input_index = x + y * width;
447
448 // Output index: row-major position of element (x, y).
449 let output_index = y + x * height;
450
451 // SAFETY: Indices are within bounds by loop construction.
452 unsafe {
453 *output.add(output_index) = *input.add(input_index);
454 }
455 }
456 }
457}
458
459/// Tiled transpose using `TILE_SIZE`×`TILE_SIZE` tiles composed of `BLOCK_SIZE`×`BLOCK_SIZE` NEON blocks.
460///
461/// This is important for medium-sized matrices (`SMALL_LEN` to `MEDIUM_LEN` elements).
462///
463/// # Tiling Strategy
464///
465/// - The matrix is divided into `TILE_SIZE`×`TILE_SIZE` tiles.
466/// - Each tile is further divided into (`TILE_SIZE`/`BLOCK_SIZE`)^2 blocks that are transposed using NEON SIMD.
467///
468/// ```text
469/// Matrix (e.g., 64×48):
470/// ┌─────────────────────────────────────────────────────────────────────────────────────┐
471/// │ Tile(0,0) │ Tile(1,0) │ Tile(2,0) │ Tile(3,0) │rem_x │
472/// │ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×.. │ │
473/// ├────────────────────┼────────────────────┼────────────────────┼───────────────┼──────┤
474/// │ Tile(0,1) │ Tile(1,1) │ Tile(2,1) │ Tile(3,1) │rem_x │
475/// │ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×.. │ │
476/// ├────────────────────┼────────────────────┼────────────────────┼───────────────┼──────┤
477/// │ Tile(0,2) │ Tile(1,2) │ Tile(2,2) │ Tile(3,2) │rem_x │
478/// │ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×.. │ │
479/// └────────────────────┴────────────────────┴────────────────────┴───────────────┴──────┘
480/// rem_y
481/// ```
482///
483/// Remainders (`rem_x`, `rem_y`) are handled with scalar transpose.
484///
485/// # Safety
486///
487/// Caller must ensure valid pointers for `width * height` elements.
488#[cfg(target_arch = "aarch64")]
489#[inline]
490unsafe fn transpose_tiled_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
491 // Compute tile counts and remainders
492
493 // Number of complete `TILE_SIZE`×`TILE_SIZE` tiles in each dimension.
494 let x_tile_count = width / TILE_SIZE;
495 let y_tile_count = height / TILE_SIZE;
496
497 // Leftover elements that don't fit in complete tiles.
498 let remainder_x = width - x_tile_count * TILE_SIZE;
499 let remainder_y = height - y_tile_count * TILE_SIZE;
500
501 // Process complete `TILE_SIZE`×`TILE_SIZE` tiles
502
503 // Iterate over tile rows.
504 for y_tile in 0..y_tile_count {
505 // Iterate over tile columns.
506 for x_tile in 0..x_tile_count {
507 // Top-left corner of this tile.
508 let x_start = x_tile * TILE_SIZE;
509 let y_start = y_tile * TILE_SIZE;
510
511 // Transpose this `TILE_SIZE`×`TILE_SIZE` tile
512 //
513 // SAFETY: Tile coordinates are within bounds.
514 unsafe {
515 transpose_tile_16x16_neon(input, output, width, height, x_start, y_start);
516 }
517 }
518
519 // Handle partial column tiles (right edge)
520
521 // Elements in columns [x_tile_count * TILE_SIZE, width) don't form a complete tile.
522 // Use scalar transpose for these.
523 if remainder_x > 0 {
524 // SAFETY: Coordinates are within bounds.
525 unsafe {
526 transpose_block_scalar(
527 input,
528 output,
529 width,
530 height,
531 x_tile_count * TILE_SIZE, // x_start
532 y_tile * TILE_SIZE, // y_start
533 remainder_x, // block_width
534 TILE_SIZE, // block_height
535 );
536 }
537 }
538 }
539
540 // Handle partial row tiles (bottom edge)
541
542 // Elements in rows [y_tile_count * TILE_SIZE, height) don't form complete tiles.
543 if remainder_y > 0 {
544 // Process bottom edge tiles (except corner).
545 for x_tile in 0..x_tile_count {
546 // SAFETY: Coordinates are within bounds.
547 unsafe {
548 transpose_block_scalar(
549 input,
550 output,
551 width,
552 height,
553 x_tile * TILE_SIZE, // x_start
554 y_tile_count * TILE_SIZE, // y_start
555 TILE_SIZE, // block_width
556 remainder_y, // block_height
557 );
558 }
559 }
560
561 // Handle corner block (bottom-right)
562
563 // The corner block is the intersection of right and bottom remainders.
564 if remainder_x > 0 {
565 // SAFETY: Coordinates are within bounds.
566 unsafe {
567 transpose_block_scalar(
568 input,
569 output,
570 width,
571 height,
572 x_tile_count * TILE_SIZE, // x_start
573 y_tile_count * TILE_SIZE, // y_start
574 remainder_x, // block_width
575 remainder_y, // block_height
576 );
577 }
578 }
579 }
580}
581
582/// Recursive cache-oblivious transpose for large matrices.
583///
584/// This algorithm recursively subdivides the matrix until sub-blocks fit
585/// in cache, then uses tiled transpose on the leaves.
586///
587/// # Cache-Oblivious Design
588///
589/// The key insight is that we don't need to know cache sizes explicitly.
590///
591/// By recursively halving the problem, we eventually reach a size that
592/// fits in any level of cache (L1, L2, or L3).
593///
594/// # Recursion Pattern
595///
596/// At each level, we split along the **longer dimension**:
597///
598/// ```text
599/// Wide matrix (cols > rows): Tall matrix (rows ≥ cols):
600/// Split vertically Split horizontally
601///
602/// ┌─────────┬─────────┐ ┌───────────────────┐
603/// │ │ │ │ │
604/// │ Left │ Right │ │ Top │
605/// │ │ │ │ │
606/// │ │ │ ├───────────────────┤
607/// │ │ │ │ │
608/// └─────────┴─────────┘ │ Bottom │
609/// │ │
610/// └───────────────────┘
611/// ```
612///
613/// # Base Case
614///
615/// We stop recursing when both dimensions are ≤ `RECURSIVE_LIMIT` elements (or ≤ 2,
616/// which is a degenerate case). At this point, the sub-matrix fits in
617/// L2 cache (~64KB for `RECURSIVE_LIMIT`×`RECURSIVE_LIMIT`×4 bytes), so we use tiled transpose.
618///
619/// # Parameters
620///
621/// The function uses coordinate ranges rather than creating sub-arrays:
622/// - `row_start..row_end`: Row range in the original matrix
623/// - `col_start..col_end`: Column range in the original matrix
624/// - `total_cols`, `total_rows`: Original matrix dimensions (for stride calculations)
625///
626/// # Safety
627///
628/// Caller must ensure valid pointers and that coordinate ranges are within bounds.
629#[cfg(target_arch = "aarch64")]
630#[allow(clippy::too_many_arguments)]
631unsafe fn transpose_recursive_4b(
632 input: *const u32,
633 output: *mut u32,
634 row_start: usize,
635 row_end: usize,
636 col_start: usize,
637 col_end: usize,
638 total_cols: usize,
639 total_rows: usize,
640) {
641 // Compute sub-matrix dimensions
642 let nbr_rows = row_end - row_start;
643 let nbr_cols = col_end - col_start;
644
645 // Base case: small enough to use tiled transpose
646
647 // Stop recursing when:
648 // 1. Both dimensions ≤ RECURSIVE_LIMIT (fits in cache), OR
649 // 2. Either dimension ≤ 2 (degenerate case, no benefit from recursion)
650 if (nbr_rows <= RECURSIVE_LIMIT && nbr_cols <= RECURSIVE_LIMIT)
651 || nbr_rows <= 2
652 || nbr_cols <= 2
653 {
654 // SAFETY: Caller ensures valid pointers and bounds.
655 unsafe {
656 transpose_region_tiled_4b(
657 input, output, row_start, row_end, col_start, col_end, total_cols, total_rows,
658 );
659 }
660 return;
661 }
662
663 // Recursive case: split along the longer dimension
664 if nbr_rows >= nbr_cols {
665 // Split horizontally (by rows)
666
667 // Midpoint of the row range.
668 let mid = row_start + (nbr_rows / 2);
669
670 // Recurse on top half.
671 // SAFETY: mid is within [row_start, row_end].
672 unsafe {
673 transpose_recursive_4b(
674 input, output, row_start, mid, col_start, col_end, total_cols, total_rows,
675 );
676 }
677
678 // Recurse on bottom half.
679 // SAFETY: mid is within [row_start, row_end].
680 unsafe {
681 transpose_recursive_4b(
682 input, output, mid, row_end, col_start, col_end, total_cols, total_rows,
683 );
684 }
685 } else {
686 // Split vertically (by columns)
687
688 // Midpoint of the column range.
689 let mid = col_start + (nbr_cols / 2);
690
691 // Recurse on left half.
692 // SAFETY: mid is within [col_start, col_end].
693 unsafe {
694 transpose_recursive_4b(
695 input, output, row_start, row_end, col_start, mid, total_cols, total_rows,
696 );
697 }
698
699 // Recurse on right half.
700 // SAFETY: mid is within [col_start, col_end].
701 unsafe {
702 transpose_recursive_4b(
703 input, output, row_start, row_end, mid, col_end, total_cols, total_rows,
704 );
705 }
706 }
707}
708
709/// Tiled transpose for a rectangular region within a larger matrix.
710///
711/// It operates on a sub-region defined by coordinate ranges.
712///
713/// Used as the base case of recursive transpose and for parallel stripe processing.
714///
715/// # Coordinate System
716///
717/// ```text
718/// Original matrix (total_cols × total_rows):
719/// ┌─────────────────────────────────────────────────────────┐
720/// │ │
721/// │ (col_start, row_start) │
722/// │ ┌─────────────────────┐ │
723/// │ │ │ │
724/// │ │ Region to │ │
725/// │ │ transpose │ │
726/// │ │ │ │
727/// │ └─────────────────────┘ │
728/// │ (col_end, row_end) │
729/// │ │
730/// └─────────────────────────────────────────────────────────┘
731/// ```
732///
733/// # Safety
734///
735/// Caller must ensure:
736/// - Valid pointers for `total_cols * total_rows` elements
737/// - `row_start < row_end <= total_rows`
738/// - `col_start < col_end <= total_cols`
739#[cfg(target_arch = "aarch64")]
740#[inline]
741#[allow(clippy::too_many_arguments)]
742unsafe fn transpose_region_tiled_4b(
743 input: *const u32,
744 output: *mut u32,
745 row_start: usize,
746 row_end: usize,
747 col_start: usize,
748 col_end: usize,
749 total_cols: usize,
750 total_rows: usize,
751) {
752 // Compute region dimensions and tile counts
753
754 // Dimensions of the region to transpose.
755 let nbr_cols = col_end - col_start;
756 let nbr_rows = row_end - row_start;
757
758 // Number of complete `TILE_SIZE`×`TILE_SIZE` tiles in each dimension.
759 let x_tile_count = nbr_cols / TILE_SIZE;
760 let y_tile_count = nbr_rows / TILE_SIZE;
761
762 // Leftover elements that don't fit in complete tiles.
763 let remainder_x = nbr_cols - x_tile_count * TILE_SIZE;
764 let remainder_y = nbr_rows - y_tile_count * TILE_SIZE;
765
766 // Process complete `TILE_SIZE`×`TILE_SIZE` tiles
767 for y_tile in 0..y_tile_count {
768 for x_tile in 0..x_tile_count {
769 // Coordinates of this tile's top-left corner in the original matrix.
770 let col = col_start + x_tile * TILE_SIZE;
771 let row = row_start + y_tile * TILE_SIZE;
772
773 // SAFETY: Tile coordinates are within the region bounds.
774 // Uses the buffered tile function: for large matrices the output
775 // is likely in L3/RAM, so L1 buffering + write prefetching avoids
776 // RFO stalls on scattered output writes.
777 unsafe {
778 transpose_tile_16x16_neon_buffered(input, output, total_cols, total_rows, col, row);
779 }
780 }
781
782 // Handle partial column tiles (right edge of region)
783 if remainder_x > 0 {
784 // SAFETY: Coordinates are within region bounds.
785 unsafe {
786 transpose_block_scalar(
787 input,
788 output,
789 total_cols,
790 total_rows,
791 col_start + x_tile_count * TILE_SIZE, // x_start
792 row_start + y_tile * TILE_SIZE, // y_start
793 remainder_x, // block_width
794 TILE_SIZE, // block_height
795 );
796 }
797 }
798 }
799
800 // Handle partial row tiles (bottom edge of region)
801 if remainder_y > 0 {
802 for x_tile in 0..x_tile_count {
803 // SAFETY: Coordinates are within region bounds.
804 unsafe {
805 transpose_block_scalar(
806 input,
807 output,
808 total_cols,
809 total_rows,
810 col_start + x_tile * TILE_SIZE, // x_start
811 row_start + y_tile_count * TILE_SIZE, // y_start
812 TILE_SIZE, // block_width
813 remainder_y, // block_height
814 );
815 }
816 }
817
818 // Handle corner block (bottom-right of region)
819 if remainder_x > 0 {
820 // SAFETY: Coordinates are within region bounds.
821 unsafe {
822 transpose_block_scalar(
823 input,
824 output,
825 total_cols,
826 total_rows,
827 col_start + x_tile_count * TILE_SIZE, // x_start
828 row_start + y_tile_count * TILE_SIZE, // y_start
829 remainder_x, // block_width
830 remainder_y, // block_height
831 );
832 }
833 }
834 }
835}
836
837/// Transpose a complete 16×16 tile using NEON SIMD (direct-to-output).
838///
839/// A 16×16 tile is processed as a 4×4 grid of 4×4 NEON blocks.
840/// This function is fully unrolled for maximum performance.
841///
842/// Used by the **medium tiled path** where the output likely fits in L2 cache
843/// and the overhead of L1 buffering isn't justified.
844///
845/// # Safety
846///
847/// Caller must ensure:
848/// - Valid pointers for the full matrix
849/// - `x_start + 16 <= width`
850/// - `y_start + 16 <= height`
851#[cfg(target_arch = "aarch64")]
852#[inline]
853unsafe fn transpose_tile_16x16_neon(
854 input: *const u32,
855 output: *mut u32,
856 width: usize,
857 height: usize,
858 x_start: usize,
859 y_start: usize,
860) {
861 unsafe {
862 // Block Row 0 (input rows y_start..y_start+4)
863 let inp = input.add(y_start * width + x_start);
864 let out = output.add(x_start * height + y_start);
865 transpose_4x4_neon(inp, out, width, height);
866 transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
867 transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
868 transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
869
870 // Block Row 1 (input rows y_start+4..y_start+8)
871 let inp = input.add((y_start + 4) * width + x_start);
872 let out = output.add(x_start * height + y_start + 4);
873 transpose_4x4_neon(inp, out, width, height);
874 transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
875 transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
876 transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
877
878 // Block Row 2 (input rows y_start+8..y_start+12)
879 let inp = input.add((y_start + 8) * width + x_start);
880 let out = output.add(x_start * height + y_start + 8);
881 transpose_4x4_neon(inp, out, width, height);
882 transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
883 transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
884 transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
885
886 // Block Row 3 (input rows y_start+12..y_start+16)
887 let inp = input.add((y_start + 12) * width + x_start);
888 let out = output.add(x_start * height + y_start + 12);
889 transpose_4x4_neon(inp, out, width, height);
890 transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
891 transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
892 transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
893 }
894}
895
896/// Transpose a complete 16×16 tile using NEON SIMD with L1 buffering.
897///
898/// Same grid of 4×4 NEON blocks, but transposed into a stack-allocated buffer
899/// first, then flushed to the output with write prefetching (`PRFM PSTL1KEEP`).
900///
901/// Used by the **recursive/parallel path** for large matrices (≥ `MEDIUM_LEN`)
902/// where the output is in L3/RAM and direct scattered writes would stall on
903/// Read-For-Ownership (RFO) cache line fetches.
904///
905/// # Safety
906///
907/// Caller must ensure:
908/// - Valid pointers for the full matrix
909/// - `x_start + 16 <= width`
910/// - `y_start + 16 <= height`
911#[cfg(target_arch = "aarch64")]
912#[inline]
913unsafe fn transpose_tile_16x16_neon_buffered(
914 input: *const u32,
915 output: *mut u32,
916 width: usize,
917 height: usize,
918 x_start: usize,
919 y_start: usize,
920) {
921 // Stack buffer for L1-hot transpose (1 KB for u32).
922 // MaybeUninit avoids unnecessary zero-initialization; every element
923 // is written by the NEON blocks before the copy reads it.
924 let mut buffer = MaybeUninit::<[u32; TILE_SIZE * TILE_SIZE]>::uninit();
925 let buf = buffer.as_mut_ptr().cast::<u32>();
926
927 unsafe {
928 // Transpose 4×4 grid of NEON blocks into the buffer.
929 // Buffer layout: buf[col * TILE_SIZE + row] = transposed element.
930 // Buffer write stride is TILE_SIZE (contiguous in L1) vs. `height` (scattered).
931
932 // Block Row 0 (input rows y_start..y_start+4)
933 let inp = input.add(y_start * width + x_start);
934 transpose_4x4_neon(inp, buf, width, TILE_SIZE);
935 transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE), width, TILE_SIZE);
936 transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE), width, TILE_SIZE);
937 transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE), width, TILE_SIZE);
938
939 // Block Row 1 (input rows y_start+4..y_start+8)
940 let inp = input.add((y_start + 4) * width + x_start);
941 transpose_4x4_neon(inp, buf.add(4), width, TILE_SIZE);
942 transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 4), width, TILE_SIZE);
943 transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 4), width, TILE_SIZE);
944 transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 4), width, TILE_SIZE);
945
946 // Block Row 2 (input rows y_start+8..y_start+12)
947 let inp = input.add((y_start + 8) * width + x_start);
948 transpose_4x4_neon(inp, buf.add(8), width, TILE_SIZE);
949 transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 8), width, TILE_SIZE);
950 transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 8), width, TILE_SIZE);
951 transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 8), width, TILE_SIZE);
952
953 // Block Row 3 (input rows y_start+12..y_start+16)
954 let inp = input.add((y_start + 12) * width + x_start);
955 transpose_4x4_neon(inp, buf.add(12), width, TILE_SIZE);
956 transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 12), width, TILE_SIZE);
957 transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 12), width, TILE_SIZE);
958 transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 12), width, TILE_SIZE);
959
960 // Flush buffer to output with write prefetching.
961 // Each iteration copies TILE_SIZE u32s (64 bytes = 1 cache line) from the
962 // L1-hot buffer to one output row. Prefetch brings the next output cache
963 // line into exclusive state, avoiding RFO stalls.
964 prefetch_write(output.add(x_start * height + y_start) as *const u8);
965 for c in 0..TILE_SIZE {
966 if c + 1 < TILE_SIZE {
967 prefetch_write(output.add((x_start + c + 1) * height + y_start) as *const u8);
968 }
969 core::ptr::copy_nonoverlapping(
970 buf.add(c * TILE_SIZE),
971 output.add((x_start + c) * height + y_start),
972 TILE_SIZE,
973 );
974 }
975 }
976}
977
978/// Scalar transpose for an arbitrary rectangular block.
979///
980/// Used for handling edge cases where dimensions don't align to tile boundaries.
981/// Falls back to simple element-by-element copying.
982///
983/// # When Used
984///
985/// - Right edge: `block_width < TILE_SIZE`
986/// - Bottom edge: `block_height < TILE_SIZE`
987/// - Bottom-right corner: both dimensions < `TILE_SIZE`
988///
989/// # Safety
990///
991/// Caller must ensure:
992/// - Valid pointers for the full matrix
993/// - `x_start + block_width <= width`
994/// - `y_start + block_height <= height`
995#[cfg(target_arch = "aarch64")]
996#[inline]
997#[allow(clippy::too_many_arguments)]
998unsafe fn transpose_block_scalar(
999 input: *const u32,
1000 output: *mut u32,
1001 width: usize,
1002 height: usize,
1003 x_start: usize,
1004 y_start: usize,
1005 block_width: usize,
1006 block_height: usize,
1007) {
1008 // Iterate over block columns (becomes output rows).
1009 for inner_x in 0..block_width {
1010 // Iterate over block rows (becomes output columns).
1011 for inner_y in 0..block_height {
1012 // Absolute coordinates in the original matrix.
1013 let x = x_start + inner_x;
1014 let y = y_start + inner_y;
1015
1016 // Input index: row-major position of (y, x).
1017 let input_index = x + y * width;
1018
1019 // Output index: row-major position of (x, y) in transposed matrix.
1020 let output_index = y + x * height;
1021
1022 // SAFETY: Indices are within bounds by construction.
1023 unsafe {
1024 *output.add(output_index) = *input.add(input_index);
1025 }
1026 }
1027 }
1028}
1029
1030/// Transpose a 4×4 block of 32-bit elements using NEON SIMD.
1031///
1032/// This is the fundamental building block of the entire transpose algorithm.
1033///
1034/// It transposes a 4×4 block entirely within NEON registers
1035/// using a two-stage butterfly network.
1036///
1037/// # Memory Layout
1038///
1039/// Input (4 rows, stride = `src_stride`):
1040/// ```text
1041/// src + 0*stride: [ a00, a01, a02, a03 ] → q0
1042/// src + 1*stride: [ a10, a11, a12, a13 ] → q1
1043/// src + 2*stride: [ a20, a21, a22, a23 ] → q2
1044/// src + 3*stride: [ a30, a31, a32, a33 ] → q3
1045/// ```
1046///
1047/// Output (4 rows, stride = `dst_stride`):
1048/// ```text
1049/// dst + 0*stride: [ a00, a10, a20, a30 ] ← r0
1050/// dst + 1*stride: [ a01, a11, a21, a31 ] ← r1
1051/// dst + 2*stride: [ a02, a12, a22, a32 ] ← r2
1052/// dst + 3*stride: [ a03, a13, a23, a33 ] ← r3
1053/// ```
1054///
1055/// # Butterfly Network Algorithm
1056///
1057/// The transpose is performed in two stages using `TRN1`/`TRN2` instructions:
1058///
1059/// ## Stage 1: 32-bit Transpose
1060///
1061/// - `TRN1` takes **even-indexed** elements,
1062/// - `TRN2` takes **odd-indexed** elements.
1063///
1064/// ```text
1065/// TRN1(q0, q1) = [ a00, a10, a02, a12 ] (even indices: 0, 2)
1066/// TRN2(q0, q1) = [ a01, a11, a03, a13 ] (odd indices: 1, 3)
1067/// TRN1(q2, q3) = [ a20, a30, a22, a32 ]
1068/// TRN2(q2, q3) = [ a21, a31, a23, a33 ]
1069/// ```
1070///
1071/// ## Stage 2: 64-bit Transpose
1072///
1073/// Reinterpret as 64-bit elements and transpose again:
1074///
1075/// ```text
1076/// TRN1_64(t0, t2) = [ a00, a10 | a20, a30 ] → r0
1077/// TRN2_64(t0, t2) = [ a02, a12 | a22, a32 ] → r2
1078/// TRN1_64(t1, t3) = [ a01, a11 | a21, a31 ] → r1
1079/// TRN2_64(t1, t3) = [ a03, a13 | a23, a33 ] → r3
1080/// ```
1081///
1082/// # Explanations
1083///
1084/// The butterfly network swaps elements at progressively larger distances:
1085/// - Stage 1: Swaps elements 1 apart (within 64-bit pairs)
1086/// - Stage 2: Swaps elements 2 apart (between 64-bit halves)
1087///
1088/// This is analogous to the bit-reversal pattern in FFT algorithms.
1089///
1090/// # Performance
1091///
1092/// - **4 loads** (vld1q_u32): 4 cycles
1093/// - **8 permutes** (vtrn): ~8 cycles (pipelined)
1094/// - **4 stores** (vst1q_u32): 4 cycles
1095/// - **Total**: ~16 cycles for 16 elements = **1 cycle/element**
1096///
1097/// # Safety
1098///
1099/// Caller must ensure:
1100/// - `src` is valid for reading 4 rows of `src_stride` elements each
1101/// - `dst` is valid for writing 4 rows of `dst_stride` elements each
1102/// - The first 4 elements of each row are accessible
1103#[cfg(target_arch = "aarch64")]
1104#[inline(always)]
1105unsafe fn transpose_4x4_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
1106 unsafe {
1107 // Phase 1: Load 4 rows into NEON registers
1108 //
1109 // Each vld1q_u32 loads 4 consecutive u32s (16 bytes = 128 bits).
1110 // Total: 64 bytes = one cache line on most ARM64 CPUs.
1111
1112 // Row 0: [a00, a01, a02, a03]
1113 let q0 = vld1q_u32(src);
1114 // Row 1: [a10, a11, a12, a13]
1115 let q1 = vld1q_u32(src.add(src_stride));
1116 // Row 2: [a20, a21, a22, a23]
1117 let q2 = vld1q_u32(src.add(2 * src_stride));
1118 // Row 3: [a30, a31, a32, a33]
1119 let q3 = vld1q_u32(src.add(3 * src_stride));
1120
1121 // Phase 2: Stage 1 - Transpose 2×2 blocks of 32-bit elements
1122 //
1123 // vtrn1q_u32(a, b): Takes elements at even indices from a and b
1124 // - Result: [a[0], b[0], a[2], b[2]]
1125 //
1126 // vtrn2q_u32(a, b): Takes elements at odd indices from a and b
1127 // - Result: [a[1], b[1], a[3], b[3]]
1128
1129 let t0_0 = vtrn1q_u32(q0, q1); // [a00, a10, a02, a12]
1130 let t0_1 = vtrn2q_u32(q0, q1); // [a01, a11, a03, a13]
1131 let t0_2 = vtrn1q_u32(q2, q3); // [a20, a30, a22, a32]
1132 let t0_3 = vtrn2q_u32(q2, q3); // [a21, a31, a23, a33]
1133
1134 // Phase 3: Stage 2 - Transpose 2×2 blocks of 64-bit elements
1135 //
1136 // Reinterpret u32x4 as u64x2, then transpose.
1137 // This swaps the 64-bit halves of the vectors.
1138 //
1139 // vtrn1q_u64(a, b): [a.lo, b.lo]
1140 // vtrn2q_u64(a, b): [a.hi, b.hi]
1141
1142 // r0 = [a00, a10, a20, a30] (column 0 of input → row 0 of output)
1143 let r0 = vreinterpretq_u32_u64(vtrn1q_u64(
1144 vreinterpretq_u64_u32(t0_0),
1145 vreinterpretq_u64_u32(t0_2),
1146 ));
1147
1148 // r2 = [a02, a12, a22, a32] (column 2 of input → row 2 of output)
1149 let r2 = vreinterpretq_u32_u64(vtrn2q_u64(
1150 vreinterpretq_u64_u32(t0_0),
1151 vreinterpretq_u64_u32(t0_2),
1152 ));
1153
1154 // r1 = [a01, a11, a21, a31] (column 1 of input → row 1 of output)
1155 let r1 = vreinterpretq_u32_u64(vtrn1q_u64(
1156 vreinterpretq_u64_u32(t0_1),
1157 vreinterpretq_u64_u32(t0_3),
1158 ));
1159
1160 // r3 = [a03, a13, a23, a33] (column 3 of input → row 3 of output)
1161 let r3 = vreinterpretq_u32_u64(vtrn2q_u64(
1162 vreinterpretq_u64_u32(t0_1),
1163 vreinterpretq_u64_u32(t0_3),
1164 ));
1165
1166 // Phase 4: Store 4 transposed rows
1167 //
1168 // Store row 0 of output
1169 vst1q_u32(dst, r0);
1170 // Store row 1 of output
1171 vst1q_u32(dst.add(dst_stride), r1);
1172 // Store row 2 of output
1173 vst1q_u32(dst.add(2 * dst_stride), r2);
1174 // Store row 3 of output
1175 vst1q_u32(dst.add(3 * dst_stride), r3);
1176 }
1177}
1178
1179// ============================================================================
1180// 8-byte (u64) NEON transpose functions
1181//
1182// These are analogous to the 4-byte functions above, but operate on u64
1183// elements. Since a 128-bit NEON register holds 2 u64 elements, each row
1184// of a 4×4 block requires 2 registers (8 total for a block). The transpose
1185// uses a single-stage butterfly with vtrn1q_u64/vtrn2q_u64 on four 2×2
1186// sub-blocks.
1187// ============================================================================
1188
1189/// Top-level NEON transpose dispatcher for 8-byte elements.
1190///
1191/// Selects the appropriate strategy based on matrix size, mirroring
1192/// `transpose_neon_4b` but for u64 elements.
1193///
1194/// # Safety
1195///
1196/// Caller must ensure `input` and `output` point to valid memory regions
1197/// of at least `width * height` elements each.
1198#[cfg(target_arch = "aarch64")]
1199#[inline]
1200unsafe fn transpose_neon_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
1201 let len = width * height;
1202
1203 #[cfg(feature = "parallel")]
1204 {
1205 if len >= PARALLEL_THRESHOLD {
1206 unsafe {
1207 transpose_neon_8b_parallel(input, output, width, height);
1208 }
1209 return;
1210 }
1211 }
1212
1213 if len <= SMALL_LEN {
1214 unsafe {
1215 transpose_small_8b(input, output, width, height);
1216 }
1217 } else if len <= MEDIUM_LEN {
1218 unsafe {
1219 transpose_tiled_8b(input, output, width, height);
1220 }
1221 } else {
1222 unsafe {
1223 transpose_recursive_8b(input, output, 0, height, 0, width, width, height);
1224 }
1225 }
1226}
1227
1228/// Parallel transpose for very large matrices of 8-byte elements.
1229///
1230/// Divides the matrix into horizontal stripes, one per thread.
1231///
1232/// # Safety
1233///
1234/// Caller must ensure valid pointers for `width * height` elements.
1235#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
1236#[inline]
1237unsafe fn transpose_neon_8b_parallel(
1238 input: *const u64,
1239 output: *mut u64,
1240 width: usize,
1241 height: usize,
1242) {
1243 use rayon::prelude::*;
1244
1245 let num_threads = rayon::current_num_threads();
1246 let rows_per_thread = height.div_ceil(num_threads);
1247
1248 let inp = AtomicUsize::new(input as usize);
1249 let out = AtomicUsize::new(output as usize);
1250
1251 (0..num_threads).into_par_iter().for_each(|thread_idx| {
1252 let row_start = thread_idx * rows_per_thread;
1253 let row_end = (row_start + rows_per_thread).min(height);
1254
1255 if row_start < row_end {
1256 let input_ptr = inp.load(Ordering::Relaxed) as *const u64;
1257 let output_ptr = out.load(Ordering::Relaxed) as *mut u64;
1258
1259 unsafe {
1260 transpose_region_tiled_8b(
1261 input_ptr, output_ptr, row_start, row_end, 0, width, width, height,
1262 );
1263 }
1264 }
1265 });
1266}
1267
1268/// Simple element-by-element transpose for small matrices of 8-byte elements.
1269///
1270/// # Safety
1271///
1272/// Caller must ensure valid pointers for `width * height` elements.
1273#[cfg(target_arch = "aarch64")]
1274#[inline]
1275unsafe fn transpose_small_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
1276 for x in 0..width {
1277 for y in 0..height {
1278 let input_index = x + y * width;
1279 let output_index = y + x * height;
1280
1281 unsafe {
1282 *output.add(output_index) = *input.add(input_index);
1283 }
1284 }
1285 }
1286}
1287
1288/// Tiled transpose using 16×16 tiles for 8-byte elements.
1289///
1290/// # Safety
1291///
1292/// Caller must ensure valid pointers for `width * height` elements.
1293#[cfg(target_arch = "aarch64")]
1294#[inline]
1295unsafe fn transpose_tiled_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
1296 let x_tile_count = width / TILE_SIZE;
1297 let y_tile_count = height / TILE_SIZE;
1298
1299 let remainder_x = width - x_tile_count * TILE_SIZE;
1300 let remainder_y = height - y_tile_count * TILE_SIZE;
1301
1302 // Process complete tiles
1303 for y_tile in 0..y_tile_count {
1304 for x_tile in 0..x_tile_count {
1305 let x_start = x_tile * TILE_SIZE;
1306 let y_start = y_tile * TILE_SIZE;
1307
1308 unsafe {
1309 transpose_tile_16x16_neon_8b(input, output, width, height, x_start, y_start);
1310 }
1311 }
1312
1313 // Right edge remainder
1314 if remainder_x > 0 {
1315 unsafe {
1316 transpose_block_scalar_8b(
1317 input,
1318 output,
1319 width,
1320 height,
1321 x_tile_count * TILE_SIZE,
1322 y_tile * TILE_SIZE,
1323 remainder_x,
1324 TILE_SIZE,
1325 );
1326 }
1327 }
1328 }
1329
1330 // Bottom edge remainder
1331 if remainder_y > 0 {
1332 for x_tile in 0..x_tile_count {
1333 unsafe {
1334 transpose_block_scalar_8b(
1335 input,
1336 output,
1337 width,
1338 height,
1339 x_tile * TILE_SIZE,
1340 y_tile_count * TILE_SIZE,
1341 TILE_SIZE,
1342 remainder_y,
1343 );
1344 }
1345 }
1346
1347 // Bottom-right corner
1348 if remainder_x > 0 {
1349 unsafe {
1350 transpose_block_scalar_8b(
1351 input,
1352 output,
1353 width,
1354 height,
1355 x_tile_count * TILE_SIZE,
1356 y_tile_count * TILE_SIZE,
1357 remainder_x,
1358 remainder_y,
1359 );
1360 }
1361 }
1362 }
1363}
1364
1365/// Recursive cache-oblivious transpose for large matrices of 8-byte elements.
1366///
1367/// # Safety
1368///
1369/// Caller must ensure valid pointers and that coordinate ranges are within bounds.
1370#[cfg(target_arch = "aarch64")]
1371#[allow(clippy::too_many_arguments)]
1372unsafe fn transpose_recursive_8b(
1373 input: *const u64,
1374 output: *mut u64,
1375 row_start: usize,
1376 row_end: usize,
1377 col_start: usize,
1378 col_end: usize,
1379 total_cols: usize,
1380 total_rows: usize,
1381) {
1382 let nbr_rows = row_end - row_start;
1383 let nbr_cols = col_end - col_start;
1384
1385 if (nbr_rows <= RECURSIVE_LIMIT && nbr_cols <= RECURSIVE_LIMIT)
1386 || nbr_rows <= 2
1387 || nbr_cols <= 2
1388 {
1389 unsafe {
1390 transpose_region_tiled_8b(
1391 input, output, row_start, row_end, col_start, col_end, total_cols, total_rows,
1392 );
1393 }
1394 return;
1395 }
1396
1397 if nbr_rows >= nbr_cols {
1398 let mid = row_start + (nbr_rows / 2);
1399
1400 unsafe {
1401 transpose_recursive_8b(
1402 input, output, row_start, mid, col_start, col_end, total_cols, total_rows,
1403 );
1404 }
1405
1406 unsafe {
1407 transpose_recursive_8b(
1408 input, output, mid, row_end, col_start, col_end, total_cols, total_rows,
1409 );
1410 }
1411 } else {
1412 let mid = col_start + (nbr_cols / 2);
1413
1414 unsafe {
1415 transpose_recursive_8b(
1416 input, output, row_start, row_end, col_start, mid, total_cols, total_rows,
1417 );
1418 }
1419
1420 unsafe {
1421 transpose_recursive_8b(
1422 input, output, row_start, row_end, mid, col_end, total_cols, total_rows,
1423 );
1424 }
1425 }
1426}
1427
1428/// Tiled transpose for a rectangular region of 8-byte elements.
1429///
1430/// Used as the base case of recursive transpose and for parallel stripe processing.
1431///
1432/// # Safety
1433///
1434/// Caller must ensure:
1435/// - Valid pointers for `total_cols * total_rows` elements
1436/// - `row_start < row_end <= total_rows`
1437/// - `col_start < col_end <= total_cols`
1438#[cfg(target_arch = "aarch64")]
1439#[inline]
1440#[allow(clippy::too_many_arguments)]
1441unsafe fn transpose_region_tiled_8b(
1442 input: *const u64,
1443 output: *mut u64,
1444 row_start: usize,
1445 row_end: usize,
1446 col_start: usize,
1447 col_end: usize,
1448 total_cols: usize,
1449 total_rows: usize,
1450) {
1451 let nbr_cols = col_end - col_start;
1452 let nbr_rows = row_end - row_start;
1453
1454 let x_tile_count = nbr_cols / TILE_SIZE;
1455 let y_tile_count = nbr_rows / TILE_SIZE;
1456
1457 let remainder_x = nbr_cols - x_tile_count * TILE_SIZE;
1458 let remainder_y = nbr_rows - y_tile_count * TILE_SIZE;
1459
1460 // Process complete tiles
1461 for y_tile in 0..y_tile_count {
1462 for x_tile in 0..x_tile_count {
1463 let col = col_start + x_tile * TILE_SIZE;
1464 let row = row_start + y_tile * TILE_SIZE;
1465
1466 // Uses the buffered tile function: for large matrices the output
1467 // is likely in L3/RAM, so L1 buffering + write prefetching avoids
1468 // RFO stalls on scattered output writes.
1469 unsafe {
1470 transpose_tile_16x16_neon_8b_buffered(
1471 input, output, total_cols, total_rows, col, row,
1472 );
1473 }
1474 }
1475
1476 // Right edge remainder
1477 if remainder_x > 0 {
1478 unsafe {
1479 transpose_block_scalar_8b(
1480 input,
1481 output,
1482 total_cols,
1483 total_rows,
1484 col_start + x_tile_count * TILE_SIZE,
1485 row_start + y_tile * TILE_SIZE,
1486 remainder_x,
1487 TILE_SIZE,
1488 );
1489 }
1490 }
1491 }
1492
1493 // Bottom edge remainder
1494 if remainder_y > 0 {
1495 for x_tile in 0..x_tile_count {
1496 unsafe {
1497 transpose_block_scalar_8b(
1498 input,
1499 output,
1500 total_cols,
1501 total_rows,
1502 col_start + x_tile * TILE_SIZE,
1503 row_start + y_tile_count * TILE_SIZE,
1504 TILE_SIZE,
1505 remainder_y,
1506 );
1507 }
1508 }
1509
1510 // Bottom-right corner
1511 if remainder_x > 0 {
1512 unsafe {
1513 transpose_block_scalar_8b(
1514 input,
1515 output,
1516 total_cols,
1517 total_rows,
1518 col_start + x_tile_count * TILE_SIZE,
1519 row_start + y_tile_count * TILE_SIZE,
1520 remainder_x,
1521 remainder_y,
1522 );
1523 }
1524 }
1525 }
1526}
1527
1528/// Transpose a complete 16×16 tile of 8-byte elements using NEON SIMD (direct-to-output).
1529///
1530/// Used by the **medium tiled path** where the output likely fits in L2 cache.
1531///
1532/// # Safety
1533///
1534/// Caller must ensure:
1535/// - Valid pointers for the full matrix
1536/// - `x_start + 16 <= width`
1537/// - `y_start + 16 <= height`
1538#[cfg(target_arch = "aarch64")]
1539#[inline]
1540unsafe fn transpose_tile_16x16_neon_8b(
1541 input: *const u64,
1542 output: *mut u64,
1543 width: usize,
1544 height: usize,
1545 x_start: usize,
1546 y_start: usize,
1547) {
1548 unsafe {
1549 // Block Row 0 (input rows y_start..y_start+4)
1550 let inp = input.add(y_start * width + x_start);
1551 let out = output.add(x_start * height + y_start);
1552 transpose_4x4_neon_8b(inp, out, width, height);
1553 transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1554 transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1555 transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1556
1557 // Block Row 1 (input rows y_start+4..y_start+8)
1558 let inp = input.add((y_start + 4) * width + x_start);
1559 let out = output.add(x_start * height + y_start + 4);
1560 transpose_4x4_neon_8b(inp, out, width, height);
1561 transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1562 transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1563 transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1564
1565 // Block Row 2 (input rows y_start+8..y_start+12)
1566 let inp = input.add((y_start + 8) * width + x_start);
1567 let out = output.add(x_start * height + y_start + 8);
1568 transpose_4x4_neon_8b(inp, out, width, height);
1569 transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1570 transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1571 transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1572
1573 // Block Row 3 (input rows y_start+12..y_start+16)
1574 let inp = input.add((y_start + 12) * width + x_start);
1575 let out = output.add(x_start * height + y_start + 12);
1576 transpose_4x4_neon_8b(inp, out, width, height);
1577 transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1578 transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1579 transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1580 }
1581}
1582
1583/// Transpose a complete 16×16 tile of 8-byte elements with L1 buffering.
1584///
1585/// Used by the **recursive/parallel path** for large matrices where the output
1586/// is in L3/RAM. L1 buffering + write prefetching avoids RFO stalls.
1587///
1588/// # Safety
1589///
1590/// Caller must ensure:
1591/// - Valid pointers for the full matrix
1592/// - `x_start + 16 <= width`
1593/// - `y_start + 16 <= height`
1594#[cfg(target_arch = "aarch64")]
1595#[inline]
1596unsafe fn transpose_tile_16x16_neon_8b_buffered(
1597 input: *const u64,
1598 output: *mut u64,
1599 width: usize,
1600 height: usize,
1601 x_start: usize,
1602 y_start: usize,
1603) {
1604 // Stack buffer for L1-hot transpose (2 KB for u64).
1605 let mut buffer = MaybeUninit::<[u64; TILE_SIZE * TILE_SIZE]>::uninit();
1606 let buf = buffer.as_mut_ptr().cast::<u64>();
1607
1608 unsafe {
1609 // Transpose 4×4 grid of NEON blocks into the buffer.
1610
1611 // Block Row 0 (input rows y_start..y_start+4)
1612 let inp = input.add(y_start * width + x_start);
1613 transpose_4x4_neon_8b(inp, buf, width, TILE_SIZE);
1614 transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE), width, TILE_SIZE);
1615 transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE), width, TILE_SIZE);
1616 transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE), width, TILE_SIZE);
1617
1618 // Block Row 1 (input rows y_start+4..y_start+8)
1619 let inp = input.add((y_start + 4) * width + x_start);
1620 transpose_4x4_neon_8b(inp, buf.add(4), width, TILE_SIZE);
1621 transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 4), width, TILE_SIZE);
1622 transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 4), width, TILE_SIZE);
1623 transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 4), width, TILE_SIZE);
1624
1625 // Block Row 2 (input rows y_start+8..y_start+12)
1626 let inp = input.add((y_start + 8) * width + x_start);
1627 transpose_4x4_neon_8b(inp, buf.add(8), width, TILE_SIZE);
1628 transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 8), width, TILE_SIZE);
1629 transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 8), width, TILE_SIZE);
1630 transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 8), width, TILE_SIZE);
1631
1632 // Block Row 3 (input rows y_start+12..y_start+16)
1633 let inp = input.add((y_start + 12) * width + x_start);
1634 transpose_4x4_neon_8b(inp, buf.add(12), width, TILE_SIZE);
1635 transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 12), width, TILE_SIZE);
1636 transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 12), width, TILE_SIZE);
1637 transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 12), width, TILE_SIZE);
1638
1639 // Flush buffer to output with write prefetching.
1640 prefetch_write(output.add(x_start * height + y_start) as *const u8);
1641 for c in 0..TILE_SIZE {
1642 if c + 1 < TILE_SIZE {
1643 prefetch_write(output.add((x_start + c + 1) * height + y_start) as *const u8);
1644 }
1645 core::ptr::copy_nonoverlapping(
1646 buf.add(c * TILE_SIZE),
1647 output.add((x_start + c) * height + y_start),
1648 TILE_SIZE,
1649 );
1650 }
1651 }
1652}
1653
1654/// Scalar transpose for an arbitrary rectangular block of 8-byte elements.
1655///
1656/// Used for handling edge cases where dimensions don't align to tile boundaries.
1657///
1658/// # Safety
1659///
1660/// Caller must ensure:
1661/// - Valid pointers for the full matrix
1662/// - `x_start + block_width <= width`
1663/// - `y_start + block_height <= height`
1664#[cfg(target_arch = "aarch64")]
1665#[inline]
1666#[allow(clippy::too_many_arguments)]
1667unsafe fn transpose_block_scalar_8b(
1668 input: *const u64,
1669 output: *mut u64,
1670 width: usize,
1671 height: usize,
1672 x_start: usize,
1673 y_start: usize,
1674 block_width: usize,
1675 block_height: usize,
1676) {
1677 for inner_x in 0..block_width {
1678 for inner_y in 0..block_height {
1679 let x = x_start + inner_x;
1680 let y = y_start + inner_y;
1681
1682 let input_index = x + y * width;
1683 let output_index = y + x * height;
1684
1685 unsafe {
1686 *output.add(output_index) = *input.add(input_index);
1687 }
1688 }
1689 }
1690}
1691
1692/// Transpose a 4×4 block of 64-bit elements using NEON SIMD.
1693///
1694/// This is the fundamental building block for 8-byte element transpose.
1695///
1696/// Since a 128-bit NEON register holds only 2 u64 elements, each row of 4
1697/// elements requires 2 registers. A 4×4 block uses 8 registers for input
1698/// and 8 for output (16 total, well within NEON's 32 registers).
1699///
1700/// # Algorithm
1701///
1702/// The transpose uses a single-stage butterfly on four independent 2×2
1703/// sub-blocks:
1704///
1705/// ```text
1706/// Load: q0_lo=[a00,a01] q0_hi=[a02,a03] (row 0)
1707/// q1_lo=[a10,a11] q1_hi=[a12,a13] (row 1)
1708/// q2_lo=[a20,a21] q2_hi=[a22,a23] (row 2)
1709/// q3_lo=[a30,a31] q3_hi=[a32,a33] (row 3)
1710///
1711/// Transpose 2×2 sub-blocks:
1712/// Top-left: trn1(q0_lo,q1_lo)=[a00,a10] trn2(q0_lo,q1_lo)=[a01,a11]
1713/// Top-right: trn1(q0_hi,q1_hi)=[a02,a12] trn2(q0_hi,q1_hi)=[a03,a13]
1714/// Bottom-left: trn1(q2_lo,q3_lo)=[a20,a30] trn2(q2_lo,q3_lo)=[a21,a31]
1715/// Bottom-right: trn1(q2_hi,q3_hi)=[a22,a32] trn2(q2_hi,q3_hi)=[a23,a33]
1716///
1717/// Store: row0=[a00,a10,a20,a30] row1=[a01,a11,a21,a31]
1718/// row2=[a02,a12,a22,a32] row3=[a03,a13,a23,a33]
1719/// ```
1720///
1721/// # Safety
1722///
1723/// Caller must ensure:
1724/// - `src` is valid for reading 4 rows of `src_stride` elements each
1725/// - `dst` is valid for writing 4 rows of `dst_stride` elements each
1726/// - The first 4 elements of each row are accessible
1727#[cfg(target_arch = "aarch64")]
1728#[inline(always)]
1729unsafe fn transpose_4x4_neon_8b(
1730 src: *const u64,
1731 dst: *mut u64,
1732 src_stride: usize,
1733 dst_stride: usize,
1734) {
1735 unsafe {
1736 // Load 4 rows, 2 registers per row (4 u64 = 2 × 128-bit)
1737
1738 // Row 0: [a00, a01] [a02, a03]
1739 let q0_lo = vld1q_u64(src);
1740 let q0_hi = vld1q_u64(src.add(2));
1741 // Row 1: [a10, a11] [a12, a13]
1742 let q1_lo = vld1q_u64(src.add(src_stride));
1743 let q1_hi = vld1q_u64(src.add(src_stride + 2));
1744 // Row 2: [a20, a21] [a22, a23]
1745 let q2_lo = vld1q_u64(src.add(2 * src_stride));
1746 let q2_hi = vld1q_u64(src.add(2 * src_stride + 2));
1747 // Row 3: [a30, a31] [a32, a33]
1748 let q3_lo = vld1q_u64(src.add(3 * src_stride));
1749 let q3_hi = vld1q_u64(src.add(3 * src_stride + 2));
1750
1751 // Transpose four 2×2 sub-blocks using vtrn1q_u64/vtrn2q_u64
1752
1753 // Top-left: rows 0,1 × columns 0,1
1754 let r0_lo = vtrn1q_u64(q0_lo, q1_lo); // [a00, a10]
1755 let r1_lo = vtrn2q_u64(q0_lo, q1_lo); // [a01, a11]
1756 // Top-right: rows 0,1 × columns 2,3
1757 let r2_lo = vtrn1q_u64(q0_hi, q1_hi); // [a02, a12]
1758 let r3_lo = vtrn2q_u64(q0_hi, q1_hi); // [a03, a13]
1759 // Bottom-left: rows 2,3 × columns 0,1
1760 let r0_hi = vtrn1q_u64(q2_lo, q3_lo); // [a20, a30]
1761 let r1_hi = vtrn2q_u64(q2_lo, q3_lo); // [a21, a31]
1762 // Bottom-right: rows 2,3 × columns 2,3
1763 let r2_hi = vtrn1q_u64(q2_hi, q3_hi); // [a22, a32]
1764 let r3_hi = vtrn2q_u64(q2_hi, q3_hi); // [a23, a33]
1765
1766 // Store 4 transposed rows, 2 registers per row
1767
1768 // Row 0: [a00, a10, a20, a30]
1769 vst1q_u64(dst, r0_lo);
1770 vst1q_u64(dst.add(2), r0_hi);
1771 // Row 1: [a01, a11, a21, a31]
1772 vst1q_u64(dst.add(dst_stride), r1_lo);
1773 vst1q_u64(dst.add(dst_stride + 2), r1_hi);
1774 // Row 2: [a02, a12, a22, a32]
1775 vst1q_u64(dst.add(2 * dst_stride), r2_lo);
1776 vst1q_u64(dst.add(2 * dst_stride + 2), r2_hi);
1777 // Row 3: [a03, a13, a23, a33]
1778 vst1q_u64(dst.add(3 * dst_stride), r3_lo);
1779 vst1q_u64(dst.add(3 * dst_stride + 2), r3_hi);
1780 }
1781}
1782
1783#[cfg(test)]
1784mod tests {
1785 use alloc::vec;
1786 use alloc::vec::Vec;
1787
1788 use p3_baby_bear::BabyBear;
1789 use p3_field::PrimeCharacteristicRing;
1790 use p3_goldilocks::Goldilocks;
1791 use proptest::prelude::*;
1792
1793 use super::*;
1794
1795 /// Naive reference implementation for correctness testing.
1796 fn transpose_reference<T: Copy + Default>(input: &[T], width: usize, height: usize) -> Vec<T> {
1797 // Allocate output buffer with same size as input.
1798 let mut output = vec![T::default(); width * height];
1799
1800 // For each position (x, y) in the input matrix:
1801 // - Input index: y * width + x (row-major)
1802 // - Output index: x * height + y (transposed row-major)
1803 for y in 0..height {
1804 for x in 0..width {
1805 output[x * height + y] = input[y * width + x];
1806 }
1807 }
1808
1809 output
1810 }
1811
1812 /// Strategy for generating matrix dimensions.
1813 fn dimension_strategy() -> impl Strategy<Value = (usize, usize)> {
1814 // Compute boundary dimensions from constants.
1815 // `small_side` is the largest square that stays in the small (scalar) path.
1816 let small_side = (SMALL_LEN as f64).sqrt() as usize;
1817 // `medium_side` is the largest square that stays in the medium (tiled) path.
1818 let medium_side = (MEDIUM_LEN as f64).sqrt() as usize;
1819 // `large_side` is the side length that triggers the large (recursive) path.
1820 let large_side = medium_side + 1;
1821
1822 prop_oneof![
1823 // Edge cases: empty and degenerate matrices
1824 //
1825 // Empty matrix (0×0)
1826 Just((0, 0)),
1827 // Single row (1×n) - tests degenerate case
1828 (1..=100_usize).prop_map(|w| (w, 1)),
1829 // Single column (n×1) - tests degenerate case
1830 (1..=100_usize).prop_map(|h| (1, h)),
1831 // Small path: len < SMALL_LEN (scalar transpose)
1832 //
1833 // These dimensions exercise the scalar transpose path.
1834
1835 // Tiny matrices (various shapes within small threshold)
1836 (1..=small_side, 1..=small_side),
1837 // Medium path: SMALL_LEN ≤ len < MEDIUM_LEN (tiled TILE_SIZE×TILE_SIZE)
1838 //
1839 // These dimensions exercise the tiled TILE_SIZE×TILE_SIZE path.
1840
1841 // Exactly 4×4 (single NEON block)
1842 Just((4, 4)),
1843 // Exactly TILE_SIZE×TILE_SIZE (single tile)
1844 Just((TILE_SIZE, TILE_SIZE)),
1845 // Multiple complete tiles (2× and 4× TILE_SIZE)
1846 Just((TILE_SIZE * 2, TILE_SIZE * 2)),
1847 Just((TILE_SIZE * 4, TILE_SIZE * 4)),
1848 // Non-aligned: has remainders in both dimensions
1849 // Range from just above TILE_SIZE to below 4×TILE_SIZE.
1850 // These test the scalar fallback for tile edges.
1851 (
1852 (TILE_SIZE + 1)..=(TILE_SIZE * 4 - 1),
1853 (TILE_SIZE + 1)..=(TILE_SIZE * 4 - 1)
1854 ),
1855 // Wide rectangle with remainders (medium path)
1856 (50..=200_usize, 10..=50_usize),
1857 // Tall rectangle with remainders (medium path)
1858 (10..=50_usize, 50..=200_usize),
1859 // Large path: MEDIUM_LEN ≤ len < PARALLEL_THRESHOLD (recursive)
1860 //
1861 // These exercise the cache-oblivious recursive subdivision.
1862
1863 // Square matrices triggering recursion (just above medium threshold)
1864 Just((large_side, large_side)),
1865 // Slightly larger square
1866 Just((large_side + 100, large_side + 100)),
1867 // Wide rectangle triggering recursion
1868 Just((large_side * 2, large_side / 2)),
1869 // Tall rectangle triggering recursion
1870 Just((large_side / 2, large_side * 2)),
1871 // Non-power-of-2 dimensions in large range
1872 Just((large_side + 50, large_side + 75)),
1873 ]
1874 }
1875
1876 proptest! {
1877 #[test]
1878 fn proptest_transpose_babybear((width, height) in dimension_strategy()) {
1879 // Skip empty matrices (they're trivially correct).
1880 if width == 0 || height == 0 {
1881 // Just verify it doesn't panic.
1882 let input: [BabyBear; 0] = [];
1883 let mut output: [BabyBear; 0] = [];
1884 transpose(&input, &mut output, width, height);
1885 return Ok(());
1886 }
1887
1888 // Create input matrix with unique values at each position.
1889 let input: Vec<BabyBear> = (0..width * height)
1890 .map(|i| BabyBear::from_u64(i as u64))
1891 .collect();
1892
1893 // Allocate output buffer.
1894 let mut output = vec![BabyBear::ZERO; width * height];
1895
1896 // Run optimized transpose.
1897 transpose(&input, &mut output, width, height);
1898
1899 // Run reference transpose.
1900 let expected = transpose_reference(&input, width, height);
1901
1902 // Verify results match.
1903 prop_assert_eq!(
1904 output,
1905 expected,
1906 "Transpose mismatch for {}×{} matrix",
1907 width,
1908 height
1909 );
1910 }
1911
1912 #[test]
1913 fn proptest_transpose_u64((width, height) in dimension_strategy()) {
1914 // Skip empty and very large matrices for u64 (memory intensive).
1915 if width == 0 || height == 0 || width * height > 100_000 {
1916 return Ok(());
1917 }
1918
1919 // Create input with unique values.
1920 let input: Vec<u64> = (0..width * height).map(|i| i as u64).collect();
1921
1922 // Allocate output.
1923 let mut output = vec![0u64; width * height];
1924
1925 // Run transpose.
1926 transpose(&input, &mut output, width, height);
1927
1928 // Verify against reference.
1929 let expected = transpose_reference(&input, width, height);
1930 prop_assert_eq!(output, expected);
1931 }
1932
1933 #[test]
1934 fn proptest_transpose_u8((width, height) in dimension_strategy()) {
1935 // Skip empty and very large matrices.
1936 if width == 0 || height == 0 || width * height > 100_000 {
1937 return Ok(());
1938 }
1939
1940 // Create input with unique values (wrapping for u8).
1941 let input: Vec<u8> = (0..width * height).map(|i| i as u8).collect();
1942
1943 // Allocate output.
1944 let mut output = vec![0u8; width * height];
1945
1946 // Run transpose.
1947 transpose(&input, &mut output, width, height);
1948
1949 // Verify against reference.
1950 let expected = transpose_reference(&input, width, height);
1951 prop_assert_eq!(output, expected);
1952 }
1953
1954 #[test]
1955 fn proptest_transpose_goldilocks((width, height) in dimension_strategy()) {
1956 // Skip empty matrices.
1957 if width == 0 || height == 0 {
1958 let input: [Goldilocks; 0] = [];
1959 let mut output: [Goldilocks; 0] = [];
1960 transpose(&input, &mut output, width, height);
1961 return Ok(());
1962 }
1963
1964 // Create input matrix with unique values at each position.
1965 let input: Vec<Goldilocks> = (0..width * height)
1966 .map(|i| Goldilocks::from_u64(i as u64))
1967 .collect();
1968
1969 // Allocate output buffer.
1970 let mut output = vec![Goldilocks::ZERO; width * height];
1971
1972 // Run optimized transpose.
1973 transpose(&input, &mut output, width, height);
1974
1975 // Run reference transpose.
1976 let expected = transpose_reference(&input, width, height);
1977
1978 // Verify results match.
1979 prop_assert_eq!(
1980 output,
1981 expected,
1982 "Transpose mismatch for {}×{} matrix",
1983 width,
1984 height
1985 );
1986 }
1987 }
1988}