zbus/address/transport/
tcp.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
use super::encode_percents;
use crate::{Error, Result};
#[cfg(not(feature = "tokio"))]
use async_io::Async;
#[cfg(not(feature = "tokio"))]
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::{
    collections::HashMap,
    fmt::{Display, Formatter},
    str::FromStr,
};
#[cfg(feature = "tokio")]
use tokio::net::TcpStream;

/// A TCP transport in a D-Bus address.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Tcp {
    pub(super) host: String,
    pub(super) bind: Option<String>,
    pub(super) port: u16,
    pub(super) family: Option<TcpTransportFamily>,
    pub(super) nonce_file: Option<Vec<u8>>,
}

impl Tcp {
    /// Create a new TCP transport with the given host and port.
    pub fn new(host: &str, port: u16) -> Self {
        Self {
            host: host.to_owned(),
            port,
            bind: None,
            family: None,
            nonce_file: None,
        }
    }

    /// Set the `tcp:` address `bind` value.
    pub fn set_bind(mut self, bind: Option<String>) -> Self {
        self.bind = bind;

        self
    }

    /// Set the `tcp:` address `family` value.
    pub fn set_family(mut self, family: Option<TcpTransportFamily>) -> Self {
        self.family = family;

        self
    }

    /// Set the `tcp:` address `noncefile` value.
    pub fn set_nonce_file(mut self, nonce_file: Option<Vec<u8>>) -> Self {
        self.nonce_file = nonce_file;

        self
    }

    /// Returns the `tcp:` address `host` value.
    pub fn host(&self) -> &str {
        &self.host
    }

    /// Returns the `tcp:` address `bind` value.
    pub fn bind(&self) -> Option<&str> {
        self.bind.as_deref()
    }

    /// Returns the `tcp:` address `port` value.
    pub fn port(&self) -> u16 {
        self.port
    }

    /// Returns the `tcp:` address `family` value.
    pub fn family(&self) -> Option<TcpTransportFamily> {
        self.family
    }

    /// The nonce file path, if any.
    pub fn nonce_file(&self) -> Option<&[u8]> {
        self.nonce_file.as_deref()
    }

    /// Take ownership of the nonce file path, if any.
    pub fn take_nonce_file(&mut self) -> Option<Vec<u8>> {
        self.nonce_file.take()
    }

    pub(super) fn from_options(
        opts: HashMap<&str, &str>,
        nonce_tcp_required: bool,
    ) -> Result<Self> {
        let bind = None;
        if opts.contains_key("bind") {
            return Err(Error::Address("`bind` isn't yet supported".into()));
        }

        let host = opts
            .get("host")
            .ok_or_else(|| Error::Address("tcp address is missing `host`".into()))?
            .to_string();
        let port = opts
            .get("port")
            .ok_or_else(|| Error::Address("tcp address is missing `port`".into()))?;
        let port = port
            .parse::<u16>()
            .map_err(|_| Error::Address("invalid tcp `port`".into()))?;
        let family = opts
            .get("family")
            .map(|f| TcpTransportFamily::from_str(f))
            .transpose()?;
        let nonce_file = opts
            .get("noncefile")
            .map(|f| super::decode_percents(f))
            .transpose()?;
        if nonce_tcp_required && nonce_file.is_none() {
            return Err(Error::Address(
                "nonce-tcp address is missing `noncefile`".into(),
            ));
        }

        Ok(Self {
            host,
            bind,
            port,
            family,
            nonce_file,
        })
    }

    #[cfg(not(feature = "tokio"))]
    pub(super) async fn connect(self) -> Result<Async<TcpStream>> {
        let addrs = crate::Task::spawn_blocking(
            move || -> Result<Vec<SocketAddr>> {
                let addrs = (self.host(), self.port()).to_socket_addrs()?.filter(|a| {
                    if let Some(family) = self.family() {
                        if family == TcpTransportFamily::Ipv4 {
                            a.is_ipv4()
                        } else {
                            a.is_ipv6()
                        }
                    } else {
                        true
                    }
                });
                Ok(addrs.collect())
            },
            "connect tcp",
        )
        .await
        .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?;

        // we could attempt connections in parallel?
        let mut last_err = Error::Address("Failed to connect".into());
        for addr in addrs {
            match Async::<TcpStream>::connect(addr).await {
                Ok(stream) => return Ok(stream),
                Err(e) => last_err = e.into(),
            }
        }

        Err(last_err)
    }

    #[cfg(feature = "tokio")]
    pub(super) async fn connect(self) -> Result<TcpStream> {
        TcpStream::connect((self.host(), self.port()))
            .await
            .map_err(|e| Error::InputOutput(e.into()))
    }
}

impl Display for Tcp {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self.nonce_file() {
            Some(nonce_file) => {
                f.write_str("nonce-tcp:noncefile=")?;
                encode_percents(f, nonce_file)?;
                f.write_str(",")?;
            }
            None => f.write_str("tcp:")?,
        }
        f.write_str("host=")?;

        encode_percents(f, self.host().as_bytes())?;

        write!(f, ",port={}", self.port())?;

        if let Some(bind) = self.bind() {
            f.write_str(",bind=")?;
            encode_percents(f, bind.as_bytes())?;
        }

        if let Some(family) = self.family() {
            write!(f, ",family={family}")?;
        }

        Ok(())
    }
}

/// A `tcp:` address family.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum TcpTransportFamily {
    Ipv4,
    Ipv6,
}

impl FromStr for TcpTransportFamily {
    type Err = Error;

    fn from_str(family: &str) -> Result<Self> {
        match family {
            "ipv4" => Ok(Self::Ipv4),
            "ipv6" => Ok(Self::Ipv6),
            _ => Err(Error::Address(format!(
                "invalid tcp address `family`: {family}"
            ))),
        }
    }
}

impl Display for TcpTransportFamily {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Ipv4 => write!(f, "ipv4"),
            Self::Ipv6 => write!(f, "ipv6"),
        }
    }
}