serde_xml_rs/de/
mod.rs

1use std::{io::Read, marker::PhantomData};
2
3use log::trace;
4use serde::de::{self, Unexpected};
5use serde::forward_to_deserialize_any;
6use xml::name::OwnedName;
7use xml::reader::{EventReader, ParserConfig, XmlEvent};
8
9use self::buffer::{BufferedXmlReader, ChildXmlBuffer, RootXmlBuffer};
10use self::map::MapAccess;
11use self::seq::SeqAccess;
12use self::var::EnumAccess;
13use crate::error::{Error, Result};
14use crate::{debug_expect, expect};
15
16mod buffer;
17mod map;
18mod seq;
19mod var;
20
21/// A convenience method for deserialize some object from a string.
22///
23/// ```rust
24/// # use serde::{Deserialize, Serialize};
25/// # use serde_xml_rs::from_str;
26/// #[derive(Debug, Deserialize, PartialEq)]
27/// struct Item {
28///     name: String,
29///     source: String,
30/// }
31/// # fn main() {
32/// let s = r##"<item name="hello" source="world.rs" />"##;
33/// let item: Item = from_str(s).unwrap();
34/// assert_eq!(item, Item { name: "hello".to_string(),source: "world.rs".to_string()});
35/// # }
36/// ```
37pub fn from_str<'de, T: de::Deserialize<'de>>(s: &str) -> Result<T> {
38    from_reader(s.as_bytes())
39}
40
41/// A convenience method for deserialize some object from a reader.
42///
43/// ```rust
44/// # use serde::Deserialize;
45/// # use serde_xml_rs::from_reader;
46/// #[derive(Debug, Deserialize, PartialEq)]
47/// struct Item {
48///     name: String,
49///     source: String,
50/// }
51/// # fn main() {
52/// let s = r##"<item name="hello" source="world.rs" />"##;
53/// let item: Item = from_reader(s.as_bytes()).unwrap();
54/// assert_eq!(item, Item { name: "hello".to_string(),source: "world.rs".to_string()});
55/// # }
56/// ```
57pub fn from_reader<'de, R: Read, T: de::Deserialize<'de>>(reader: R) -> Result<T> {
58    T::deserialize(&mut Deserializer::new_from_reader(reader))
59}
60
61type RootDeserializer<R> = Deserializer<R, RootXmlBuffer<R>>;
62type ChildDeserializer<'parent, R> = Deserializer<R, ChildXmlBuffer<'parent, R>>;
63
64pub struct Deserializer<
65    R: Read, // Kept as type param to avoid type signature breaking-change
66    B: BufferedXmlReader<R> = RootXmlBuffer<R>,
67> {
68    /// XML document nested element depth
69    depth: usize,
70    buffered_reader: B,
71    is_map_value: bool,
72    non_contiguous_seq_elements: bool,
73    marker: PhantomData<R>,
74}
75
76impl<'de, R: Read> RootDeserializer<R> {
77    pub fn new(reader: EventReader<R>) -> Self {
78        let buffered_reader = RootXmlBuffer::new(reader);
79
80        Deserializer {
81            buffered_reader,
82            depth: 0,
83            is_map_value: false,
84            non_contiguous_seq_elements: false,
85            marker: PhantomData,
86        }
87    }
88
89    pub fn new_from_reader(reader: R) -> Self {
90        let config = ParserConfig::new()
91            .trim_whitespace(true)
92            .whitespace_to_characters(true)
93            .cdata_to_characters(true)
94            .ignore_comments(true)
95            .coalesce_characters(true);
96
97        Self::new(EventReader::new_with_config(reader, config))
98    }
99
100    /// Configures whether the deserializer should search all sibling elements when building a
101    /// sequence. Not required if all XML elements for sequences are adjacent. Disabled by
102    /// default. Enabling this option may incur additional memory usage.
103    ///
104    /// ```rust
105    /// # use serde::Deserialize;
106    /// # use serde_xml_rs::from_reader;
107    /// #[derive(Debug, Deserialize, PartialEq)]
108    /// struct Foo {
109    ///     bar: Vec<usize>,
110    ///     baz: String,
111    /// }
112    /// # fn main() {
113    /// let s = r##"
114    ///     <foo>
115    ///         <bar>1</bar>
116    ///         <bar>2</bar>
117    ///         <baz>Hello, world</baz>
118    ///         <bar>3</bar>
119    ///         <bar>4</bar>
120    ///     </foo>
121    /// "##;
122    /// let mut de = serde_xml_rs::Deserializer::new_from_reader(s.as_bytes())
123    ///     .non_contiguous_seq_elements(true);
124    /// let foo = Foo::deserialize(&mut de).unwrap();
125    /// assert_eq!(foo, Foo { bar: vec![1, 2, 3, 4], baz: "Hello, world".to_string()});
126    /// # }
127    /// ```
128    pub fn non_contiguous_seq_elements(mut self, set: bool) -> Self {
129        self.non_contiguous_seq_elements = set;
130        self
131    }
132}
133
134impl<'de, R: Read, B: BufferedXmlReader<R>> Deserializer<R, B> {
135    fn child<'a>(&'a mut self) -> Deserializer<R, ChildXmlBuffer<'a, R>> {
136        let Deserializer {
137            buffered_reader,
138            depth,
139            is_map_value,
140            non_contiguous_seq_elements,
141            ..
142        } = self;
143
144        Deserializer {
145            buffered_reader: buffered_reader.child_buffer(),
146            depth: *depth,
147            is_map_value: *is_map_value,
148            non_contiguous_seq_elements: *non_contiguous_seq_elements,
149            marker: PhantomData,
150        }
151    }
152
153    /// Gets the next XML event without advancing the cursor.
154    fn peek(&mut self) -> Result<&XmlEvent> {
155        let peeked = self.buffered_reader.peek()?;
156
157        trace!("Peeked {:?}", peeked);
158        Ok(peeked)
159    }
160
161    /// Gets the XML event at the cursor and advances the cursor.
162    fn next(&mut self) -> Result<XmlEvent> {
163        let next = self.buffered_reader.next()?;
164
165        match next {
166            XmlEvent::StartElement { .. } => {
167                self.depth += 1;
168            }
169            XmlEvent::EndElement { .. } => {
170                self.depth -= 1;
171            }
172            _ => {}
173        }
174        trace!("Fetched {:?}", next);
175        Ok(next)
176    }
177
178    fn set_map_value(&mut self) {
179        self.is_map_value = true;
180    }
181
182    pub fn unset_map_value(&mut self) -> bool {
183        ::std::mem::replace(&mut self.is_map_value, false)
184    }
185
186    /// If `self.is_map_value`: Performs the read operations specified by `f` on the inner content of an XML element.
187    /// `f` is expected to consume the entire inner contents of the element. The cursor will be moved to the end of the
188    /// element.
189    /// If `!self.is_map_value`: `f` will be performed without additional checks/advances for an outer XML element.
190    fn read_inner_value<V: de::Visitor<'de>, T, F: FnOnce(&mut Self) -> Result<T>>(
191        &mut self,
192        f: F,
193    ) -> Result<T> {
194        if self.unset_map_value() {
195            debug_expect!(self.next(), Ok(XmlEvent::StartElement { name, .. }) => {
196                let result = f(self)?;
197                self.expect_end_element(name)?;
198                Ok(result)
199            })
200        } else {
201            f(self)
202        }
203    }
204
205    fn expect_end_element(&mut self, start_name: OwnedName) -> Result<()> {
206        expect!(self.next()?, XmlEvent::EndElement { name, .. } => {
207            if name == start_name {
208                Ok(())
209            } else {
210                Err(Error::Custom { field: format!(
211                    "End tag </{}> didn't match the start tag <{}>",
212                    name.local_name,
213                    start_name.local_name
214                ) })
215            }
216        })
217    }
218
219    fn prepare_parse_type<V: de::Visitor<'de>>(&mut self) -> Result<String> {
220        if let XmlEvent::StartElement { .. } = *self.peek()? {
221            self.set_map_value()
222        }
223        self.read_inner_value::<V, String, _>(|this| {
224            if let XmlEvent::EndElement { .. } = *this.peek()? {
225                return Err(Error::UnexpectedToken {
226                    token: "EndElement".into(),
227                    found: "Characters".into(),
228                });
229            }
230
231            expect!(this.next()?, XmlEvent::Characters(s) => {
232                return Ok(s)
233            })
234        })
235    }
236}
237
238macro_rules! deserialize_type {
239    ($deserialize:ident => $visit:ident) => {
240        fn $deserialize<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
241            let value = self.prepare_parse_type::<V>()?.parse()?;
242            visitor.$visit(value)
243        }
244    };
245}
246
247impl<'de, 'a, R: Read, B: BufferedXmlReader<R>> de::Deserializer<'de>
248    for &'a mut Deserializer<R, B>
249{
250    type Error = Error;
251
252    forward_to_deserialize_any! {
253        identifier
254    }
255
256    fn deserialize_struct<V: de::Visitor<'de>>(
257        self,
258        _name: &'static str,
259        fields: &'static [&'static str],
260        visitor: V,
261    ) -> Result<V::Value> {
262        self.unset_map_value();
263        expect!(self.next()?, XmlEvent::StartElement { name, attributes, .. } => {
264            let map_value = visitor.visit_map(MapAccess::new(
265                self,
266                attributes,
267                fields.contains(&"$value")
268            ))?;
269            self.expect_end_element(name)?;
270            Ok(map_value)
271        })
272    }
273
274    deserialize_type!(deserialize_i8 => visit_i8);
275    deserialize_type!(deserialize_i16 => visit_i16);
276    deserialize_type!(deserialize_i32 => visit_i32);
277    deserialize_type!(deserialize_i64 => visit_i64);
278    deserialize_type!(deserialize_u8 => visit_u8);
279    deserialize_type!(deserialize_u16 => visit_u16);
280    deserialize_type!(deserialize_u32 => visit_u32);
281    deserialize_type!(deserialize_u64 => visit_u64);
282    deserialize_type!(deserialize_f32 => visit_f32);
283    deserialize_type!(deserialize_f64 => visit_f64);
284
285    fn deserialize_bool<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
286        if let XmlEvent::StartElement { .. } = *self.peek()? {
287            self.set_map_value()
288        }
289        self.read_inner_value::<V, V::Value, _>(|this| {
290            if let XmlEvent::EndElement { .. } = *this.peek()? {
291                return visitor.visit_bool(false);
292            }
293            expect!(this.next()?, XmlEvent::Characters(s) => {
294                match s.as_str() {
295                    "true" | "1" => visitor.visit_bool(true),
296                    "false" | "0" => visitor.visit_bool(false),
297                    _ => Err(de::Error::invalid_value(Unexpected::Str(&s), &"a boolean")),
298                }
299
300            })
301        })
302    }
303
304    fn deserialize_char<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
305        self.deserialize_string(visitor)
306    }
307
308    fn deserialize_str<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
309        self.deserialize_string(visitor)
310    }
311
312    fn deserialize_bytes<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
313        self.deserialize_string(visitor)
314    }
315
316    fn deserialize_byte_buf<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
317        self.deserialize_string(visitor)
318    }
319
320    fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
321        if let XmlEvent::StartElement { .. } = *self.peek()? {
322            self.set_map_value()
323        }
324        self.read_inner_value::<V, V::Value, _>(
325            |this| expect!(this.peek()?, &XmlEvent::EndElement { .. } => visitor.visit_unit()),
326        )
327    }
328
329    fn deserialize_unit_struct<V: de::Visitor<'de>>(
330        self,
331        _name: &'static str,
332        visitor: V,
333    ) -> Result<V::Value> {
334        self.deserialize_unit(visitor)
335    }
336
337    fn deserialize_newtype_struct<V: de::Visitor<'de>>(
338        self,
339        _name: &'static str,
340        visitor: V,
341    ) -> Result<V::Value> {
342        visitor.visit_newtype_struct(self)
343    }
344
345    fn deserialize_tuple_struct<V: de::Visitor<'de>>(
346        self,
347        _name: &'static str,
348        len: usize,
349        visitor: V,
350    ) -> Result<V::Value> {
351        self.deserialize_tuple(len, visitor)
352    }
353
354    fn deserialize_tuple<V: de::Visitor<'de>>(self, len: usize, visitor: V) -> Result<V::Value> {
355        let child_deserializer = self.child();
356
357        visitor.visit_seq(SeqAccess::new(child_deserializer, Some(len)))
358    }
359
360    fn deserialize_enum<V: de::Visitor<'de>>(
361        self,
362        _name: &'static str,
363        _variants: &'static [&'static str],
364        visitor: V,
365    ) -> Result<V::Value> {
366        self.read_inner_value::<V, V::Value, _>(|this| visitor.visit_enum(EnumAccess::new(this)))
367    }
368
369    fn deserialize_string<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
370        if let XmlEvent::StartElement { .. } = *self.peek()? {
371            self.set_map_value()
372        }
373        self.read_inner_value::<V, V::Value, _>(|this| {
374            if let XmlEvent::EndElement { .. } = *this.peek()? {
375                return visitor.visit_str("");
376            }
377            expect!(this.next()?, XmlEvent::Characters(s) => {
378                visitor.visit_string(s)
379            })
380        })
381    }
382
383    fn deserialize_seq<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
384        let child_deserializer = self.child();
385
386        visitor.visit_seq(SeqAccess::new(child_deserializer, None))
387    }
388
389    fn deserialize_map<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
390        self.unset_map_value();
391        expect!(self.next()?, XmlEvent::StartElement { name, attributes, .. } => {
392            let map_value = visitor.visit_map(MapAccess::new(self, attributes, false))?;
393            self.expect_end_element(name)?;
394            Ok(map_value)
395        })
396    }
397
398    fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
399        match *self.peek()? {
400            XmlEvent::EndElement { .. } => visitor.visit_none(),
401            _ => visitor.visit_some(self),
402        }
403    }
404
405    fn deserialize_ignored_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
406        self.unset_map_value();
407        let depth = self.depth;
408        loop {
409            self.next()?;
410            if self.depth == depth {
411                break;
412            }
413        }
414        visitor.visit_unit()
415    }
416
417    fn deserialize_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
418        match *self.peek()? {
419            XmlEvent::StartElement { .. } => self.deserialize_map(visitor),
420            XmlEvent::EndElement { .. } => self.deserialize_unit(visitor),
421            _ => self.deserialize_string(visitor),
422        }
423    }
424}