quick_xml/
utils.rs

1use std::borrow::{Borrow, Cow};
2use std::fmt::{self, Debug, Formatter};
3use std::io;
4use std::ops::Deref;
5
6#[cfg(feature = "async-tokio")]
7use std::{
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12#[cfg(feature = "serialize")]
13use serde::de::{Deserialize, Deserializer, Error, Visitor};
14#[cfg(feature = "serialize")]
15use serde::ser::{Serialize, Serializer};
16
17#[allow(clippy::ptr_arg)]
18pub fn write_cow_string(f: &mut Formatter, cow_string: &Cow<[u8]>) -> fmt::Result {
19    match cow_string {
20        Cow::Owned(s) => {
21            write!(f, "Owned(")?;
22            write_byte_string(f, s)?;
23        }
24        Cow::Borrowed(s) => {
25            write!(f, "Borrowed(")?;
26            write_byte_string(f, s)?;
27        }
28    }
29    write!(f, ")")
30}
31
32pub fn write_byte_string(f: &mut Formatter, byte_string: &[u8]) -> fmt::Result {
33    write!(f, "\"")?;
34    for b in byte_string {
35        match *b {
36            32..=33 | 35..=126 => write!(f, "{}", *b as char)?,
37            34 => write!(f, "\\\"")?,
38            _ => write!(f, "{:#02X}", b)?,
39        }
40    }
41    write!(f, "\"")?;
42    Ok(())
43}
44
45////////////////////////////////////////////////////////////////////////////////////////////////////
46
47/// A version of [`Cow`] that can borrow from two different buffers, one of them
48/// is a deserializer input.
49///
50/// # Lifetimes
51///
52/// - `'i`: lifetime of the data that deserializer borrow from the parsed input
53/// - `'s`: lifetime of the data that owned by a deserializer
54pub enum CowRef<'i, 's, B>
55where
56    B: ToOwned + ?Sized,
57{
58    /// An input borrowed from the parsed data
59    Input(&'i B),
60    /// An input borrowed from the buffer owned by another deserializer
61    Slice(&'s B),
62    /// An input taken from an external deserializer, owned by that deserializer
63    Owned(<B as ToOwned>::Owned),
64}
65impl<'i, 's, B> Deref for CowRef<'i, 's, B>
66where
67    B: ToOwned + ?Sized,
68    B::Owned: Borrow<B>,
69{
70    type Target = B;
71
72    fn deref(&self) -> &B {
73        match *self {
74            Self::Input(borrowed) => borrowed,
75            Self::Slice(borrowed) => borrowed,
76            Self::Owned(ref owned) => owned.borrow(),
77        }
78    }
79}
80
81impl<'i, 's, B> Debug for CowRef<'i, 's, B>
82where
83    B: ToOwned + ?Sized + Debug,
84    B::Owned: Debug,
85{
86    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
87        match *self {
88            Self::Input(borrowed) => Debug::fmt(borrowed, f),
89            Self::Slice(borrowed) => Debug::fmt(borrowed, f),
90            Self::Owned(ref owned) => Debug::fmt(owned, f),
91        }
92    }
93}
94
95////////////////////////////////////////////////////////////////////////////////////////////////////
96
97/// Wrapper around `Vec<u8>` that has a human-readable debug representation:
98/// printable ASCII symbols output as is, all other output in HEX notation.
99///
100/// Also, when [`serialize`] feature is on, this type deserialized using
101/// [`deserialize_byte_buf`](serde::Deserializer::deserialize_byte_buf) instead
102/// of vector's generic [`deserialize_seq`](serde::Deserializer::deserialize_seq)
103///
104/// [`serialize`]: ../index.html#serialize
105#[derive(PartialEq, Eq)]
106pub struct ByteBuf(pub Vec<u8>);
107
108impl Debug for ByteBuf {
109    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
110        write_byte_string(f, &self.0)
111    }
112}
113
114#[cfg(feature = "serialize")]
115impl<'de> Deserialize<'de> for ByteBuf {
116    fn deserialize<D>(d: D) -> Result<Self, D::Error>
117    where
118        D: Deserializer<'de>,
119    {
120        struct ValueVisitor;
121
122        impl<'de> Visitor<'de> for ValueVisitor {
123            type Value = ByteBuf;
124
125            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
126                f.write_str("byte data")
127            }
128
129            fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
130                Ok(ByteBuf(v.to_vec()))
131            }
132
133            fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
134                Ok(ByteBuf(v))
135            }
136        }
137
138        d.deserialize_byte_buf(ValueVisitor)
139    }
140}
141
142#[cfg(feature = "serialize")]
143impl Serialize for ByteBuf {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: Serializer,
147    {
148        serializer.serialize_bytes(&self.0)
149    }
150}
151
152////////////////////////////////////////////////////////////////////////////////////////////////////
153
154/// Wrapper around `&[u8]` that has a human-readable debug representation:
155/// printable ASCII symbols output as is, all other output in HEX notation.
156///
157/// Also, when [`serialize`] feature is on, this type deserialized using
158/// [`deserialize_bytes`](serde::Deserializer::deserialize_bytes) instead
159/// of vector's generic [`deserialize_seq`](serde::Deserializer::deserialize_seq)
160///
161/// [`serialize`]: ../index.html#serialize
162#[derive(PartialEq, Eq)]
163pub struct Bytes<'de>(pub &'de [u8]);
164
165impl<'de> Debug for Bytes<'de> {
166    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
167        write_byte_string(f, self.0)
168    }
169}
170
171#[cfg(feature = "serialize")]
172impl<'de> Deserialize<'de> for Bytes<'de> {
173    fn deserialize<D>(d: D) -> Result<Self, D::Error>
174    where
175        D: Deserializer<'de>,
176    {
177        struct ValueVisitor;
178
179        impl<'de> Visitor<'de> for ValueVisitor {
180            type Value = Bytes<'de>;
181
182            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
183                f.write_str("borrowed bytes")
184            }
185
186            fn visit_borrowed_bytes<E: Error>(self, v: &'de [u8]) -> Result<Self::Value, E> {
187                Ok(Bytes(v))
188            }
189        }
190
191        d.deserialize_bytes(ValueVisitor)
192    }
193}
194
195#[cfg(feature = "serialize")]
196impl<'de> Serialize for Bytes<'de> {
197    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
198    where
199        S: Serializer,
200    {
201        serializer.serialize_bytes(self.0)
202    }
203}
204
205////////////////////////////////////////////////////////////////////////////////////////////////////
206
207/// A simple producer of infinite stream of bytes, useful in tests.
208///
209/// Will repeat `chunk` field indefinitely.
210pub struct Fountain<'a> {
211    /// That piece of data repeated infinitely...
212    pub chunk: &'a [u8],
213    /// Part of `chunk` that was consumed by BufRead impl
214    pub consumed: usize,
215    /// The overall count of read bytes
216    pub overall_read: u64,
217}
218
219impl<'a> io::Read for Fountain<'a> {
220    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
221        let available = &self.chunk[self.consumed..];
222        let len = buf.len().min(available.len());
223        let (portion, _) = available.split_at(len);
224
225        buf.copy_from_slice(portion);
226        Ok(len)
227    }
228}
229
230impl<'a> io::BufRead for Fountain<'a> {
231    #[inline]
232    fn fill_buf(&mut self) -> io::Result<&[u8]> {
233        Ok(&self.chunk[self.consumed..])
234    }
235
236    fn consume(&mut self, amt: usize) {
237        self.consumed += amt;
238        if self.consumed == self.chunk.len() {
239            self.consumed = 0;
240        }
241        self.overall_read += amt as u64;
242    }
243}
244
245#[cfg(feature = "async-tokio")]
246impl<'a> tokio::io::AsyncRead for Fountain<'a> {
247    fn poll_read(
248        self: Pin<&mut Self>,
249        _cx: &mut Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> Poll<io::Result<()>> {
252        let available = &self.chunk[self.consumed..];
253        let len = buf.remaining().min(available.len());
254        let (portion, _) = available.split_at(len);
255
256        buf.put_slice(portion);
257        Poll::Ready(Ok(()))
258    }
259}
260
261#[cfg(feature = "async-tokio")]
262impl<'a> tokio::io::AsyncBufRead for Fountain<'a> {
263    #[inline]
264    fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
265        Poll::Ready(io::BufRead::fill_buf(self.get_mut()))
266    }
267
268    #[inline]
269    fn consume(self: Pin<&mut Self>, amt: usize) {
270        io::BufRead::consume(self.get_mut(), amt);
271    }
272}
273
274////////////////////////////////////////////////////////////////////////////////////////////////////
275
276/// A function to check whether the byte is a whitespace (blank, new line, carriage return or tab).
277#[inline]
278pub const fn is_whitespace(b: u8) -> bool {
279    matches!(b, b' ' | b'\r' | b'\n' | b'\t')
280}
281
282/// Calculates name from an element-like content. Name is the first word in `content`,
283/// where word boundaries is XML whitespace characters.
284///
285/// 'Whitespace' refers to the definition used by [`is_whitespace`].
286#[inline]
287pub const fn name_len(mut bytes: &[u8]) -> usize {
288    // Note: A pattern matching based approach (instead of indexing) allows
289    // making the function const.
290    let mut len = 0;
291    while let [first, rest @ ..] = bytes {
292        if is_whitespace(*first) {
293            break;
294        }
295        len += 1;
296        bytes = rest;
297    }
298    len
299}
300
301/// Returns a byte slice with leading XML whitespace bytes removed.
302///
303/// 'Whitespace' refers to the definition used by [`is_whitespace`].
304#[inline]
305pub const fn trim_xml_start(mut bytes: &[u8]) -> &[u8] {
306    // Note: A pattern matching based approach (instead of indexing) allows
307    // making the function const.
308    while let [first, rest @ ..] = bytes {
309        if is_whitespace(*first) {
310            bytes = rest;
311        } else {
312            break;
313        }
314    }
315    bytes
316}
317
318/// Returns a byte slice with trailing XML whitespace bytes removed.
319///
320/// 'Whitespace' refers to the definition used by [`is_whitespace`].
321#[inline]
322pub const fn trim_xml_end(mut bytes: &[u8]) -> &[u8] {
323    // Note: A pattern matching based approach (instead of indexing) allows
324    // making the function const.
325    while let [rest @ .., last] = bytes {
326        if is_whitespace(*last) {
327            bytes = rest;
328        } else {
329            break;
330        }
331    }
332    bytes
333}
334
335////////////////////////////////////////////////////////////////////////////////////////////////////
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use pretty_assertions::assert_eq;
341
342    #[test]
343    fn write_byte_string0() {
344        let bytes = ByteBuf(vec![10, 32, 32, 32, 32, 32, 32, 32, 32]);
345        assert_eq!(format!("{:?}", bytes), "\"0xA        \"");
346    }
347
348    #[test]
349    fn write_byte_string1() {
350        let bytes = ByteBuf(vec![
351            104, 116, 116, 112, 58, 47, 47, 119, 119, 119, 46, 119, 51, 46, 111, 114, 103, 47, 50,
352            48, 48, 50, 47, 48, 55, 47, 111, 119, 108, 35,
353        ]);
354        assert_eq!(
355            format!("{:?}", bytes),
356            r##""http://www.w3.org/2002/07/owl#""##
357        );
358    }
359
360    #[test]
361    fn write_byte_string3() {
362        let bytes = ByteBuf(vec![
363            67, 108, 97, 115, 115, 32, 73, 82, 73, 61, 34, 35, 66, 34,
364        ]);
365        assert_eq!(format!("{:?}", bytes), r##""Class IRI=\"#B\"""##);
366    }
367
368    #[test]
369    fn name_len() {
370        assert_eq!(super::name_len(b""), 0);
371        assert_eq!(super::name_len(b" abc"), 0);
372        assert_eq!(super::name_len(b" \t\r\n"), 0);
373
374        assert_eq!(super::name_len(b"abc"), 3);
375        assert_eq!(super::name_len(b"abc "), 3);
376
377        assert_eq!(super::name_len(b"a bc"), 1);
378        assert_eq!(super::name_len(b"ab\tc"), 2);
379        assert_eq!(super::name_len(b"ab\rc"), 2);
380        assert_eq!(super::name_len(b"ab\nc"), 2);
381    }
382
383    #[test]
384    fn trim_xml_start() {
385        assert_eq!(Bytes(super::trim_xml_start(b"")), Bytes(b""));
386        assert_eq!(Bytes(super::trim_xml_start(b"abc")), Bytes(b"abc"));
387        assert_eq!(
388            Bytes(super::trim_xml_start(b"\r\n\t ab \t\r\nc \t\r\n")),
389            Bytes(b"ab \t\r\nc \t\r\n")
390        );
391    }
392
393    #[test]
394    fn trim_xml_end() {
395        assert_eq!(Bytes(super::trim_xml_end(b"")), Bytes(b""));
396        assert_eq!(Bytes(super::trim_xml_end(b"abc")), Bytes(b"abc"));
397        assert_eq!(
398            Bytes(super::trim_xml_end(b"\r\n\t ab \t\r\nc \t\r\n")),
399            Bytes(b"\r\n\t ab \t\r\nc")
400        );
401    }
402}