Skip to main content

spongefish_circuit/
allocator.rs

1//! Defines the allocator and wires to be used for computing the key-derivation steps.
2
3use alloc::{sync::Arc, vec::Vec};
4use core::borrow::Borrow;
5
6use hashbrown::HashMap;
7use itertools::Itertools;
8use spin::RwLock;
9use spongefish::Unit;
10
11/// A symbolic wire over which we perform out computation.
12#[derive(Clone, Copy, Default, Hash, PartialEq, Eq)]
13pub struct FieldVar(usize);
14
15impl FieldVar {
16    /// Maximum number of variables supported by the circuit allocator.
17    pub const MAX_COUNT: usize = 1 << 30;
18    /// The distinguished zero variable.
19    pub const ZERO: Self = Self(0);
20
21    /// Return the variable index.
22    #[must_use]
23    pub const fn index(self) -> usize {
24        self.0
25    }
26
27    /// Construct a variable from an index when it is within the supported range.
28    #[must_use]
29    pub const fn try_from_index(index: usize) -> Option<Self> {
30        if index < Self::MAX_COUNT {
31            Some(Self(index))
32        } else {
33            None
34        }
35    }
36}
37
38impl Unit for FieldVar {
39    const ZERO: Self = Self::ZERO;
40}
41
42impl core::fmt::Debug for FieldVar {
43    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44        write!(f, "v({})", self.0)
45    }
46}
47
48/// Allocator for field variables.
49///
50/// Creates a new wire identifier when requested,
51/// and keeps tracks of the wires that have been declared as public.
52#[derive(Clone)]
53pub struct VarAllocator<T> {
54    state: Arc<RwLock<AllocatorState<T>>>,
55}
56
57struct AllocatorState<T> {
58    vars_count: usize,
59    public_values: HashMap<FieldVar, T>,
60}
61
62impl<T: Clone + Unit> Default for VarAllocator<T> {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl<T: Clone + Unit> VarAllocator<T> {
69    #[must_use]
70    pub fn new() -> Self {
71        let zero_var = FieldVar::ZERO;
72        let mut public_values = HashMap::new();
73        public_values.insert(zero_var, T::ZERO);
74        Self {
75            state: Arc::new(RwLock::new(AllocatorState {
76                vars_count: 1,
77                public_values,
78            })),
79        }
80    }
81
82    #[must_use]
83    pub fn new_field_var(&self) -> FieldVar {
84        let mut state = self.state.write();
85        assert!(
86            state.vars_count < FieldVar::MAX_COUNT,
87            "variable count exceeds supported maximum {}",
88            FieldVar::MAX_COUNT,
89        );
90        let var = FieldVar(state.vars_count);
91        state.vars_count += 1;
92        var
93    }
94
95    #[must_use]
96    pub fn allocate_vars<const N: usize>(&self) -> [FieldVar; N] {
97        let mut buf = [FieldVar::default(); N];
98        for x in &mut buf {
99            *x = self.new_field_var();
100        }
101        buf
102    }
103
104    #[must_use]
105    pub fn allocate_vars_vec(&self, count: usize) -> Vec<FieldVar> {
106        {
107            let state = self.state.read();
108            let new_count = state
109                .vars_count
110                .checked_add(count)
111                .expect("variable count overflow");
112            assert!(
113                new_count <= FieldVar::MAX_COUNT,
114                "variable count exceeds supported maximum {}",
115                FieldVar::MAX_COUNT,
116            );
117        }
118        (0..count).map(|_| self.new_field_var()).collect()
119    }
120
121    pub fn allocate_public<const N: usize>(&self, public_values: &[T; N]) -> [FieldVar; N] {
122        let vars = self.allocate_vars();
123        self.set_public_vars(vars, public_values);
124        vars
125    }
126
127    pub fn allocate_public_vec(&self, public_values: &[T]) -> Vec<FieldVar> {
128        let vars = self.allocate_vars_vec(public_values.len());
129        self.set_public_vars(vars.clone(), public_values);
130        vars
131    }
132
133    #[must_use]
134    pub fn vars_count(&self) -> usize {
135        self.state.read().vars_count
136    }
137
138    #[must_use]
139    pub fn is_allocated(&self, var: FieldVar) -> bool {
140        var.index() < self.vars_count()
141    }
142
143    /// Assigns the wire variable `var` to `val`.
144    ///
145    /// If the wire was already present, it is over-written.
146    pub fn set_public_var(&self, var: FieldVar, val: T) {
147        self.state.write().public_values.insert(var, val);
148    }
149
150    /// Sets a list of public variables.
151    ///
152    /// Takes as input two iterators (for wires and values respectively),
153    /// and sets each of them to public values.
154    ///
155    /// # Panics
156    ///
157    /// If the iterators have different length, this function will panic.
158    pub fn set_public_vars<Val, Var>(
159        &self,
160        vars: impl IntoIterator<Item = Var>,
161        vals: impl IntoIterator<Item = Val>,
162    ) where
163        Var: Borrow<FieldVar>,
164        Val: Borrow<T>,
165    {
166        self.state.write().public_values.extend(
167            vars.into_iter()
168                .zip_eq(vals)
169                .map(|(var, val)| (*var.borrow(), val.borrow().clone())),
170        );
171    }
172
173    #[must_use]
174    pub fn public_vars(&self) -> Vec<(FieldVar, T)> {
175        let mut public_values = self
176            .state
177            .read()
178            .public_values
179            .iter()
180            .map(|(var, val)| (*var, val.clone()))
181            .collect::<Vec<_>>();
182        public_values.sort_unstable_by_key(|(var, _)| var.index());
183        public_values
184    }
185}