1use core::marker::PhantomData;
2
3use p3_field::{InjectiveMonomial, PrimeCharacteristicRing};
4use p3_poseidon2::{
5 ExternalLayer, GenericPoseidon2LinearLayers, InternalLayer, MDSMat4, add_rc_and_sbox_generic,
6 external_initial_permute_state, external_terminal_permute_state,
7};
8
9use crate::{
10 FieldParameters, MontyField31, MontyParameters, Poseidon2ExternalLayerMonty31,
11 Poseidon2InternalLayerMonty31, RelativelyPrimePower,
12};
13
14pub trait InternalLayerBaseParameters<MP: MontyParameters, const WIDTH: usize>:
19 Clone + Sync
20{
21 fn internal_layer_mat_mul<R: PrimeCharacteristicRing>(state: &mut [R; WIDTH], sum: R);
24
25 fn generic_internal_linear_layer<R: PrimeCharacteristicRing>(state: &mut [R; WIDTH]) {
28 let part_sum: R = state[1..].iter().cloned().sum();
30 let full_sum = part_sum.clone() + state[0].clone();
31 state[0] = part_sum - state[0].clone();
32 Self::internal_layer_mat_mul(state, full_sum);
33 }
34}
35
36#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
37pub trait InternalLayerParameters<FP: FieldParameters, const WIDTH: usize>:
38 InternalLayerBaseParameters<FP, WIDTH> + crate::InternalLayerParametersNeon<FP, WIDTH>
39{
40}
41#[cfg(all(
42 target_arch = "x86_64",
43 target_feature = "avx2",
44 not(target_feature = "avx512f")
45))]
46pub trait InternalLayerParameters<FP: FieldParameters, const WIDTH: usize>:
47 InternalLayerBaseParameters<FP, WIDTH> + crate::InternalLayerParametersAVX2<FP, WIDTH>
48{
49}
50#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
51pub trait InternalLayerParameters<FP: FieldParameters, const WIDTH: usize>:
52 InternalLayerBaseParameters<FP, WIDTH> + crate::InternalLayerParametersAVX512<FP, WIDTH>
53{
54}
55#[cfg(not(any(
56 all(target_arch = "aarch64", target_feature = "neon"),
57 all(
58 target_arch = "x86_64",
59 target_feature = "avx2",
60 not(target_feature = "avx512f")
61 ),
62 all(target_arch = "x86_64", target_feature = "avx512f"),
63)))]
64pub trait InternalLayerParameters<FP: FieldParameters, const WIDTH: usize>:
65 InternalLayerBaseParameters<FP, WIDTH>
66{
67}
68
69impl<FP, const WIDTH: usize, P2P, const D: u64> InternalLayer<MontyField31<FP>, WIDTH, D>
70 for Poseidon2InternalLayerMonty31<FP, WIDTH, P2P>
71where
72 FP: FieldParameters + RelativelyPrimePower<D>,
73 P2P: InternalLayerParameters<FP, WIDTH>,
74{
75 fn permute_state(&self, state: &mut [MontyField31<FP>; WIDTH]) {
77 self.internal_constants.iter().for_each(|rc| {
78 state[0] += *rc;
79 state[0] = state[0].injective_exp_n();
80 let part_sum: MontyField31<FP> = state[1..].iter().copied().sum();
81 let full_sum = part_sum + state[0];
82 state[0] = part_sum - state[0];
83 P2P::internal_layer_mat_mul(state, full_sum);
84 });
85 }
86}
87
88impl<FP, const WIDTH: usize, const D: u64> ExternalLayer<MontyField31<FP>, WIDTH, D>
89 for Poseidon2ExternalLayerMonty31<FP, WIDTH>
90where
91 FP: FieldParameters + RelativelyPrimePower<D>,
92{
93 fn permute_state_initial(&self, state: &mut [MontyField31<FP>; WIDTH]) {
95 external_initial_permute_state(
96 state,
97 self.external_constants.get_initial_constants(),
98 add_rc_and_sbox_generic,
99 &MDSMat4,
100 );
101 }
102
103 fn permute_state_terminal(&self, state: &mut [MontyField31<FP>; WIDTH]) {
105 external_terminal_permute_state(
106 state,
107 self.external_constants.get_terminal_constants(),
108 add_rc_and_sbox_generic,
109 &MDSMat4,
110 );
111 }
112}
113
114pub struct GenericPoseidon2LinearLayersMonty31<FP, ILBP> {
120 _phantom1: PhantomData<FP>,
121 _phantom2: PhantomData<ILBP>,
122}
123
124impl<FP, ILBP, const WIDTH: usize> GenericPoseidon2LinearLayers<WIDTH>
125 for GenericPoseidon2LinearLayersMonty31<FP, ILBP>
126where
127 FP: FieldParameters,
128 ILBP: InternalLayerBaseParameters<FP, WIDTH>,
129{
130 fn internal_linear_layer<R: PrimeCharacteristicRing>(state: &mut [R; WIDTH]) {
131 ILBP::generic_internal_linear_layer(state);
132 }
133}