1use crate::de::simple_type::UnitOnly;
2use crate::de::str2bool;
3use crate::encoding::Decoder;
4use crate::errors::serialize::DeError;
5use crate::name::QName;
6use crate::utils::CowRef;
7use serde::de::{DeserializeSeed, Deserializer, EnumAccess, Visitor};
8use serde::{forward_to_deserialize_any, serde_if_integer128};
9use std::borrow::Cow;
10
11macro_rules! deserialize_num {
12 ($method:ident, $visit:ident) => {
13 fn $method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
14 where
15 V: Visitor<'de>,
16 {
17 visitor.$visit(self.name.parse()?)
18 }
19 };
20}
21
22#[inline]
26fn decode_name<'n>(name: QName<'n>, decoder: Decoder) -> Result<Cow<'n, str>, DeError> {
27 let local = name.local_name();
28 Ok(decoder.decode(local.into_inner())?)
29}
30
31pub struct QNameDeserializer<'i, 'd> {
74 name: CowRef<'i, 'd, str>,
75}
76
77impl<'i, 'd> QNameDeserializer<'i, 'd> {
78 pub fn from_attr(
80 name: QName<'d>,
81 decoder: Decoder,
82 key_buf: &'d mut String,
83 ) -> Result<Self, DeError> {
84 key_buf.clear();
85 key_buf.push('@');
86
87 if name.as_namespace_binding().is_some() {
90 decoder.decode_into(name.into_inner(), key_buf)?;
91 } else {
92 let local = name.local_name();
93 decoder.decode_into(local.into_inner(), key_buf)?;
94 };
95
96 Ok(Self {
97 name: CowRef::Slice(key_buf),
98 })
99 }
100
101 pub fn from_elem(name: CowRef<'i, 'd, [u8]>, decoder: Decoder) -> Result<Self, DeError> {
103 let local = match name {
104 CowRef::Input(borrowed) => match decode_name(QName(borrowed), decoder)? {
105 Cow::Borrowed(borrowed) => CowRef::Input(borrowed),
106 Cow::Owned(owned) => CowRef::Owned(owned),
107 },
108 CowRef::Slice(borrowed) => match decode_name(QName(borrowed), decoder)? {
109 Cow::Borrowed(borrowed) => CowRef::Slice(borrowed),
110 Cow::Owned(owned) => CowRef::Owned(owned),
111 },
112 CowRef::Owned(owned) => match decode_name(QName(&owned), decoder)? {
113 Cow::Borrowed(_) => CowRef::Owned(String::from_utf8(owned).unwrap()),
116 Cow::Owned(owned) => CowRef::Owned(owned),
117 },
118 };
119
120 Ok(Self { name: local })
121 }
122}
123
124impl<'de, 'd> Deserializer<'de> for QNameDeserializer<'de, 'd> {
125 type Error = DeError;
126
127 forward_to_deserialize_any! {
128 char str string
129 bytes byte_buf
130 seq tuple tuple_struct
131 map struct
132 ignored_any
133 }
134
135 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
144 where
145 V: Visitor<'de>,
146 {
147 str2bool(self.name.as_ref(), visitor)
148 }
149
150 deserialize_num!(deserialize_i8, visit_i8);
151 deserialize_num!(deserialize_i16, visit_i16);
152 deserialize_num!(deserialize_i32, visit_i32);
153 deserialize_num!(deserialize_i64, visit_i64);
154
155 deserialize_num!(deserialize_u8, visit_u8);
156 deserialize_num!(deserialize_u16, visit_u16);
157 deserialize_num!(deserialize_u32, visit_u32);
158 deserialize_num!(deserialize_u64, visit_u64);
159
160 serde_if_integer128! {
161 deserialize_num!(deserialize_i128, visit_i128);
162 deserialize_num!(deserialize_u128, visit_u128);
163 }
164
165 deserialize_num!(deserialize_f32, visit_f32);
166 deserialize_num!(deserialize_f64, visit_f64);
167
168 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
170 where
171 V: Visitor<'de>,
172 {
173 visitor.visit_unit()
174 }
175
176 fn deserialize_unit_struct<V>(
178 self,
179 _name: &'static str,
180 visitor: V,
181 ) -> Result<V::Value, Self::Error>
182 where
183 V: Visitor<'de>,
184 {
185 self.deserialize_unit(visitor)
186 }
187
188 #[inline]
190 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
191 where
192 V: Visitor<'de>,
193 {
194 self.deserialize_identifier(visitor)
195 }
196
197 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
200 where
201 V: Visitor<'de>,
202 {
203 if self.name.is_empty() {
204 visitor.visit_none()
205 } else {
206 visitor.visit_some(self)
207 }
208 }
209
210 fn deserialize_newtype_struct<V>(
211 self,
212 _name: &'static str,
213 visitor: V,
214 ) -> Result<V::Value, Self::Error>
215 where
216 V: Visitor<'de>,
217 {
218 visitor.visit_newtype_struct(self)
219 }
220
221 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
227 where
228 V: Visitor<'de>,
229 {
230 match self.name {
231 CowRef::Input(name) => visitor.visit_borrowed_str(name),
232 CowRef::Slice(name) => visitor.visit_str(name),
233 CowRef::Owned(name) => visitor.visit_string(name),
234 }
235 }
236
237 fn deserialize_enum<V>(
238 self,
239 _name: &str,
240 _variants: &'static [&'static str],
241 visitor: V,
242 ) -> Result<V::Value, Self::Error>
243 where
244 V: Visitor<'de>,
245 {
246 visitor.visit_enum(self)
247 }
248}
249
250impl<'de, 'd> EnumAccess<'de> for QNameDeserializer<'de, 'd> {
251 type Error = DeError;
252 type Variant = UnitOnly;
253
254 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
255 where
256 V: DeserializeSeed<'de>,
257 {
258 let name = seed.deserialize(self)?;
259 Ok((name, UnitOnly))
260 }
261}
262
263#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::se::key::QNameSerializer;
269 use crate::utils::{ByteBuf, Bytes};
270 use pretty_assertions::assert_eq;
271 use serde::de::IgnoredAny;
272 use serde::{Deserialize, Serialize};
273 use std::collections::HashMap;
274
275 #[derive(Debug, Deserialize, Serialize, PartialEq)]
276 struct Unit;
277
278 #[derive(Debug, Deserialize, Serialize, PartialEq)]
279 struct Newtype(String);
280
281 #[derive(Debug, Deserialize, Serialize, PartialEq)]
282 struct Tuple((), ());
283
284 #[derive(Debug, Deserialize, Serialize, PartialEq)]
285 struct Struct {
286 key: String,
287 val: usize,
288 }
289
290 #[derive(Debug, Deserialize, Serialize, PartialEq)]
291 enum Enum {
292 Unit,
293 #[serde(rename = "@Attr")]
294 Attr,
295 Newtype(String),
296 Tuple(String, usize),
297 Struct {
298 key: String,
299 val: usize,
300 },
301 }
302
303 #[derive(Debug, Deserialize, PartialEq)]
304 #[serde(field_identifier)]
305 enum Id {
306 Field,
307 }
308
309 #[derive(Debug, Deserialize)]
310 #[serde(transparent)]
311 struct Any(IgnoredAny);
312 impl PartialEq for Any {
313 fn eq(&self, _other: &Any) -> bool {
314 true
315 }
316 }
317
318 macro_rules! deserialized_to_only {
320 ($name:ident: $type:ty = $input:literal => $result:expr) => {
321 #[test]
322 fn $name() {
323 let de = QNameDeserializer {
324 name: CowRef::Input($input),
325 };
326 let data: $type = Deserialize::deserialize(de).unwrap();
327
328 assert_eq!(data, $result);
329 }
330 };
331 }
332
333 macro_rules! deserialized_to {
335 ($name:ident: $type:ty = $input:literal => $result:expr) => {
336 #[test]
337 fn $name() {
338 let de = QNameDeserializer {
339 name: CowRef::Input($input),
340 };
341 let data: $type = Deserialize::deserialize(de).unwrap();
342
343 assert_eq!(data, $result);
344
345 assert_eq!(
347 data.serialize(QNameSerializer {
348 writer: String::new()
349 })
350 .unwrap(),
351 $input
352 );
353 }
354 };
355 }
356
357 macro_rules! err {
360 ($name:ident: $type:ty = $input:literal => $kind:ident($reason:literal)) => {
361 #[test]
362 fn $name() {
363 let de = QNameDeserializer {
364 name: CowRef::Input($input),
365 };
366 let err = <$type as Deserialize>::deserialize(de).unwrap_err();
367
368 match err {
369 DeError::$kind(e) => assert_eq!(e, $reason),
370 _ => panic!(
371 "Expected `Err({}({}))`, but got `{:?}`",
372 stringify!($kind),
373 $reason,
374 err
375 ),
376 }
377 }
378 };
379 }
380
381 deserialized_to!(false_: bool = "false" => false);
382 deserialized_to!(true_: bool = "true" => true);
383
384 deserialized_to!(i8_: i8 = "-2" => -2);
385 deserialized_to!(i16_: i16 = "-2" => -2);
386 deserialized_to!(i32_: i32 = "-2" => -2);
387 deserialized_to!(i64_: i64 = "-2" => -2);
388
389 deserialized_to!(u8_: u8 = "3" => 3);
390 deserialized_to!(u16_: u16 = "3" => 3);
391 deserialized_to!(u32_: u32 = "3" => 3);
392 deserialized_to!(u64_: u64 = "3" => 3);
393
394 serde_if_integer128! {
395 deserialized_to!(i128_: i128 = "-2" => -2);
396 deserialized_to!(u128_: u128 = "2" => 2);
397 }
398
399 deserialized_to!(f32_: f32 = "1.23" => 1.23);
400 deserialized_to!(f64_: f64 = "1.23" => 1.23);
401
402 deserialized_to!(char_unescaped: char = "h" => 'h');
403 err!(char_escaped: char = "<"
404 => Custom("invalid value: string \"<\", expected a character"));
405
406 deserialized_to!(string: String = "<escaped string" => "<escaped string");
407 deserialized_to!(borrowed_str: &str = "name" => "name");
408
409 err!(byte_buf: ByteBuf = "<escaped string"
410 => Custom("invalid type: string \"<escaped string\", expected byte data"));
411 err!(borrowed_bytes: Bytes = "name"
412 => Custom("invalid type: string \"name\", expected borrowed bytes"));
413
414 deserialized_to!(option_none: Option<String> = "" => None);
415 deserialized_to!(option_some: Option<String> = "name" => Some("name".into()));
416
417 deserialized_to_only!(unit: () = "anything" => ());
420 deserialized_to_only!(unit_struct: Unit = "anything" => Unit);
421
422 deserialized_to!(newtype: Newtype = "<escaped string" => Newtype("<escaped string".into()));
423
424 err!(seq: Vec<()> = "name"
425 => Custom("invalid type: string \"name\", expected a sequence"));
426 err!(tuple: ((), ()) = "name"
427 => Custom("invalid type: string \"name\", expected a tuple of size 2"));
428 err!(tuple_struct: Tuple = "name"
429 => Custom("invalid type: string \"name\", expected tuple struct Tuple"));
430
431 err!(map: HashMap<(), ()> = "name"
432 => Custom("invalid type: string \"name\", expected a map"));
433 err!(struct_: Struct = "name"
434 => Custom("invalid type: string \"name\", expected struct Struct"));
435
436 deserialized_to!(enum_unit: Enum = "Unit" => Enum::Unit);
437 deserialized_to!(enum_unit_for_attr: Enum = "@Attr" => Enum::Attr);
438 err!(enum_newtype: Enum = "Newtype"
439 => Custom("invalid type: unit value, expected a string"));
440 err!(enum_tuple: Enum = "Tuple"
441 => Custom("invalid type: unit value, expected tuple variant Enum::Tuple"));
442 err!(enum_struct: Enum = "Struct"
443 => Custom("invalid type: unit value, expected struct variant Enum::Struct"));
444
445 deserialized_to_only!(identifier: Id = "Field" => Id::Field);
448 deserialized_to_only!(ignored_any: Any = "any-name" => Any(IgnoredAny));
449}