serde_with/
flatten_maybe.rs

1use crate::prelude::*;
2
3/// Support deserializing from flattened and non-flattened representation
4///
5/// When working with different serialization formats, sometimes it is more idiomatic to flatten
6/// fields, while other formats prefer nesting. Using `#[serde(flatten)]` only the flattened form
7/// is supported.
8///
9/// This helper creates a function, which support deserializing from either the flattened or the
10/// nested form. It gives an error, when both forms are provided. The `flatten` attribute is
11/// required on the field such that the helper works. The serialization format will always be
12/// flattened.
13///
14/// # Examples
15///
16/// ```rust
17/// # use serde::Deserialize;
18/// #
19/// // Setup the types
20/// #[derive(Deserialize, Debug)]
21/// struct S {
22///     #[serde(flatten, deserialize_with = "deserialize_t")]
23///     t: T,
24/// }
25///
26/// #[derive(Deserialize, Debug)]
27/// struct T {
28///     i: i32,
29/// }
30///
31/// // The macro creates custom deserialization code.
32/// // You need to specify a function name and the field name of the flattened field.
33/// serde_with::flattened_maybe!(deserialize_t, "t");
34///
35/// # fn main() {
36/// // Supports both flattened
37/// let j = r#" {"i":1} "#;
38/// assert!(serde_json::from_str::<S>(j).is_ok());
39/// # // Ensure the t field is not dead code
40/// # assert_eq!(serde_json::from_str::<S>(j).unwrap().t.i, 1);
41///
42/// // and non-flattened versions.
43/// let j = r#" {"t":{"i":1}} "#;
44/// assert!(serde_json::from_str::<S>(j).is_ok());
45///
46/// // Ensure that the value is given
47/// let j = r#" {} "#;
48/// assert!(serde_json::from_str::<S>(j).is_err());
49///
50/// // and only occurs once, not multiple times.
51/// let j = r#" {"i":1,"t":{"i":1}} "#;
52/// assert!(serde_json::from_str::<S>(j).is_err());
53/// # }
54/// ```
55#[macro_export]
56macro_rules! flattened_maybe {
57    ($fn:ident, $field:tt) => {
58        fn $fn<'de, T, D>(deserializer: D) -> $crate::__private__::Result<T, D::Error>
59        where
60            T: $crate::__private__::Deserialize<'de>,
61            D: $crate::__private__::Deserializer<'de>,
62        {
63            $crate::__private__::DeserializeSeed::deserialize(
64                $crate::flatten_maybe::FlattenedMaybe($field, $crate::__private__::PhantomData),
65                deserializer,
66            )
67        }
68    };
69}
70
71/// Helper struct for the deserialization of the flattened maybe field.
72///
73/// Takes as first value the field name of the non-flattened field.
74pub struct FlattenedMaybe<T>(pub &'static str, pub PhantomData<T>);
75
76impl<'de, T> DeserializeSeed<'de> for FlattenedMaybe<T>
77where
78    T: Deserialize<'de>,
79{
80    type Value = T;
81
82    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
83    where
84        D: Deserializer<'de>,
85    {
86        #[allow(non_camel_case_types)]
87        enum Field<'de> {
88            // Marked for the non-flattened field
89            field_not_flat,
90            // Rest, buffered to be deserialized later
91            other(content::de::Content<'de>),
92        }
93
94        struct FieldVisitor<'a> {
95            fieldname: &'a str,
96        }
97
98        impl<'a, 'de> Visitor<'de> for FieldVisitor<'a> {
99            type Value = Field<'de>;
100            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
101                fmt::Formatter::write_str(formatter, "field identifier")
102            }
103            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
104            where
105                E: DeError,
106            {
107                Ok(Field::other(content::de::Content::Bool(value)))
108            }
109            fn visit_i8<E>(self, value: i8) -> Result<Self::Value, E>
110            where
111                E: DeError,
112            {
113                Ok(Field::other(content::de::Content::I8(value)))
114            }
115            fn visit_i16<E>(self, value: i16) -> Result<Self::Value, E>
116            where
117                E: DeError,
118            {
119                Ok(Field::other(content::de::Content::I16(value)))
120            }
121            fn visit_i32<E>(self, value: i32) -> Result<Self::Value, E>
122            where
123                E: DeError,
124            {
125                Ok(Field::other(content::de::Content::I32(value)))
126            }
127            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
128            where
129                E: DeError,
130            {
131                Ok(Field::other(content::de::Content::I64(value)))
132            }
133            fn visit_i128<E>(self, value: i128) -> Result<Self::Value, E>
134            where
135                E: DeError,
136            {
137                Ok(Field::other(content::de::Content::I128(value)))
138            }
139            fn visit_u8<E>(self, value: u8) -> Result<Self::Value, E>
140            where
141                E: DeError,
142            {
143                Ok(Field::other(content::de::Content::U8(value)))
144            }
145            fn visit_u16<E>(self, value: u16) -> Result<Self::Value, E>
146            where
147                E: DeError,
148            {
149                Ok(Field::other(content::de::Content::U16(value)))
150            }
151            fn visit_u32<E>(self, value: u32) -> Result<Self::Value, E>
152            where
153                E: DeError,
154            {
155                Ok(Field::other(content::de::Content::U32(value)))
156            }
157            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
158            where
159                E: DeError,
160            {
161                Ok(Field::other(content::de::Content::U64(value)))
162            }
163            fn visit_u128<E>(self, value: u128) -> Result<Self::Value, E>
164            where
165                E: DeError,
166            {
167                Ok(Field::other(content::de::Content::U128(value)))
168            }
169            fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
170            where
171                E: DeError,
172            {
173                Ok(Field::other(content::de::Content::F32(value)))
174            }
175            fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
176            where
177                E: DeError,
178            {
179                Ok(Field::other(content::de::Content::F64(value)))
180            }
181            fn visit_char<E>(self, value: char) -> Result<Self::Value, E>
182            where
183                E: DeError,
184            {
185                Ok(Field::other(content::de::Content::Char(value)))
186            }
187            fn visit_unit<E>(self) -> Result<Self::Value, E>
188            where
189                E: DeError,
190            {
191                Ok(Field::other(content::de::Content::Unit))
192            }
193            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
194            where
195                E: DeError,
196            {
197                if value == self.fieldname {
198                    Ok(Field::field_not_flat)
199                } else {
200                    let value = content::de::Content::String(ToString::to_string(value));
201                    Ok(Field::other(value))
202                }
203            }
204            fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
205            where
206                E: DeError,
207            {
208                if value == self.fieldname.as_bytes() {
209                    Ok(Field::field_not_flat)
210                } else {
211                    let value = content::de::Content::ByteBuf(value.to_vec());
212                    Ok(Field::other(value))
213                }
214            }
215            fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Self::Value, E>
216            where
217                E: DeError,
218            {
219                if value == self.fieldname {
220                    Ok(Field::field_not_flat)
221                } else {
222                    let value = content::de::Content::Str(value);
223                    Ok(Field::other(value))
224                }
225            }
226            fn visit_borrowed_bytes<E>(self, value: &'de [u8]) -> Result<Self::Value, E>
227            where
228                E: DeError,
229            {
230                if value == self.fieldname.as_bytes() {
231                    Ok(Field::field_not_flat)
232                } else {
233                    let value = content::de::Content::Bytes(value);
234                    Ok(Field::other(value))
235                }
236            }
237        }
238
239        impl<'de> DeserializeSeed<'de> for FieldVisitor<'_> {
240            type Value = Field<'de>;
241
242            #[inline]
243            fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
244            where
245                D: Deserializer<'de>,
246            {
247                Deserializer::deserialize_identifier(deserializer, self)
248            }
249        }
250
251        struct FlattenedMaybeVisitor<T> {
252            is_human_readable: bool,
253            fieldname: &'static str,
254            marker: PhantomData<T>,
255        }
256
257        impl<'de, T> Visitor<'de> for FlattenedMaybeVisitor<T>
258        where
259            T: Deserialize<'de>,
260        {
261            type Value = T;
262
263            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
264                formatter.write_fmt(format_args!(
265                    "a structure with a maybe flattened field `{}`",
266                    self.fieldname,
267                ))
268            }
269
270            #[inline]
271            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
272            where
273                A: MapAccess<'de>,
274            {
275                // Set to Some if field is present
276                let mut value_not_flat: Option<Option<T>> = None;
277                // Collect all other fields or the flattened fields
278                let mut collect =
279                    Vec::<Option<(content::de::Content<'_>, content::de::Content<'_>)>>::new();
280
281                // Iterate over the map
282                while let Some(key) = MapAccess::next_key_seed(
283                    &mut map,
284                    FieldVisitor {
285                        fieldname: self.fieldname,
286                    },
287                )? {
288                    match key {
289                        Field::field_not_flat => {
290                            if Option::is_some(&value_not_flat) {
291                                return Err(<A::Error as DeError>::duplicate_field(self.fieldname));
292                            }
293                            value_not_flat = Some(MapAccess::next_value::<Option<T>>(&mut map)?);
294                        }
295                        Field::other(name) => {
296                            collect.push(Some((name, MapAccess::next_value(&mut map)?)));
297                        }
298                    }
299                }
300
301                // Map is done, now check what we got
302                let value_not_flat = value_not_flat.flatten();
303                // Try to reconstruct the flattened structure
304                let value_flat: Option<T> =
305                    Deserialize::deserialize(content::de::FlatMapDeserializer(
306                        &mut collect,
307                        PhantomData,
308                        self.is_human_readable,
309                    ))?;
310
311                // Check that exactly one of the two options is set
312                match (value_flat, value_not_flat) {
313                    (Some(t), None) | (None, Some(t)) => Ok(t),
314                    (None, None) => Err(DeError::missing_field(self.fieldname)),
315                    (Some(_), Some(_)) => Err(DeError::custom(format_args!(
316                        "`{}` is both flattened and not",
317                        self.fieldname,
318                    ))),
319                }
320            }
321        }
322
323        let is_human_readable = deserializer.is_human_readable();
324        Deserializer::deserialize_map(
325            deserializer,
326            FlattenedMaybeVisitor {
327                is_human_readable,
328                fieldname: self.0,
329                marker: PhantomData,
330            },
331        )
332    }
333}