p3_util/
array_serialization.rs1use alloc::vec::Vec;
2use core::marker::PhantomData;
3
4use serde::de::{SeqAccess, Visitor};
5use serde::ser::SerializeTuple;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8pub fn serialize<S: Serializer, T: Serialize, const N: usize>(
9 data: &[T; N],
10 ser: S,
11) -> Result<S::Ok, S::Error> {
12 let mut s = ser.serialize_tuple(N)?;
13 for item in data {
14 s.serialize_element(item)?;
15 }
16 s.end()
17}
18
19struct ArrayVisitor<T, const N: usize>(PhantomData<T>);
20
21impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
22where
23 T: Deserialize<'de>,
24{
25 type Value = [T; N];
26
27 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 formatter.write_fmt(format_args!("an array of length {N}"))
29 }
30
31 #[inline]
32 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
33 where
34 A: SeqAccess<'de>,
35 {
36 let mut data = Vec::with_capacity(N);
37 for _ in 0..N {
38 match seq.next_element()? {
39 Some(val) => data.push(val),
40 None => return Err(serde::de::Error::invalid_length(N, &self)),
41 }
42 }
43 data.try_into().map_or_else(|_| unreachable!(), Ok)
44 }
45}
46pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
47where
48 D: Deserializer<'de>,
49 T: Deserialize<'de>,
50{
51 deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>(PhantomData))
52}
53
54#[cfg(test)]
55mod tests {
56 use serde::{Deserialize, Serialize};
57 use serde_json;
58
59 use super::*;
60
61 #[derive(Serialize, Deserialize, Debug, PartialEq)]
63 #[serde(bound(serialize = "", deserialize = ""))]
64 struct Wrapper<const N: usize> {
65 #[serde(serialize_with = "serialize", deserialize_with = "deserialize")]
66 arr: [u32; N],
67 }
68
69 #[test]
70 fn test_array_serde_roundtrip() {
71 let original = Wrapper::<3> { arr: [10, 20, 30] };
72
73 let json = serde_json::to_string(&original).unwrap();
74 assert_eq!(json, r#"{"arr":[10,20,30]}"#);
75
76 let deserialized: Wrapper<3> = serde_json::from_str(&json).unwrap();
77 assert_eq!(deserialized, original);
78
79 let parsed: Wrapper<3> = serde_json::from_str(r#"{"arr":[10,20,30]}"#).unwrap();
80 assert_eq!(parsed.arr, [10, 20, 30]);
81 }
82
83 #[test]
84 fn test_deserialize_wrong_length() {
85 let json = r#"{"arr":[1,2]}"#;
86
87 let result: Result<Wrapper<3>, _> = serde_json::from_str(json);
88 assert!(result.is_err());
89 }
90
91 #[test]
92 fn test_empty_array() {
93 let data = Wrapper::<0> { arr: [] };
94
95 let json = serde_json::to_string(&data).unwrap();
96 assert_eq!(json, r#"{"arr":[]}"#);
97
98 let parsed: Wrapper<0> = serde_json::from_str(&json).unwrap();
99 assert_eq!(parsed, data);
100 }
101}