Skip to main content

p3_util/transpose/
rectangular.rs

1//! High-performance matrix transpose for generic `Copy` types.
2//!
3//! This module provides an optimized **out-of-place** matrix transpose.
4//!
5//! # Overview
6//!
7//! Matrix transposition converts a row-major matrix into its column-major equivalent.
8//! For a matrix `A` with dimensions `height × width`:
9//!
10//! ```text
11//!     A[i][j] → A^T[j][i]
12//! ```
13//!
14//! In memory (row-major layout), element at position `(row, col)` is stored at:
15//! - **Input**: `input[row * width + col]`
16//! - **Output**: `output[col * height + row]`
17//!
18//! # Architecture-Specific Optimizations
19//!
20//! On **ARM64** (aarch64), this module uses NEON SIMD intrinsics for:
21//! - **4-byte elements** (typical for 32-bit field elements like `MontyField31`, `BabyBear`)
22//!   using a 2-stage butterfly (`vtrn1q_u32`/`vtrn2q_u32` then `vtrn1q_u64`/`vtrn2q_u64`).
23//! - **8-byte elements** (typical for 64-bit field elements like `Goldilocks`)
24//!   using a simpler 1-stage butterfly (`vtrn1q_u64`/`vtrn2q_u64`) on pairs of registers.
25//!
26//! On other architectures or for other element sizes, it falls back to the `transpose` crate.
27//!
28//! # Key Optimizations
29//!
30//! ## NEON SIMD Registers (128-bit)
31//!
32//! ARM64 NEON provides 32 vector registers, each holding 128 bits.
33//! - For 32-bit (4-byte) elements, each register holds exactly **`BLOCK_SIZE` elements**.
34//! - A `BLOCK_SIZE`×`BLOCK_SIZE` block (`BLOCK_SIZE`^2 elements) fits perfectly in **`BLOCK_SIZE` registers**.
35//!
36//! ```text
37//!     ┌─────────────────────────────────┐
38//!     │  q0 = [ a00, a01, a02, a03 ]    │  ← 128 bits = 4 × 32-bit
39//!     │  q1 = [ a10, a11, a12, a13 ]    │
40//!     │  q2 = [ a20, a21, a22, a23 ]    │
41//!     │  q3 = [ a30, a31, a32, a33 ]    │
42//!     └─────────────────────────────────┘
43//! ```
44//!
45//! ## In-Register Transpose (Butterfly Network)
46//!
47//! We transpose a `BLOCK_SIZE`×`BLOCK_SIZE` block entirely in registers using a 2-stage butterfly:
48//!
49//! **Stage 1**: Swap pairs of 32-bit elements using `TRN1`/`TRN2`
50//! **Stage 2**: Swap pairs of 64-bit elements using `TRN1`/`TRN2` on reinterpreted u64
51//!
52//! ```text
53//!     Input:          After Stage 1:      After Stage 2 (Output):
54//!     ┌─────────────┐ ┌─────────────┐     ┌─────────────┐
55//!     │ a b │ c d   │ │ a e │ c g   │     │ a e │ i m   │
56//!     │ e f │ g h   │ │ b f │ d h   │     │ b f │ j n   │
57//!     │─────┼───────│ │─────┼───────│     │─────┼───────│
58//!     │ i j │ k l   │ │ i m │ k o   │     │ c g │ k o   │
59//!     │ m n │ o p   │ │ j n │ l p   │     │ d h │ l p   │
60//!     └─────────────┘ └─────────────┘     └─────────────┘
61//! ```
62//!
63//! ## Multi-Level Tiling Strategy
64//!
65//! Different strategies for different matrix sizes:
66//! - **Small (<`SMALL_LEN` elements)**: Scalar transpose - fits in L1, no overhead
67//! - **Medium (<`MEDIUM_LEN` elements)**: `TILE_SIZE`×`TILE_SIZE` Tiled - L2-friendly tiles
68//! - **Large (≥`MEDIUM_LEN` elements)**: Recursive + Tiled - Cache-oblivious
69
70#[cfg(target_arch = "aarch64")]
71use core::arch::aarch64::*;
72#[cfg(target_arch = "aarch64")]
73use core::mem::MaybeUninit;
74#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
75use core::sync::atomic::{AtomicUsize, Ordering};
76
77/// Software prefetch for write (PRFM PSTL1KEEP).
78///
79/// Brings the cache line containing `ptr` into the L1 data cache in exclusive
80/// state, preparing for a subsequent store. This avoids Read-For-Ownership
81/// (RFO) stalls when writing to memory not already in L1.
82#[cfg(target_arch = "aarch64")]
83#[inline(always)]
84unsafe fn prefetch_write(ptr: *const u8) {
85    // PRFM PSTL1KEEP: Prefetch for Store, L1 cache, temporal (keep in cache).
86    unsafe {
87        core::arch::asm!(
88            "prfm pstl1keep, [{ptr}]",
89            ptr = in(reg) ptr,
90            options(readonly, nostack, preserves_flags),
91        );
92    }
93}
94
95/// Maximum number of elements for the simple scalar transpose.
96///
97/// For matrices with fewer than `SMALL_LEN` elements (~1KB for 4-byte elements),
98/// the overhead of tiling isn't worth it.
99///
100/// Direct element-by-element copy is faster because:
101/// - The entire matrix fits in L1 cache (32-64KB on most CPUs)
102/// - No tile boundary calculations needed
103/// - Branch prediction works well for small loops
104#[cfg(any(target_arch = "aarch64", test))]
105const SMALL_LEN: usize = 255;
106
107/// Maximum number of elements for the single-level tiled transpose.
108///
109/// For matrices up to `MEDIUM_LEN` elements (~4MB for 4-byte elements), we use
110/// a simple tiled approach with `TILE_SIZE`×`TILE_SIZE` tiles.
111///
112/// This fits comfortably within L2 cache (256KB-512KB) with good spatial locality.
113///
114/// Beyond this threshold, we switch to recursive subdivision to ensure
115/// cache-oblivious behavior for very large matrices.
116#[cfg(any(target_arch = "aarch64", test))]
117const MEDIUM_LEN: usize = 1024 * 1024;
118
119/// Side length of a tile in elements.
120///
121/// We use `TILE_SIZE`×`TILE_SIZE` tiles because:
122/// - `TILE_SIZE`×`TILE_SIZE` × 4 bytes = 1KB per tile, fitting in L1 cache
123/// - `TILE_SIZE` is divisible by `BLOCK_SIZE`, allowing exactly (`TILE_SIZE`/`BLOCK_SIZE`)^2 NEON blocks per tile
124/// - Good balance between tile overhead and cache utilization
125#[cfg(any(target_arch = "aarch64", test))]
126const TILE_SIZE: usize = 16;
127
128/// Maximum dimension for recursive base case.
129///
130/// When recursively subdividing large matrices, we stop when both
131/// dimensions are ≤ `RECURSIVE_LIMIT` elements.
132///
133/// At this point, the sub-matrix (up to `RECURSIVE_LIMIT`×`RECURSIVE_LIMIT` elements)
134/// fits in L2 cache, so we switch to tiled transpose.
135#[cfg(target_arch = "aarch64")]
136const RECURSIVE_LIMIT: usize = 128;
137
138/// Minimum number of elements before enabling parallel processing.
139///
140/// Parallel transpose only pays off for large matrices because:
141/// - Thread spawn/join overhead (~1-10μs)
142/// - Cache coherency traffic between cores
143/// - Memory bandwidth becomes the bottleneck, not compute
144///
145/// At `PARALLEL_THRESHOLD` elements, the work per thread is large enough that
146/// parallelism overhead is amortized.
147#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
148const PARALLEL_THRESHOLD: usize = 4 * 1024 * 1024;
149
150/// Transpose a matrix from row-major `input` to row-major `output`.
151///
152/// Given an input matrix with `height` rows and `width` columns, produces
153/// an output matrix with `width` rows and `height` columns.
154///
155/// # Memory Layout
156///
157/// Both input and output are stored in **row-major order**.
158///
159/// ```text
160///     Input (height=2, width=3):       Output (height=3, width=2):
161///
162///     Row 0: [ a, b, c ]               Row 0: [ a, d ]
163///     Row 1: [ d, e, f ]               Row 1: [ b, e ]
164///                                      Row 2: [ c, f ]
165///
166///     Memory: [a, b, c, d, e, f]       Memory: [a, d, b, e, c, f]
167/// ```
168///
169/// # Index Transformation
170///
171/// - Initial element at `input[row * width + col]`,
172/// - Transposed position is `output[col * height + row]`.
173///
174/// # Arguments
175///
176/// * `input` - Source matrix in row-major order
177/// * `output` - Destination buffer in row-major order
178/// * `width` - Number of columns in the input matrix
179/// * `height` - Number of rows in the input matrix
180///
181/// # Panics
182///
183/// Panics if:
184/// - `input.len() != width * height`
185/// - `output.len() != width * height`
186#[inline]
187pub fn transpose<T: Copy + Send + Sync>(
188    input: &[T],
189    output: &mut [T],
190    width: usize,
191    height: usize,
192) {
193    // Input validation
194    assert_eq!(
195        input.len(),
196        width * height,
197        "Input length {} doesn't match width*height = {}",
198        input.len(),
199        width * height
200    );
201    assert_eq!(
202        output.len(),
203        width * height,
204        "Output length {} doesn't match width*height = {}",
205        output.len(),
206        width * height
207    );
208
209    // Handle empty matrices
210    if width == 0 || height == 0 {
211        return;
212    }
213
214    // Architecture dispatch
215    #[cfg(target_arch = "aarch64")]
216    {
217        // Use NEON-optimized path for 4-byte elements.
218        //
219        // This covers common field types like MontyField31.
220        if core::mem::size_of::<T>() == 4 {
221            // SAFETY:
222            // - input/output lengths verified above
223            // - T is 4 bytes, matching u32 size and alignment
224            // - Pointers derived from valid slices
225            unsafe {
226                transpose_neon_4b(
227                    input.as_ptr().cast::<u32>(),
228                    output.as_mut_ptr().cast::<u32>(),
229                    width,
230                    height,
231                );
232            }
233            return;
234        }
235
236        // Use NEON-optimized path for 8-byte elements.
237        //
238        // This covers 64-bit field types like Goldilocks.
239        // A 128-bit NEON register holds 2 u64 elements, so we use
240        // pairs of registers per row and a 1-stage butterfly.
241        if core::mem::size_of::<T>() == 8 {
242            // SAFETY:
243            // - input/output lengths verified above
244            // - T is 8 bytes, matching u64 size and alignment
245            // - Pointers derived from valid slices
246            unsafe {
247                transpose_neon_8b(
248                    input.as_ptr().cast::<u64>(),
249                    output.as_mut_ptr().cast::<u64>(),
250                    width,
251                    height,
252                );
253            }
254            return;
255        }
256    }
257
258    // Fallback for non-ARM64 or unsupported element sizes.
259    transpose::transpose(input, output, width, height);
260}
261
262/// Top-level NEON transpose dispatcher for 4-byte elements.
263///
264/// Selects the appropriate strategy based on matrix size:
265///
266/// ```text
267///     ┌───────────────────────────────────────────────────────────────────────────────────┐
268///     │                              transpose_neon_4b                                    │
269///     │                                     │                                             │
270///     │    ┌────────────────────────────────┼───────────────────────────────┐             │
271///     │    ▼                                ▼                               ▼             │
272///     │  len < SMALL_LEN     SMALL_LEN ≤ len < MEDIUM_LEN          len ≥ MEDIUM_LEN       │
273///     │    │                                │                               │             │
274///     │    ▼                                ▼                               ▼             │
275///     │  scalar              tiled TILE_SIZE×TILE_SIZE                 recursive          │
276///     │                                     │                        (→ tiled at          │
277///     │                                     │                          leaves)            │
278///     │                                     │                               │             │
279///     │                                     └───────────────┬───────────────┘             │
280///     │                                                     ▼                             │
281///     │                                       parallel (if ≥ PARALLEL_THRESHOLD           │
282///     │                                         and feature enabled)                      │
283///     └───────────────────────────────────────────────────────────────────────────────────┘
284/// ```
285///
286/// # Safety
287///
288/// Caller must ensure `input` and `output` point to valid memory regions
289/// of at least `width * height` elements each.
290#[cfg(target_arch = "aarch64")]
291#[inline]
292unsafe fn transpose_neon_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
293    // Total number of elements in the matrix.
294    let len = width * height;
295
296    #[cfg(feature = "parallel")]
297    {
298        // Parallel path (if enabled and matrix is large enough)
299        if len >= PARALLEL_THRESHOLD {
300            // SAFETY: Caller guarantees valid pointers.
301            unsafe {
302                transpose_neon_4b_parallel(input, output, width, height);
303            }
304            return;
305        }
306    }
307
308    // Sequential path - choose strategy based on size
309    if len <= SMALL_LEN {
310        // Small matrix: simple scalar transpose.
311        //
312        // SAFETY: Caller guarantees valid pointers.
313        unsafe {
314            transpose_small_4b(input, output, width, height);
315        }
316    } else if len <= MEDIUM_LEN {
317        // Medium matrix: single-level `TILE_SIZE`×`TILE_SIZE` tiling.
318        //
319        // SAFETY: Caller guarantees valid pointers.
320        unsafe {
321            transpose_tiled_4b(input, output, width, height);
322        }
323    } else {
324        // Large matrix: recursive subdivision then tiling.
325        //
326        // This is the cache-oblivious approach.
327        // SAFETY: Caller guarantees valid pointers.
328        unsafe {
329            transpose_recursive_4b(input, output, 0, height, 0, width, width, height);
330        }
331    }
332}
333
334/// Parallel transpose for very large matrices (≥ `PARALLEL_THRESHOLD` elements).
335///
336/// Divides the matrix into horizontal stripes, one per thread.
337/// Each thread processes its stripe independently using the tiled algorithm.
338///
339/// # Stripe Division
340///
341/// ```text
342///     ┌─────────────────────────────────┐
343///     │         Thread 0                │  rows [0, rows_per_thread)
344///     ├─────────────────────────────────┤
345///     │         Thread 1                │  rows [rows_per_thread, 2*rows_per_thread)
346///     ├─────────────────────────────────┤
347///     │         Thread 2                │  ...
348///     ├─────────────────────────────────┤
349///     │         ...                     │
350///     └─────────────────────────────────┘
351/// ```
352///
353/// # Data Race Safety
354///
355/// Each thread writes to a disjoint portion of the output:
356/// - Thread processing rows `[r_start, r_end)` writes to columns `[r_start, r_end)`
357///   of the transposed output.
358/// - No synchronization needed beyond the initial stripe assignment.
359///
360/// # Safety
361///
362/// Caller must ensure valid pointers for `width * height` elements.
363#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
364#[inline]
365unsafe fn transpose_neon_4b_parallel(
366    input: *const u32,
367    output: *mut u32,
368    width: usize,
369    height: usize,
370) {
371    use rayon::prelude::*;
372
373    // Compute stripe sizes
374
375    // Number of available threads in the rayon thread pool.
376    let num_threads = rayon::current_num_threads();
377
378    // Divide rows as evenly as possible among threads.
379    //
380    // Ceiling ensures the last thread doesn't get an oversized chunk.
381    let rows_per_thread = height.div_ceil(num_threads);
382
383    // Share pointers across threads
384
385    // We use `AtomicUsize` to pass pointer addresses to threads.
386    //
387    // This is safe because:
388    // 1. We only read the addresses (Relaxed ordering is fine)
389    // 2. Each thread writes to disjoint output regions
390    let inp = AtomicUsize::new(input as usize);
391    let out = AtomicUsize::new(output as usize);
392
393    // Parallel stripe processing
394    (0..num_threads).into_par_iter().for_each(|thread_idx| {
395        // Compute this thread's row range.
396        let row_start = thread_idx * rows_per_thread;
397        let row_end = (row_start + rows_per_thread).min(height);
398
399        // Skip if this thread has no work (can happen with more threads than rows).
400        if row_start < row_end {
401            // Recover pointers from atomic storage.
402            let input_ptr = inp.load(Ordering::Relaxed) as *const u32;
403            let output_ptr = out.load(Ordering::Relaxed) as *mut u32;
404
405            // SAFETY:
406            // - Pointers are valid (from caller)
407            // - Each thread writes to disjoint output columns
408            unsafe {
409                transpose_region_tiled_4b(
410                    input_ptr, output_ptr, row_start, row_end, 0, width, width, height,
411                );
412            }
413        }
414    });
415}
416
417/// Simple element-by-element transpose for small matrices.
418///
419/// For matrices with <= `SMALL_LEN` elements, the overhead of tiling isn't justified.
420/// Direct copying with good cache behavior is faster.
421///
422/// # Algorithm
423///
424/// For each position `(x, y)`:
425/// - Read from `input[y * width + x]`
426/// - Write to `output[x * height + y]`
427///
428/// # Loop Order
429///
430/// We iterate `x` in the outer loop to improve **output locality**.
431///
432/// This means consecutive writes go to consecutive memory addresses,
433/// which is better for the write-combining buffers.
434///
435/// # Safety
436///
437/// Caller must ensure valid pointers for `width * height` elements.
438#[cfg(target_arch = "aarch64")]
439#[inline]
440unsafe fn transpose_small_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
441    // Outer loop over columns (output rows).
442    for x in 0..width {
443        // Inner loop over rows (output columns).
444        for y in 0..height {
445            // Input index: row-major position of element (y, x).
446            let input_index = x + y * width;
447
448            // Output index: row-major position of element (x, y).
449            let output_index = y + x * height;
450
451            // SAFETY: Indices are within bounds by loop construction.
452            unsafe {
453                *output.add(output_index) = *input.add(input_index);
454            }
455        }
456    }
457}
458
459/// Tiled transpose using `TILE_SIZE`×`TILE_SIZE` tiles composed of `BLOCK_SIZE`×`BLOCK_SIZE` NEON blocks.
460///
461/// This is important for medium-sized matrices (`SMALL_LEN` to `MEDIUM_LEN` elements).
462///
463/// # Tiling Strategy
464///
465/// - The matrix is divided into `TILE_SIZE`×`TILE_SIZE` tiles.
466/// - Each tile is further divided into (`TILE_SIZE`/`BLOCK_SIZE`)^2 blocks that are transposed using NEON SIMD.
467///
468/// ```text
469///     Matrix (e.g., 64×48):
470///     ┌─────────────────────────────────────────────────────────────────────────────────────┐
471///     │    Tile(0,0)       │    Tile(1,0)       │    Tile(2,0)       │    Tile(3,0)  │rem_x │
472///     │ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×..  │      │
473///     ├────────────────────┼────────────────────┼────────────────────┼───────────────┼──────┤
474///     │    Tile(0,1)       │    Tile(1,1)       │    Tile(2,1)       │    Tile(3,1)  │rem_x │
475///     │ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×..  │      │
476///     ├────────────────────┼────────────────────┼────────────────────┼───────────────┼──────┤
477///     │    Tile(0,2)       │    Tile(1,2)       │    Tile(2,2)       │    Tile(3,2)  │rem_x │
478///     │ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×TILE_SIZE│ TILE_SIZE×..  │      │
479///     └────────────────────┴────────────────────┴────────────────────┴───────────────┴──────┘
480///                                                                                     rem_y
481/// ```
482///
483/// Remainders (`rem_x`, `rem_y`) are handled with scalar transpose.
484///
485/// # Safety
486///
487/// Caller must ensure valid pointers for `width * height` elements.
488#[cfg(target_arch = "aarch64")]
489#[inline]
490unsafe fn transpose_tiled_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
491    // Compute tile counts and remainders
492
493    // Number of complete `TILE_SIZE`×`TILE_SIZE` tiles in each dimension.
494    let x_tile_count = width / TILE_SIZE;
495    let y_tile_count = height / TILE_SIZE;
496
497    // Leftover elements that don't fit in complete tiles.
498    let remainder_x = width - x_tile_count * TILE_SIZE;
499    let remainder_y = height - y_tile_count * TILE_SIZE;
500
501    // Process complete `TILE_SIZE`×`TILE_SIZE` tiles
502
503    // Iterate over tile rows.
504    for y_tile in 0..y_tile_count {
505        // Iterate over tile columns.
506        for x_tile in 0..x_tile_count {
507            // Top-left corner of this tile.
508            let x_start = x_tile * TILE_SIZE;
509            let y_start = y_tile * TILE_SIZE;
510
511            // Transpose this `TILE_SIZE`×`TILE_SIZE` tile
512            //
513            // SAFETY: Tile coordinates are within bounds.
514            unsafe {
515                transpose_tile_16x16_neon(input, output, width, height, x_start, y_start);
516            }
517        }
518
519        // Handle partial column tiles (right edge)
520
521        // Elements in columns [x_tile_count * TILE_SIZE, width) don't form a complete tile.
522        // Use scalar transpose for these.
523        if remainder_x > 0 {
524            // SAFETY: Coordinates are within bounds.
525            unsafe {
526                transpose_block_scalar(
527                    input,
528                    output,
529                    width,
530                    height,
531                    x_tile_count * TILE_SIZE, // x_start
532                    y_tile * TILE_SIZE,       // y_start
533                    remainder_x,              // block_width
534                    TILE_SIZE,                // block_height
535                );
536            }
537        }
538    }
539
540    // Handle partial row tiles (bottom edge)
541
542    // Elements in rows [y_tile_count * TILE_SIZE, height) don't form complete tiles.
543    if remainder_y > 0 {
544        // Process bottom edge tiles (except corner).
545        for x_tile in 0..x_tile_count {
546            // SAFETY: Coordinates are within bounds.
547            unsafe {
548                transpose_block_scalar(
549                    input,
550                    output,
551                    width,
552                    height,
553                    x_tile * TILE_SIZE,       // x_start
554                    y_tile_count * TILE_SIZE, // y_start
555                    TILE_SIZE,                // block_width
556                    remainder_y,              // block_height
557                );
558            }
559        }
560
561        // Handle corner block (bottom-right)
562
563        // The corner block is the intersection of right and bottom remainders.
564        if remainder_x > 0 {
565            // SAFETY: Coordinates are within bounds.
566            unsafe {
567                transpose_block_scalar(
568                    input,
569                    output,
570                    width,
571                    height,
572                    x_tile_count * TILE_SIZE, // x_start
573                    y_tile_count * TILE_SIZE, // y_start
574                    remainder_x,              // block_width
575                    remainder_y,              // block_height
576                );
577            }
578        }
579    }
580}
581
582/// Recursive cache-oblivious transpose for large matrices.
583///
584/// This algorithm recursively subdivides the matrix until sub-blocks fit
585/// in cache, then uses tiled transpose on the leaves.
586///
587/// # Cache-Oblivious Design
588///
589/// The key insight is that we don't need to know cache sizes explicitly.
590///
591/// By recursively halving the problem, we eventually reach a size that
592/// fits in any level of cache (L1, L2, or L3).
593///
594/// # Recursion Pattern
595///
596/// At each level, we split along the **longer dimension**:
597///
598/// ```text
599///     Wide matrix (cols > rows):      Tall matrix (rows ≥ cols):
600///     Split vertically                Split horizontally
601///
602///     ┌─────────┬─────────┐           ┌───────────────────┐
603///     │         │         │           │                   │
604///     │  Left   │  Right  │           │       Top         │
605///     │         │         │           │                   │
606///     │         │         │           ├───────────────────┤
607///     │         │         │           │                   │
608///     └─────────┴─────────┘           │      Bottom       │
609///                                     │                   │
610///                                     └───────────────────┘
611/// ```
612///
613/// # Base Case
614///
615/// We stop recursing when both dimensions are ≤ `RECURSIVE_LIMIT` elements (or ≤ 2,
616/// which is a degenerate case). At this point, the sub-matrix fits in
617/// L2 cache (~64KB for `RECURSIVE_LIMIT`×`RECURSIVE_LIMIT`×4 bytes), so we use tiled transpose.
618///
619/// # Parameters
620///
621/// The function uses coordinate ranges rather than creating sub-arrays:
622/// - `row_start..row_end`: Row range in the original matrix
623/// - `col_start..col_end`: Column range in the original matrix
624/// - `total_cols`, `total_rows`: Original matrix dimensions (for stride calculations)
625///
626/// # Safety
627///
628/// Caller must ensure valid pointers and that coordinate ranges are within bounds.
629#[cfg(target_arch = "aarch64")]
630#[allow(clippy::too_many_arguments)]
631unsafe fn transpose_recursive_4b(
632    input: *const u32,
633    output: *mut u32,
634    row_start: usize,
635    row_end: usize,
636    col_start: usize,
637    col_end: usize,
638    total_cols: usize,
639    total_rows: usize,
640) {
641    // Compute sub-matrix dimensions
642    let nbr_rows = row_end - row_start;
643    let nbr_cols = col_end - col_start;
644
645    // Base case: small enough to use tiled transpose
646
647    // Stop recursing when:
648    // 1. Both dimensions ≤ RECURSIVE_LIMIT (fits in cache), OR
649    // 2. Either dimension ≤ 2 (degenerate case, no benefit from recursion)
650    if (nbr_rows <= RECURSIVE_LIMIT && nbr_cols <= RECURSIVE_LIMIT)
651        || nbr_rows <= 2
652        || nbr_cols <= 2
653    {
654        // SAFETY: Caller ensures valid pointers and bounds.
655        unsafe {
656            transpose_region_tiled_4b(
657                input, output, row_start, row_end, col_start, col_end, total_cols, total_rows,
658            );
659        }
660        return;
661    }
662
663    // Recursive case: split along the longer dimension
664    if nbr_rows >= nbr_cols {
665        // Split horizontally (by rows)
666
667        // Midpoint of the row range.
668        let mid = row_start + (nbr_rows / 2);
669
670        // Recurse on top half.
671        // SAFETY: mid is within [row_start, row_end].
672        unsafe {
673            transpose_recursive_4b(
674                input, output, row_start, mid, col_start, col_end, total_cols, total_rows,
675            );
676        }
677
678        // Recurse on bottom half.
679        // SAFETY: mid is within [row_start, row_end].
680        unsafe {
681            transpose_recursive_4b(
682                input, output, mid, row_end, col_start, col_end, total_cols, total_rows,
683            );
684        }
685    } else {
686        // Split vertically (by columns)
687
688        // Midpoint of the column range.
689        let mid = col_start + (nbr_cols / 2);
690
691        // Recurse on left half.
692        // SAFETY: mid is within [col_start, col_end].
693        unsafe {
694            transpose_recursive_4b(
695                input, output, row_start, row_end, col_start, mid, total_cols, total_rows,
696            );
697        }
698
699        // Recurse on right half.
700        // SAFETY: mid is within [col_start, col_end].
701        unsafe {
702            transpose_recursive_4b(
703                input, output, row_start, row_end, mid, col_end, total_cols, total_rows,
704            );
705        }
706    }
707}
708
709/// Tiled transpose for a rectangular region within a larger matrix.
710///
711/// It operates on a sub-region defined by coordinate ranges.
712///
713/// Used as the base case of recursive transpose and for parallel stripe processing.
714///
715/// # Coordinate System
716///
717/// ```text
718///     Original matrix (total_cols × total_rows):
719///     ┌─────────────────────────────────────────────────────────┐
720///     │                                                         │
721///     │    (col_start, row_start)                               │
722///     │           ┌─────────────────────┐                       │
723///     │           │                     │                       │
724///     │           │       Region to     │                       │
725///     │           │       transpose     │                       │
726///     │           │                     │                       │
727///     │           └─────────────────────┘                       │
728///     │                    (col_end, row_end)                   │
729///     │                                                         │
730///     └─────────────────────────────────────────────────────────┘
731/// ```
732///
733/// # Safety
734///
735/// Caller must ensure:
736/// - Valid pointers for `total_cols * total_rows` elements
737/// - `row_start < row_end <= total_rows`
738/// - `col_start < col_end <= total_cols`
739#[cfg(target_arch = "aarch64")]
740#[inline]
741#[allow(clippy::too_many_arguments)]
742unsafe fn transpose_region_tiled_4b(
743    input: *const u32,
744    output: *mut u32,
745    row_start: usize,
746    row_end: usize,
747    col_start: usize,
748    col_end: usize,
749    total_cols: usize,
750    total_rows: usize,
751) {
752    // Compute region dimensions and tile counts
753
754    // Dimensions of the region to transpose.
755    let nbr_cols = col_end - col_start;
756    let nbr_rows = row_end - row_start;
757
758    // Number of complete `TILE_SIZE`×`TILE_SIZE` tiles in each dimension.
759    let x_tile_count = nbr_cols / TILE_SIZE;
760    let y_tile_count = nbr_rows / TILE_SIZE;
761
762    // Leftover elements that don't fit in complete tiles.
763    let remainder_x = nbr_cols - x_tile_count * TILE_SIZE;
764    let remainder_y = nbr_rows - y_tile_count * TILE_SIZE;
765
766    // Process complete `TILE_SIZE`×`TILE_SIZE` tiles
767    for y_tile in 0..y_tile_count {
768        for x_tile in 0..x_tile_count {
769            // Coordinates of this tile's top-left corner in the original matrix.
770            let col = col_start + x_tile * TILE_SIZE;
771            let row = row_start + y_tile * TILE_SIZE;
772
773            // SAFETY: Tile coordinates are within the region bounds.
774            // Uses the buffered tile function: for large matrices the output
775            // is likely in L3/RAM, so L1 buffering + write prefetching avoids
776            // RFO stalls on scattered output writes.
777            unsafe {
778                transpose_tile_16x16_neon_buffered(input, output, total_cols, total_rows, col, row);
779            }
780        }
781
782        // Handle partial column tiles (right edge of region)
783        if remainder_x > 0 {
784            // SAFETY: Coordinates are within region bounds.
785            unsafe {
786                transpose_block_scalar(
787                    input,
788                    output,
789                    total_cols,
790                    total_rows,
791                    col_start + x_tile_count * TILE_SIZE, // x_start
792                    row_start + y_tile * TILE_SIZE,       // y_start
793                    remainder_x,                          // block_width
794                    TILE_SIZE,                            // block_height
795                );
796            }
797        }
798    }
799
800    // Handle partial row tiles (bottom edge of region)
801    if remainder_y > 0 {
802        for x_tile in 0..x_tile_count {
803            // SAFETY: Coordinates are within region bounds.
804            unsafe {
805                transpose_block_scalar(
806                    input,
807                    output,
808                    total_cols,
809                    total_rows,
810                    col_start + x_tile * TILE_SIZE,       // x_start
811                    row_start + y_tile_count * TILE_SIZE, // y_start
812                    TILE_SIZE,                            // block_width
813                    remainder_y,                          // block_height
814                );
815            }
816        }
817
818        // Handle corner block (bottom-right of region)
819        if remainder_x > 0 {
820            // SAFETY: Coordinates are within region bounds.
821            unsafe {
822                transpose_block_scalar(
823                    input,
824                    output,
825                    total_cols,
826                    total_rows,
827                    col_start + x_tile_count * TILE_SIZE, // x_start
828                    row_start + y_tile_count * TILE_SIZE, // y_start
829                    remainder_x,                          // block_width
830                    remainder_y,                          // block_height
831                );
832            }
833        }
834    }
835}
836
837/// Transpose a complete 16×16 tile using NEON SIMD (direct-to-output).
838///
839/// A 16×16 tile is processed as a 4×4 grid of 4×4 NEON blocks.
840/// This function is fully unrolled for maximum performance.
841///
842/// Used by the **medium tiled path** where the output likely fits in L2 cache
843/// and the overhead of L1 buffering isn't justified.
844///
845/// # Safety
846///
847/// Caller must ensure:
848/// - Valid pointers for the full matrix
849/// - `x_start + 16 <= width`
850/// - `y_start + 16 <= height`
851#[cfg(target_arch = "aarch64")]
852#[inline]
853unsafe fn transpose_tile_16x16_neon(
854    input: *const u32,
855    output: *mut u32,
856    width: usize,
857    height: usize,
858    x_start: usize,
859    y_start: usize,
860) {
861    unsafe {
862        // Block Row 0 (input rows y_start..y_start+4)
863        let inp = input.add(y_start * width + x_start);
864        let out = output.add(x_start * height + y_start);
865        transpose_4x4_neon(inp, out, width, height);
866        transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
867        transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
868        transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
869
870        // Block Row 1 (input rows y_start+4..y_start+8)
871        let inp = input.add((y_start + 4) * width + x_start);
872        let out = output.add(x_start * height + y_start + 4);
873        transpose_4x4_neon(inp, out, width, height);
874        transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
875        transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
876        transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
877
878        // Block Row 2 (input rows y_start+8..y_start+12)
879        let inp = input.add((y_start + 8) * width + x_start);
880        let out = output.add(x_start * height + y_start + 8);
881        transpose_4x4_neon(inp, out, width, height);
882        transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
883        transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
884        transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
885
886        // Block Row 3 (input rows y_start+12..y_start+16)
887        let inp = input.add((y_start + 12) * width + x_start);
888        let out = output.add(x_start * height + y_start + 12);
889        transpose_4x4_neon(inp, out, width, height);
890        transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
891        transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
892        transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
893    }
894}
895
896/// Transpose a complete 16×16 tile using NEON SIMD with L1 buffering.
897///
898/// Same grid of 4×4 NEON blocks, but transposed into a stack-allocated buffer
899/// first, then flushed to the output with write prefetching (`PRFM PSTL1KEEP`).
900///
901/// Used by the **recursive/parallel path** for large matrices (≥ `MEDIUM_LEN`)
902/// where the output is in L3/RAM and direct scattered writes would stall on
903/// Read-For-Ownership (RFO) cache line fetches.
904///
905/// # Safety
906///
907/// Caller must ensure:
908/// - Valid pointers for the full matrix
909/// - `x_start + 16 <= width`
910/// - `y_start + 16 <= height`
911#[cfg(target_arch = "aarch64")]
912#[inline]
913unsafe fn transpose_tile_16x16_neon_buffered(
914    input: *const u32,
915    output: *mut u32,
916    width: usize,
917    height: usize,
918    x_start: usize,
919    y_start: usize,
920) {
921    // Stack buffer for L1-hot transpose (1 KB for u32).
922    // MaybeUninit avoids unnecessary zero-initialization; every element
923    // is written by the NEON blocks before the copy reads it.
924    let mut buffer = MaybeUninit::<[u32; TILE_SIZE * TILE_SIZE]>::uninit();
925    let buf = buffer.as_mut_ptr().cast::<u32>();
926
927    unsafe {
928        // Transpose 4×4 grid of NEON blocks into the buffer.
929        // Buffer layout: buf[col * TILE_SIZE + row] = transposed element.
930        // Buffer write stride is TILE_SIZE (contiguous in L1) vs. `height` (scattered).
931
932        // Block Row 0 (input rows y_start..y_start+4)
933        let inp = input.add(y_start * width + x_start);
934        transpose_4x4_neon(inp, buf, width, TILE_SIZE);
935        transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE), width, TILE_SIZE);
936        transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE), width, TILE_SIZE);
937        transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE), width, TILE_SIZE);
938
939        // Block Row 1 (input rows y_start+4..y_start+8)
940        let inp = input.add((y_start + 4) * width + x_start);
941        transpose_4x4_neon(inp, buf.add(4), width, TILE_SIZE);
942        transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 4), width, TILE_SIZE);
943        transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 4), width, TILE_SIZE);
944        transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 4), width, TILE_SIZE);
945
946        // Block Row 2 (input rows y_start+8..y_start+12)
947        let inp = input.add((y_start + 8) * width + x_start);
948        transpose_4x4_neon(inp, buf.add(8), width, TILE_SIZE);
949        transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 8), width, TILE_SIZE);
950        transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 8), width, TILE_SIZE);
951        transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 8), width, TILE_SIZE);
952
953        // Block Row 3 (input rows y_start+12..y_start+16)
954        let inp = input.add((y_start + 12) * width + x_start);
955        transpose_4x4_neon(inp, buf.add(12), width, TILE_SIZE);
956        transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 12), width, TILE_SIZE);
957        transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 12), width, TILE_SIZE);
958        transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 12), width, TILE_SIZE);
959
960        // Flush buffer to output with write prefetching.
961        // Each iteration copies TILE_SIZE u32s (64 bytes = 1 cache line) from the
962        // L1-hot buffer to one output row. Prefetch brings the next output cache
963        // line into exclusive state, avoiding RFO stalls.
964        prefetch_write(output.add(x_start * height + y_start) as *const u8);
965        for c in 0..TILE_SIZE {
966            if c + 1 < TILE_SIZE {
967                prefetch_write(output.add((x_start + c + 1) * height + y_start) as *const u8);
968            }
969            core::ptr::copy_nonoverlapping(
970                buf.add(c * TILE_SIZE),
971                output.add((x_start + c) * height + y_start),
972                TILE_SIZE,
973            );
974        }
975    }
976}
977
978/// Scalar transpose for an arbitrary rectangular block.
979///
980/// Used for handling edge cases where dimensions don't align to tile boundaries.
981/// Falls back to simple element-by-element copying.
982///
983/// # When Used
984///
985/// - Right edge: `block_width < TILE_SIZE`
986/// - Bottom edge: `block_height < TILE_SIZE`
987/// - Bottom-right corner: both dimensions < `TILE_SIZE`
988///
989/// # Safety
990///
991/// Caller must ensure:
992/// - Valid pointers for the full matrix
993/// - `x_start + block_width <= width`
994/// - `y_start + block_height <= height`
995#[cfg(target_arch = "aarch64")]
996#[inline]
997#[allow(clippy::too_many_arguments)]
998unsafe fn transpose_block_scalar(
999    input: *const u32,
1000    output: *mut u32,
1001    width: usize,
1002    height: usize,
1003    x_start: usize,
1004    y_start: usize,
1005    block_width: usize,
1006    block_height: usize,
1007) {
1008    // Iterate over block columns (becomes output rows).
1009    for inner_x in 0..block_width {
1010        // Iterate over block rows (becomes output columns).
1011        for inner_y in 0..block_height {
1012            // Absolute coordinates in the original matrix.
1013            let x = x_start + inner_x;
1014            let y = y_start + inner_y;
1015
1016            // Input index: row-major position of (y, x).
1017            let input_index = x + y * width;
1018
1019            // Output index: row-major position of (x, y) in transposed matrix.
1020            let output_index = y + x * height;
1021
1022            // SAFETY: Indices are within bounds by construction.
1023            unsafe {
1024                *output.add(output_index) = *input.add(input_index);
1025            }
1026        }
1027    }
1028}
1029
1030/// Transpose a 4×4 block of 32-bit elements using NEON SIMD.
1031///
1032/// This is the fundamental building block of the entire transpose algorithm.
1033///
1034/// It transposes a 4×4 block entirely within NEON registers
1035/// using a two-stage butterfly network.
1036///
1037/// # Memory Layout
1038///
1039/// Input (4 rows, stride = `src_stride`):
1040/// ```text
1041///     src + 0*stride:  [ a00, a01, a02, a03 ]  → q0
1042///     src + 1*stride:  [ a10, a11, a12, a13 ]  → q1
1043///     src + 2*stride:  [ a20, a21, a22, a23 ]  → q2
1044///     src + 3*stride:  [ a30, a31, a32, a33 ]  → q3
1045/// ```
1046///
1047/// Output (4 rows, stride = `dst_stride`):
1048/// ```text
1049///     dst + 0*stride:  [ a00, a10, a20, a30 ]  ← r0
1050///     dst + 1*stride:  [ a01, a11, a21, a31 ]  ← r1
1051///     dst + 2*stride:  [ a02, a12, a22, a32 ]  ← r2
1052///     dst + 3*stride:  [ a03, a13, a23, a33 ]  ← r3
1053/// ```
1054///
1055/// # Butterfly Network Algorithm
1056///
1057/// The transpose is performed in two stages using `TRN1`/`TRN2` instructions:
1058///
1059/// ## Stage 1: 32-bit Transpose
1060///
1061/// - `TRN1` takes **even-indexed** elements,
1062/// - `TRN2` takes **odd-indexed** elements.
1063///
1064/// ```text
1065///     TRN1(q0, q1) = [ a00, a10, a02, a12 ]  (even indices: 0, 2)
1066///     TRN2(q0, q1) = [ a01, a11, a03, a13 ]  (odd indices: 1, 3)
1067///     TRN1(q2, q3) = [ a20, a30, a22, a32 ]
1068///     TRN2(q2, q3) = [ a21, a31, a23, a33 ]
1069/// ```
1070///
1071/// ## Stage 2: 64-bit Transpose
1072///
1073/// Reinterpret as 64-bit elements and transpose again:
1074///
1075/// ```text
1076///     TRN1_64(t0, t2) = [ a00, a10 | a20, a30 ]  → r0
1077///     TRN2_64(t0, t2) = [ a02, a12 | a22, a32 ]  → r2
1078///     TRN1_64(t1, t3) = [ a01, a11 | a21, a31 ]  → r1
1079///     TRN2_64(t1, t3) = [ a03, a13 | a23, a33 ]  → r3
1080/// ```
1081///
1082/// # Explanations
1083///
1084/// The butterfly network swaps elements at progressively larger distances:
1085/// - Stage 1: Swaps elements 1 apart (within 64-bit pairs)
1086/// - Stage 2: Swaps elements 2 apart (between 64-bit halves)
1087///
1088/// This is analogous to the bit-reversal pattern in FFT algorithms.
1089///
1090/// # Performance
1091///
1092/// - **4 loads** (vld1q_u32): 4 cycles
1093/// - **8 permutes** (vtrn): ~8 cycles (pipelined)
1094/// - **4 stores** (vst1q_u32): 4 cycles
1095/// - **Total**: ~16 cycles for 16 elements = **1 cycle/element**
1096///
1097/// # Safety
1098///
1099/// Caller must ensure:
1100/// - `src` is valid for reading 4 rows of `src_stride` elements each
1101/// - `dst` is valid for writing 4 rows of `dst_stride` elements each
1102/// - The first 4 elements of each row are accessible
1103#[cfg(target_arch = "aarch64")]
1104#[inline(always)]
1105unsafe fn transpose_4x4_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
1106    unsafe {
1107        // Phase 1: Load 4 rows into NEON registers
1108        //
1109        // Each vld1q_u32 loads 4 consecutive u32s (16 bytes = 128 bits).
1110        // Total: 64 bytes = one cache line on most ARM64 CPUs.
1111
1112        // Row 0: [a00, a01, a02, a03]
1113        let q0 = vld1q_u32(src);
1114        // Row 1: [a10, a11, a12, a13]
1115        let q1 = vld1q_u32(src.add(src_stride));
1116        // Row 2: [a20, a21, a22, a23]
1117        let q2 = vld1q_u32(src.add(2 * src_stride));
1118        // Row 3: [a30, a31, a32, a33]
1119        let q3 = vld1q_u32(src.add(3 * src_stride));
1120
1121        // Phase 2: Stage 1 - Transpose 2×2 blocks of 32-bit elements
1122        //
1123        // vtrn1q_u32(a, b): Takes elements at even indices from a and b
1124        // - Result: [a[0], b[0], a[2], b[2]]
1125        //
1126        // vtrn2q_u32(a, b): Takes elements at odd indices from a and b
1127        // - Result: [a[1], b[1], a[3], b[3]]
1128
1129        let t0_0 = vtrn1q_u32(q0, q1); // [a00, a10, a02, a12]
1130        let t0_1 = vtrn2q_u32(q0, q1); // [a01, a11, a03, a13]
1131        let t0_2 = vtrn1q_u32(q2, q3); // [a20, a30, a22, a32]
1132        let t0_3 = vtrn2q_u32(q2, q3); // [a21, a31, a23, a33]
1133
1134        // Phase 3: Stage 2 - Transpose 2×2 blocks of 64-bit elements
1135        //
1136        // Reinterpret u32x4 as u64x2, then transpose.
1137        // This swaps the 64-bit halves of the vectors.
1138        //
1139        // vtrn1q_u64(a, b): [a.lo, b.lo]
1140        // vtrn2q_u64(a, b): [a.hi, b.hi]
1141
1142        // r0 = [a00, a10, a20, a30] (column 0 of input → row 0 of output)
1143        let r0 = vreinterpretq_u32_u64(vtrn1q_u64(
1144            vreinterpretq_u64_u32(t0_0),
1145            vreinterpretq_u64_u32(t0_2),
1146        ));
1147
1148        // r2 = [a02, a12, a22, a32] (column 2 of input → row 2 of output)
1149        let r2 = vreinterpretq_u32_u64(vtrn2q_u64(
1150            vreinterpretq_u64_u32(t0_0),
1151            vreinterpretq_u64_u32(t0_2),
1152        ));
1153
1154        // r1 = [a01, a11, a21, a31] (column 1 of input → row 1 of output)
1155        let r1 = vreinterpretq_u32_u64(vtrn1q_u64(
1156            vreinterpretq_u64_u32(t0_1),
1157            vreinterpretq_u64_u32(t0_3),
1158        ));
1159
1160        // r3 = [a03, a13, a23, a33] (column 3 of input → row 3 of output)
1161        let r3 = vreinterpretq_u32_u64(vtrn2q_u64(
1162            vreinterpretq_u64_u32(t0_1),
1163            vreinterpretq_u64_u32(t0_3),
1164        ));
1165
1166        // Phase 4: Store 4 transposed rows
1167        //
1168        // Store row 0 of output
1169        vst1q_u32(dst, r0);
1170        // Store row 1 of output
1171        vst1q_u32(dst.add(dst_stride), r1);
1172        // Store row 2 of output
1173        vst1q_u32(dst.add(2 * dst_stride), r2);
1174        // Store row 3 of output
1175        vst1q_u32(dst.add(3 * dst_stride), r3);
1176    }
1177}
1178
1179// ============================================================================
1180// 8-byte (u64) NEON transpose functions
1181//
1182// These are analogous to the 4-byte functions above, but operate on u64
1183// elements. Since a 128-bit NEON register holds 2 u64 elements, each row
1184// of a 4×4 block requires 2 registers (8 total for a block). The transpose
1185// uses a single-stage butterfly with vtrn1q_u64/vtrn2q_u64 on four 2×2
1186// sub-blocks.
1187// ============================================================================
1188
1189/// Top-level NEON transpose dispatcher for 8-byte elements.
1190///
1191/// Selects the appropriate strategy based on matrix size, mirroring
1192/// `transpose_neon_4b` but for u64 elements.
1193///
1194/// # Safety
1195///
1196/// Caller must ensure `input` and `output` point to valid memory regions
1197/// of at least `width * height` elements each.
1198#[cfg(target_arch = "aarch64")]
1199#[inline]
1200unsafe fn transpose_neon_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
1201    let len = width * height;
1202
1203    #[cfg(feature = "parallel")]
1204    {
1205        if len >= PARALLEL_THRESHOLD {
1206            unsafe {
1207                transpose_neon_8b_parallel(input, output, width, height);
1208            }
1209            return;
1210        }
1211    }
1212
1213    if len <= SMALL_LEN {
1214        unsafe {
1215            transpose_small_8b(input, output, width, height);
1216        }
1217    } else if len <= MEDIUM_LEN {
1218        unsafe {
1219            transpose_tiled_8b(input, output, width, height);
1220        }
1221    } else {
1222        unsafe {
1223            transpose_recursive_8b(input, output, 0, height, 0, width, width, height);
1224        }
1225    }
1226}
1227
1228/// Parallel transpose for very large matrices of 8-byte elements.
1229///
1230/// Divides the matrix into horizontal stripes, one per thread.
1231///
1232/// # Safety
1233///
1234/// Caller must ensure valid pointers for `width * height` elements.
1235#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
1236#[inline]
1237unsafe fn transpose_neon_8b_parallel(
1238    input: *const u64,
1239    output: *mut u64,
1240    width: usize,
1241    height: usize,
1242) {
1243    use rayon::prelude::*;
1244
1245    let num_threads = rayon::current_num_threads();
1246    let rows_per_thread = height.div_ceil(num_threads);
1247
1248    let inp = AtomicUsize::new(input as usize);
1249    let out = AtomicUsize::new(output as usize);
1250
1251    (0..num_threads).into_par_iter().for_each(|thread_idx| {
1252        let row_start = thread_idx * rows_per_thread;
1253        let row_end = (row_start + rows_per_thread).min(height);
1254
1255        if row_start < row_end {
1256            let input_ptr = inp.load(Ordering::Relaxed) as *const u64;
1257            let output_ptr = out.load(Ordering::Relaxed) as *mut u64;
1258
1259            unsafe {
1260                transpose_region_tiled_8b(
1261                    input_ptr, output_ptr, row_start, row_end, 0, width, width, height,
1262                );
1263            }
1264        }
1265    });
1266}
1267
1268/// Simple element-by-element transpose for small matrices of 8-byte elements.
1269///
1270/// # Safety
1271///
1272/// Caller must ensure valid pointers for `width * height` elements.
1273#[cfg(target_arch = "aarch64")]
1274#[inline]
1275unsafe fn transpose_small_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
1276    for x in 0..width {
1277        for y in 0..height {
1278            let input_index = x + y * width;
1279            let output_index = y + x * height;
1280
1281            unsafe {
1282                *output.add(output_index) = *input.add(input_index);
1283            }
1284        }
1285    }
1286}
1287
1288/// Tiled transpose using 16×16 tiles for 8-byte elements.
1289///
1290/// # Safety
1291///
1292/// Caller must ensure valid pointers for `width * height` elements.
1293#[cfg(target_arch = "aarch64")]
1294#[inline]
1295unsafe fn transpose_tiled_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
1296    let x_tile_count = width / TILE_SIZE;
1297    let y_tile_count = height / TILE_SIZE;
1298
1299    let remainder_x = width - x_tile_count * TILE_SIZE;
1300    let remainder_y = height - y_tile_count * TILE_SIZE;
1301
1302    // Process complete tiles
1303    for y_tile in 0..y_tile_count {
1304        for x_tile in 0..x_tile_count {
1305            let x_start = x_tile * TILE_SIZE;
1306            let y_start = y_tile * TILE_SIZE;
1307
1308            unsafe {
1309                transpose_tile_16x16_neon_8b(input, output, width, height, x_start, y_start);
1310            }
1311        }
1312
1313        // Right edge remainder
1314        if remainder_x > 0 {
1315            unsafe {
1316                transpose_block_scalar_8b(
1317                    input,
1318                    output,
1319                    width,
1320                    height,
1321                    x_tile_count * TILE_SIZE,
1322                    y_tile * TILE_SIZE,
1323                    remainder_x,
1324                    TILE_SIZE,
1325                );
1326            }
1327        }
1328    }
1329
1330    // Bottom edge remainder
1331    if remainder_y > 0 {
1332        for x_tile in 0..x_tile_count {
1333            unsafe {
1334                transpose_block_scalar_8b(
1335                    input,
1336                    output,
1337                    width,
1338                    height,
1339                    x_tile * TILE_SIZE,
1340                    y_tile_count * TILE_SIZE,
1341                    TILE_SIZE,
1342                    remainder_y,
1343                );
1344            }
1345        }
1346
1347        // Bottom-right corner
1348        if remainder_x > 0 {
1349            unsafe {
1350                transpose_block_scalar_8b(
1351                    input,
1352                    output,
1353                    width,
1354                    height,
1355                    x_tile_count * TILE_SIZE,
1356                    y_tile_count * TILE_SIZE,
1357                    remainder_x,
1358                    remainder_y,
1359                );
1360            }
1361        }
1362    }
1363}
1364
1365/// Recursive cache-oblivious transpose for large matrices of 8-byte elements.
1366///
1367/// # Safety
1368///
1369/// Caller must ensure valid pointers and that coordinate ranges are within bounds.
1370#[cfg(target_arch = "aarch64")]
1371#[allow(clippy::too_many_arguments)]
1372unsafe fn transpose_recursive_8b(
1373    input: *const u64,
1374    output: *mut u64,
1375    row_start: usize,
1376    row_end: usize,
1377    col_start: usize,
1378    col_end: usize,
1379    total_cols: usize,
1380    total_rows: usize,
1381) {
1382    let nbr_rows = row_end - row_start;
1383    let nbr_cols = col_end - col_start;
1384
1385    if (nbr_rows <= RECURSIVE_LIMIT && nbr_cols <= RECURSIVE_LIMIT)
1386        || nbr_rows <= 2
1387        || nbr_cols <= 2
1388    {
1389        unsafe {
1390            transpose_region_tiled_8b(
1391                input, output, row_start, row_end, col_start, col_end, total_cols, total_rows,
1392            );
1393        }
1394        return;
1395    }
1396
1397    if nbr_rows >= nbr_cols {
1398        let mid = row_start + (nbr_rows / 2);
1399
1400        unsafe {
1401            transpose_recursive_8b(
1402                input, output, row_start, mid, col_start, col_end, total_cols, total_rows,
1403            );
1404        }
1405
1406        unsafe {
1407            transpose_recursive_8b(
1408                input, output, mid, row_end, col_start, col_end, total_cols, total_rows,
1409            );
1410        }
1411    } else {
1412        let mid = col_start + (nbr_cols / 2);
1413
1414        unsafe {
1415            transpose_recursive_8b(
1416                input, output, row_start, row_end, col_start, mid, total_cols, total_rows,
1417            );
1418        }
1419
1420        unsafe {
1421            transpose_recursive_8b(
1422                input, output, row_start, row_end, mid, col_end, total_cols, total_rows,
1423            );
1424        }
1425    }
1426}
1427
1428/// Tiled transpose for a rectangular region of 8-byte elements.
1429///
1430/// Used as the base case of recursive transpose and for parallel stripe processing.
1431///
1432/// # Safety
1433///
1434/// Caller must ensure:
1435/// - Valid pointers for `total_cols * total_rows` elements
1436/// - `row_start < row_end <= total_rows`
1437/// - `col_start < col_end <= total_cols`
1438#[cfg(target_arch = "aarch64")]
1439#[inline]
1440#[allow(clippy::too_many_arguments)]
1441unsafe fn transpose_region_tiled_8b(
1442    input: *const u64,
1443    output: *mut u64,
1444    row_start: usize,
1445    row_end: usize,
1446    col_start: usize,
1447    col_end: usize,
1448    total_cols: usize,
1449    total_rows: usize,
1450) {
1451    let nbr_cols = col_end - col_start;
1452    let nbr_rows = row_end - row_start;
1453
1454    let x_tile_count = nbr_cols / TILE_SIZE;
1455    let y_tile_count = nbr_rows / TILE_SIZE;
1456
1457    let remainder_x = nbr_cols - x_tile_count * TILE_SIZE;
1458    let remainder_y = nbr_rows - y_tile_count * TILE_SIZE;
1459
1460    // Process complete tiles
1461    for y_tile in 0..y_tile_count {
1462        for x_tile in 0..x_tile_count {
1463            let col = col_start + x_tile * TILE_SIZE;
1464            let row = row_start + y_tile * TILE_SIZE;
1465
1466            // Uses the buffered tile function: for large matrices the output
1467            // is likely in L3/RAM, so L1 buffering + write prefetching avoids
1468            // RFO stalls on scattered output writes.
1469            unsafe {
1470                transpose_tile_16x16_neon_8b_buffered(
1471                    input, output, total_cols, total_rows, col, row,
1472                );
1473            }
1474        }
1475
1476        // Right edge remainder
1477        if remainder_x > 0 {
1478            unsafe {
1479                transpose_block_scalar_8b(
1480                    input,
1481                    output,
1482                    total_cols,
1483                    total_rows,
1484                    col_start + x_tile_count * TILE_SIZE,
1485                    row_start + y_tile * TILE_SIZE,
1486                    remainder_x,
1487                    TILE_SIZE,
1488                );
1489            }
1490        }
1491    }
1492
1493    // Bottom edge remainder
1494    if remainder_y > 0 {
1495        for x_tile in 0..x_tile_count {
1496            unsafe {
1497                transpose_block_scalar_8b(
1498                    input,
1499                    output,
1500                    total_cols,
1501                    total_rows,
1502                    col_start + x_tile * TILE_SIZE,
1503                    row_start + y_tile_count * TILE_SIZE,
1504                    TILE_SIZE,
1505                    remainder_y,
1506                );
1507            }
1508        }
1509
1510        // Bottom-right corner
1511        if remainder_x > 0 {
1512            unsafe {
1513                transpose_block_scalar_8b(
1514                    input,
1515                    output,
1516                    total_cols,
1517                    total_rows,
1518                    col_start + x_tile_count * TILE_SIZE,
1519                    row_start + y_tile_count * TILE_SIZE,
1520                    remainder_x,
1521                    remainder_y,
1522                );
1523            }
1524        }
1525    }
1526}
1527
1528/// Transpose a complete 16×16 tile of 8-byte elements using NEON SIMD (direct-to-output).
1529///
1530/// Used by the **medium tiled path** where the output likely fits in L2 cache.
1531///
1532/// # Safety
1533///
1534/// Caller must ensure:
1535/// - Valid pointers for the full matrix
1536/// - `x_start + 16 <= width`
1537/// - `y_start + 16 <= height`
1538#[cfg(target_arch = "aarch64")]
1539#[inline]
1540unsafe fn transpose_tile_16x16_neon_8b(
1541    input: *const u64,
1542    output: *mut u64,
1543    width: usize,
1544    height: usize,
1545    x_start: usize,
1546    y_start: usize,
1547) {
1548    unsafe {
1549        // Block Row 0 (input rows y_start..y_start+4)
1550        let inp = input.add(y_start * width + x_start);
1551        let out = output.add(x_start * height + y_start);
1552        transpose_4x4_neon_8b(inp, out, width, height);
1553        transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1554        transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1555        transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1556
1557        // Block Row 1 (input rows y_start+4..y_start+8)
1558        let inp = input.add((y_start + 4) * width + x_start);
1559        let out = output.add(x_start * height + y_start + 4);
1560        transpose_4x4_neon_8b(inp, out, width, height);
1561        transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1562        transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1563        transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1564
1565        // Block Row 2 (input rows y_start+8..y_start+12)
1566        let inp = input.add((y_start + 8) * width + x_start);
1567        let out = output.add(x_start * height + y_start + 8);
1568        transpose_4x4_neon_8b(inp, out, width, height);
1569        transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1570        transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1571        transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1572
1573        // Block Row 3 (input rows y_start+12..y_start+16)
1574        let inp = input.add((y_start + 12) * width + x_start);
1575        let out = output.add(x_start * height + y_start + 12);
1576        transpose_4x4_neon_8b(inp, out, width, height);
1577        transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
1578        transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
1579        transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
1580    }
1581}
1582
1583/// Transpose a complete 16×16 tile of 8-byte elements with L1 buffering.
1584///
1585/// Used by the **recursive/parallel path** for large matrices where the output
1586/// is in L3/RAM. L1 buffering + write prefetching avoids RFO stalls.
1587///
1588/// # Safety
1589///
1590/// Caller must ensure:
1591/// - Valid pointers for the full matrix
1592/// - `x_start + 16 <= width`
1593/// - `y_start + 16 <= height`
1594#[cfg(target_arch = "aarch64")]
1595#[inline]
1596unsafe fn transpose_tile_16x16_neon_8b_buffered(
1597    input: *const u64,
1598    output: *mut u64,
1599    width: usize,
1600    height: usize,
1601    x_start: usize,
1602    y_start: usize,
1603) {
1604    // Stack buffer for L1-hot transpose (2 KB for u64).
1605    let mut buffer = MaybeUninit::<[u64; TILE_SIZE * TILE_SIZE]>::uninit();
1606    let buf = buffer.as_mut_ptr().cast::<u64>();
1607
1608    unsafe {
1609        // Transpose 4×4 grid of NEON blocks into the buffer.
1610
1611        // Block Row 0 (input rows y_start..y_start+4)
1612        let inp = input.add(y_start * width + x_start);
1613        transpose_4x4_neon_8b(inp, buf, width, TILE_SIZE);
1614        transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE), width, TILE_SIZE);
1615        transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE), width, TILE_SIZE);
1616        transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE), width, TILE_SIZE);
1617
1618        // Block Row 1 (input rows y_start+4..y_start+8)
1619        let inp = input.add((y_start + 4) * width + x_start);
1620        transpose_4x4_neon_8b(inp, buf.add(4), width, TILE_SIZE);
1621        transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 4), width, TILE_SIZE);
1622        transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 4), width, TILE_SIZE);
1623        transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 4), width, TILE_SIZE);
1624
1625        // Block Row 2 (input rows y_start+8..y_start+12)
1626        let inp = input.add((y_start + 8) * width + x_start);
1627        transpose_4x4_neon_8b(inp, buf.add(8), width, TILE_SIZE);
1628        transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 8), width, TILE_SIZE);
1629        transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 8), width, TILE_SIZE);
1630        transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 8), width, TILE_SIZE);
1631
1632        // Block Row 3 (input rows y_start+12..y_start+16)
1633        let inp = input.add((y_start + 12) * width + x_start);
1634        transpose_4x4_neon_8b(inp, buf.add(12), width, TILE_SIZE);
1635        transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 12), width, TILE_SIZE);
1636        transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 12), width, TILE_SIZE);
1637        transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 12), width, TILE_SIZE);
1638
1639        // Flush buffer to output with write prefetching.
1640        prefetch_write(output.add(x_start * height + y_start) as *const u8);
1641        for c in 0..TILE_SIZE {
1642            if c + 1 < TILE_SIZE {
1643                prefetch_write(output.add((x_start + c + 1) * height + y_start) as *const u8);
1644            }
1645            core::ptr::copy_nonoverlapping(
1646                buf.add(c * TILE_SIZE),
1647                output.add((x_start + c) * height + y_start),
1648                TILE_SIZE,
1649            );
1650        }
1651    }
1652}
1653
1654/// Scalar transpose for an arbitrary rectangular block of 8-byte elements.
1655///
1656/// Used for handling edge cases where dimensions don't align to tile boundaries.
1657///
1658/// # Safety
1659///
1660/// Caller must ensure:
1661/// - Valid pointers for the full matrix
1662/// - `x_start + block_width <= width`
1663/// - `y_start + block_height <= height`
1664#[cfg(target_arch = "aarch64")]
1665#[inline]
1666#[allow(clippy::too_many_arguments)]
1667unsafe fn transpose_block_scalar_8b(
1668    input: *const u64,
1669    output: *mut u64,
1670    width: usize,
1671    height: usize,
1672    x_start: usize,
1673    y_start: usize,
1674    block_width: usize,
1675    block_height: usize,
1676) {
1677    for inner_x in 0..block_width {
1678        for inner_y in 0..block_height {
1679            let x = x_start + inner_x;
1680            let y = y_start + inner_y;
1681
1682            let input_index = x + y * width;
1683            let output_index = y + x * height;
1684
1685            unsafe {
1686                *output.add(output_index) = *input.add(input_index);
1687            }
1688        }
1689    }
1690}
1691
1692/// Transpose a 4×4 block of 64-bit elements using NEON SIMD.
1693///
1694/// This is the fundamental building block for 8-byte element transpose.
1695///
1696/// Since a 128-bit NEON register holds only 2 u64 elements, each row of 4
1697/// elements requires 2 registers. A 4×4 block uses 8 registers for input
1698/// and 8 for output (16 total, well within NEON's 32 registers).
1699///
1700/// # Algorithm
1701///
1702/// The transpose uses a single-stage butterfly on four independent 2×2
1703/// sub-blocks:
1704///
1705/// ```text
1706///     Load:  q0_lo=[a00,a01] q0_hi=[a02,a03]  (row 0)
1707///            q1_lo=[a10,a11] q1_hi=[a12,a13]  (row 1)
1708///            q2_lo=[a20,a21] q2_hi=[a22,a23]  (row 2)
1709///            q3_lo=[a30,a31] q3_hi=[a32,a33]  (row 3)
1710///
1711///     Transpose 2×2 sub-blocks:
1712///       Top-left:     trn1(q0_lo,q1_lo)=[a00,a10]  trn2(q0_lo,q1_lo)=[a01,a11]
1713///       Top-right:    trn1(q0_hi,q1_hi)=[a02,a12]  trn2(q0_hi,q1_hi)=[a03,a13]
1714///       Bottom-left:  trn1(q2_lo,q3_lo)=[a20,a30]  trn2(q2_lo,q3_lo)=[a21,a31]
1715///       Bottom-right: trn1(q2_hi,q3_hi)=[a22,a32]  trn2(q2_hi,q3_hi)=[a23,a33]
1716///
1717///     Store: row0=[a00,a10,a20,a30]  row1=[a01,a11,a21,a31]
1718///            row2=[a02,a12,a22,a32]  row3=[a03,a13,a23,a33]
1719/// ```
1720///
1721/// # Safety
1722///
1723/// Caller must ensure:
1724/// - `src` is valid for reading 4 rows of `src_stride` elements each
1725/// - `dst` is valid for writing 4 rows of `dst_stride` elements each
1726/// - The first 4 elements of each row are accessible
1727#[cfg(target_arch = "aarch64")]
1728#[inline(always)]
1729unsafe fn transpose_4x4_neon_8b(
1730    src: *const u64,
1731    dst: *mut u64,
1732    src_stride: usize,
1733    dst_stride: usize,
1734) {
1735    unsafe {
1736        // Load 4 rows, 2 registers per row (4 u64 = 2 × 128-bit)
1737
1738        // Row 0: [a00, a01] [a02, a03]
1739        let q0_lo = vld1q_u64(src);
1740        let q0_hi = vld1q_u64(src.add(2));
1741        // Row 1: [a10, a11] [a12, a13]
1742        let q1_lo = vld1q_u64(src.add(src_stride));
1743        let q1_hi = vld1q_u64(src.add(src_stride + 2));
1744        // Row 2: [a20, a21] [a22, a23]
1745        let q2_lo = vld1q_u64(src.add(2 * src_stride));
1746        let q2_hi = vld1q_u64(src.add(2 * src_stride + 2));
1747        // Row 3: [a30, a31] [a32, a33]
1748        let q3_lo = vld1q_u64(src.add(3 * src_stride));
1749        let q3_hi = vld1q_u64(src.add(3 * src_stride + 2));
1750
1751        // Transpose four 2×2 sub-blocks using vtrn1q_u64/vtrn2q_u64
1752
1753        // Top-left: rows 0,1 × columns 0,1
1754        let r0_lo = vtrn1q_u64(q0_lo, q1_lo); // [a00, a10]
1755        let r1_lo = vtrn2q_u64(q0_lo, q1_lo); // [a01, a11]
1756        // Top-right: rows 0,1 × columns 2,3
1757        let r2_lo = vtrn1q_u64(q0_hi, q1_hi); // [a02, a12]
1758        let r3_lo = vtrn2q_u64(q0_hi, q1_hi); // [a03, a13]
1759        // Bottom-left: rows 2,3 × columns 0,1
1760        let r0_hi = vtrn1q_u64(q2_lo, q3_lo); // [a20, a30]
1761        let r1_hi = vtrn2q_u64(q2_lo, q3_lo); // [a21, a31]
1762        // Bottom-right: rows 2,3 × columns 2,3
1763        let r2_hi = vtrn1q_u64(q2_hi, q3_hi); // [a22, a32]
1764        let r3_hi = vtrn2q_u64(q2_hi, q3_hi); // [a23, a33]
1765
1766        // Store 4 transposed rows, 2 registers per row
1767
1768        // Row 0: [a00, a10, a20, a30]
1769        vst1q_u64(dst, r0_lo);
1770        vst1q_u64(dst.add(2), r0_hi);
1771        // Row 1: [a01, a11, a21, a31]
1772        vst1q_u64(dst.add(dst_stride), r1_lo);
1773        vst1q_u64(dst.add(dst_stride + 2), r1_hi);
1774        // Row 2: [a02, a12, a22, a32]
1775        vst1q_u64(dst.add(2 * dst_stride), r2_lo);
1776        vst1q_u64(dst.add(2 * dst_stride + 2), r2_hi);
1777        // Row 3: [a03, a13, a23, a33]
1778        vst1q_u64(dst.add(3 * dst_stride), r3_lo);
1779        vst1q_u64(dst.add(3 * dst_stride + 2), r3_hi);
1780    }
1781}
1782
1783#[cfg(test)]
1784mod tests {
1785    use alloc::vec;
1786    use alloc::vec::Vec;
1787
1788    use p3_baby_bear::BabyBear;
1789    use p3_field::PrimeCharacteristicRing;
1790    use p3_goldilocks::Goldilocks;
1791    use proptest::prelude::*;
1792
1793    use super::*;
1794
1795    /// Naive reference implementation for correctness testing.
1796    fn transpose_reference<T: Copy + Default>(input: &[T], width: usize, height: usize) -> Vec<T> {
1797        // Allocate output buffer with same size as input.
1798        let mut output = vec![T::default(); width * height];
1799
1800        // For each position (x, y) in the input matrix:
1801        // - Input index: y * width + x (row-major)
1802        // - Output index: x * height + y (transposed row-major)
1803        for y in 0..height {
1804            for x in 0..width {
1805                output[x * height + y] = input[y * width + x];
1806            }
1807        }
1808
1809        output
1810    }
1811
1812    /// Strategy for generating matrix dimensions.
1813    fn dimension_strategy() -> impl Strategy<Value = (usize, usize)> {
1814        // Compute boundary dimensions from constants.
1815        // `small_side` is the largest square that stays in the small (scalar) path.
1816        let small_side = (SMALL_LEN as f64).sqrt() as usize;
1817        // `medium_side` is the largest square that stays in the medium (tiled) path.
1818        let medium_side = (MEDIUM_LEN as f64).sqrt() as usize;
1819        // `large_side` is the side length that triggers the large (recursive) path.
1820        let large_side = medium_side + 1;
1821
1822        prop_oneof![
1823            // Edge cases: empty and degenerate matrices
1824            //
1825            // Empty matrix (0×0)
1826            Just((0, 0)),
1827            // Single row (1×n) - tests degenerate case
1828            (1..=100_usize).prop_map(|w| (w, 1)),
1829            // Single column (n×1) - tests degenerate case
1830            (1..=100_usize).prop_map(|h| (1, h)),
1831            // Small path: len < SMALL_LEN (scalar transpose)
1832            //
1833            // These dimensions exercise the scalar transpose path.
1834
1835            // Tiny matrices (various shapes within small threshold)
1836            (1..=small_side, 1..=small_side),
1837            // Medium path: SMALL_LEN ≤ len < MEDIUM_LEN (tiled TILE_SIZE×TILE_SIZE)
1838            //
1839            // These dimensions exercise the tiled TILE_SIZE×TILE_SIZE path.
1840
1841            // Exactly 4×4 (single NEON block)
1842            Just((4, 4)),
1843            // Exactly TILE_SIZE×TILE_SIZE (single tile)
1844            Just((TILE_SIZE, TILE_SIZE)),
1845            // Multiple complete tiles (2× and 4× TILE_SIZE)
1846            Just((TILE_SIZE * 2, TILE_SIZE * 2)),
1847            Just((TILE_SIZE * 4, TILE_SIZE * 4)),
1848            // Non-aligned: has remainders in both dimensions
1849            // Range from just above TILE_SIZE to below 4×TILE_SIZE.
1850            // These test the scalar fallback for tile edges.
1851            (
1852                (TILE_SIZE + 1)..=(TILE_SIZE * 4 - 1),
1853                (TILE_SIZE + 1)..=(TILE_SIZE * 4 - 1)
1854            ),
1855            // Wide rectangle with remainders (medium path)
1856            (50..=200_usize, 10..=50_usize),
1857            // Tall rectangle with remainders (medium path)
1858            (10..=50_usize, 50..=200_usize),
1859            // Large path: MEDIUM_LEN ≤ len < PARALLEL_THRESHOLD (recursive)
1860            //
1861            // These exercise the cache-oblivious recursive subdivision.
1862
1863            // Square matrices triggering recursion (just above medium threshold)
1864            Just((large_side, large_side)),
1865            // Slightly larger square
1866            Just((large_side + 100, large_side + 100)),
1867            // Wide rectangle triggering recursion
1868            Just((large_side * 2, large_side / 2)),
1869            // Tall rectangle triggering recursion
1870            Just((large_side / 2, large_side * 2)),
1871            // Non-power-of-2 dimensions in large range
1872            Just((large_side + 50, large_side + 75)),
1873        ]
1874    }
1875
1876    proptest! {
1877        #[test]
1878        fn proptest_transpose_babybear((width, height) in dimension_strategy()) {
1879            // Skip empty matrices (they're trivially correct).
1880            if width == 0 || height == 0 {
1881                // Just verify it doesn't panic.
1882                let input: [BabyBear; 0] = [];
1883                let mut output: [BabyBear; 0] = [];
1884                transpose(&input, &mut output, width, height);
1885                return Ok(());
1886            }
1887
1888            // Create input matrix with unique values at each position.
1889            let input: Vec<BabyBear> = (0..width * height)
1890                .map(|i| BabyBear::from_u64(i as u64))
1891                .collect();
1892
1893            // Allocate output buffer.
1894            let mut output = vec![BabyBear::ZERO; width * height];
1895
1896            // Run optimized transpose.
1897            transpose(&input, &mut output, width, height);
1898
1899            // Run reference transpose.
1900            let expected = transpose_reference(&input, width, height);
1901
1902            // Verify results match.
1903            prop_assert_eq!(
1904                output,
1905                expected,
1906                "Transpose mismatch for {}×{} matrix",
1907                width,
1908                height
1909            );
1910        }
1911
1912        #[test]
1913        fn proptest_transpose_u64((width, height) in dimension_strategy()) {
1914            // Skip empty and very large matrices for u64 (memory intensive).
1915            if width == 0 || height == 0 || width * height > 100_000 {
1916                return Ok(());
1917            }
1918
1919            // Create input with unique values.
1920            let input: Vec<u64> = (0..width * height).map(|i| i as u64).collect();
1921
1922            // Allocate output.
1923            let mut output = vec![0u64; width * height];
1924
1925            // Run transpose.
1926            transpose(&input, &mut output, width, height);
1927
1928            // Verify against reference.
1929            let expected = transpose_reference(&input, width, height);
1930            prop_assert_eq!(output, expected);
1931        }
1932
1933        #[test]
1934        fn proptest_transpose_u8((width, height) in dimension_strategy()) {
1935            // Skip empty and very large matrices.
1936            if width == 0 || height == 0 || width * height > 100_000 {
1937                return Ok(());
1938            }
1939
1940            // Create input with unique values (wrapping for u8).
1941            let input: Vec<u8> = (0..width * height).map(|i| i as u8).collect();
1942
1943            // Allocate output.
1944            let mut output = vec![0u8; width * height];
1945
1946            // Run transpose.
1947            transpose(&input, &mut output, width, height);
1948
1949            // Verify against reference.
1950            let expected = transpose_reference(&input, width, height);
1951            prop_assert_eq!(output, expected);
1952        }
1953
1954        #[test]
1955        fn proptest_transpose_goldilocks((width, height) in dimension_strategy()) {
1956            // Skip empty matrices.
1957            if width == 0 || height == 0 {
1958                let input: [Goldilocks; 0] = [];
1959                let mut output: [Goldilocks; 0] = [];
1960                transpose(&input, &mut output, width, height);
1961                return Ok(());
1962            }
1963
1964            // Create input matrix with unique values at each position.
1965            let input: Vec<Goldilocks> = (0..width * height)
1966                .map(|i| Goldilocks::from_u64(i as u64))
1967                .collect();
1968
1969            // Allocate output buffer.
1970            let mut output = vec![Goldilocks::ZERO; width * height];
1971
1972            // Run optimized transpose.
1973            transpose(&input, &mut output, width, height);
1974
1975            // Run reference transpose.
1976            let expected = transpose_reference(&input, width, height);
1977
1978            // Verify results match.
1979            prop_assert_eq!(
1980                output,
1981                expected,
1982                "Transpose mismatch for {}×{} matrix",
1983                width,
1984                height
1985            );
1986        }
1987    }
1988}