rustix/net/
socket_addr_any.rs

1//! The [`SocketAddrAny`] type and related utilities.
2
3#![allow(unsafe_code)]
4
5use crate::backend::c;
6use crate::backend::net::read_sockaddr;
7use crate::io::Errno;
8use crate::net::addr::{SocketAddrArg, SocketAddrLen, SocketAddrOpaque, SocketAddrStorage};
9#[cfg(unix)]
10use crate::net::SocketAddrUnix;
11use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6};
12use core::fmt;
13use core::mem::{size_of, MaybeUninit};
14use core::num::NonZeroU32;
15
16/// Temporary buffer for creating a `SocketAddrAny` from a syscall that writes
17/// to a `sockaddr_t` and `socklen_t`
18///
19/// Unlike `SocketAddrAny`, this does not maintain the invariant that `len`
20/// bytes are initialized.
21pub(crate) struct SocketAddrBuf {
22    pub(crate) len: c::socklen_t,
23    pub(crate) storage: MaybeUninit<SocketAddrStorage>,
24}
25
26impl SocketAddrBuf {
27    #[inline]
28    pub(crate) const fn new() -> Self {
29        Self {
30            len: size_of::<SocketAddrStorage>() as c::socklen_t,
31            storage: MaybeUninit::<SocketAddrStorage>::uninit(),
32        }
33    }
34
35    /// Convert the buffer into [`SocketAddrAny`].
36    ///
37    /// # Safety
38    ///
39    /// A valid address must have been written into `self.storage` and its
40    /// length written into `self.len`.
41    #[inline]
42    pub(crate) unsafe fn into_any(self) -> SocketAddrAny {
43        SocketAddrAny::new(self.storage, bitcast!(self.len))
44    }
45
46    /// Convert the buffer into [`Option<SocketAddrAny>`].
47    ///
48    /// This returns `None` if `len` is zero or other platform-specific
49    /// conditions define the address as empty.
50    ///
51    /// # Safety
52    ///
53    /// Either valid address must have been written into `self.storage` and its
54    /// length written into `self.len`, or `self.len` must have been set to 0.
55    #[inline]
56    pub(crate) unsafe fn into_any_option(self) -> Option<SocketAddrAny> {
57        let len = bitcast!(self.len);
58        if read_sockaddr::sockaddr_nonempty(self.storage.as_ptr().cast(), len) {
59            Some(SocketAddrAny::new(self.storage, len))
60        } else {
61            None
62        }
63    }
64}
65
66/// A type that can hold any kind of socket address, as a safe abstraction for
67/// `sockaddr_storage`.
68///
69/// Socket addresses can be converted to `SocketAddrAny` via the [`From`] and
70/// [`Into`] traits. `SocketAddrAny` can be converted back to a specific socket
71/// address type with [`TryFrom`] and [`TryInto`]. These implementations return
72/// [`Errno::AFNOSUPPORT`] if the address family does not match the requested
73/// type.
74#[derive(Clone)]
75#[doc(alias = "sockaddr_storage")]
76pub struct SocketAddrAny {
77    // Invariants:
78    //  - `len` is at least `size_of::<backend::c::sa_family_t>()`
79    //  - `len` is at most `size_of::<SocketAddrStorage>()`
80    //  - The first `len` bytes of `storage` are initialized.
81    pub(crate) len: NonZeroU32,
82    pub(crate) storage: MaybeUninit<SocketAddrStorage>,
83}
84
85impl SocketAddrAny {
86    /// Creates a socket address from `storage`, which is initialized for `len`
87    /// bytes.
88    ///
89    /// # Panics
90    ///
91    /// if `len` is smaller than the sockaddr header or larger than
92    /// `SocketAddrStorage`.
93    ///
94    /// # Safety
95    ///
96    ///  - `storage` must contain a valid socket address.
97    ///  - `len` bytes must be initialized.
98    #[inline]
99    pub const unsafe fn new(storage: MaybeUninit<SocketAddrStorage>, len: SocketAddrLen) -> Self {
100        assert!(len as usize >= size_of::<read_sockaddr::sockaddr_header>());
101        assert!(len as usize <= size_of::<SocketAddrStorage>());
102        let len = NonZeroU32::new_unchecked(len);
103        Self { storage, len }
104    }
105
106    /// Creates a socket address from reading from `ptr`, which points at `len`
107    /// initialized bytes.
108    ///
109    /// # Panics
110    ///
111    /// if `len` is smaller than the sockaddr header or larger than
112    /// `SocketAddrStorage`.
113    ///
114    /// # Safety
115    ///
116    ///  - `ptr` must be a pointer to memory containing a valid socket address.
117    ///  - `len` bytes must be initialized.
118    pub unsafe fn read(ptr: *const SocketAddrStorage, len: SocketAddrLen) -> Self {
119        assert!(len as usize >= size_of::<read_sockaddr::sockaddr_header>());
120        assert!(len as usize <= size_of::<SocketAddrStorage>());
121        let mut storage = MaybeUninit::<SocketAddrStorage>::uninit();
122        core::ptr::copy_nonoverlapping(
123            ptr.cast::<u8>(),
124            storage.as_mut_ptr().cast::<u8>(),
125            len as usize,
126        );
127        let len = NonZeroU32::new_unchecked(len);
128        Self { storage, len }
129    }
130
131    /// Gets the initialized part of the storage as bytes.
132    #[inline]
133    fn bytes(&self) -> &[u8] {
134        let len = self.len.get() as usize;
135        unsafe { core::slice::from_raw_parts(self.storage.as_ptr().cast(), len) }
136    }
137
138    /// Gets the address family of this socket address.
139    #[inline]
140    pub fn address_family(&self) -> AddressFamily {
141        // SAFETY: Our invariants maintain that the `sa_family` field is
142        // initialized.
143        unsafe {
144            AddressFamily::from_raw(crate::backend::net::read_sockaddr::read_sa_family(
145                self.storage.as_ptr().cast(),
146            ))
147        }
148    }
149
150    /// Returns a raw pointer to the sockaddr.
151    #[inline]
152    pub fn as_ptr(&self) -> *const SocketAddrStorage {
153        self.storage.as_ptr()
154    }
155
156    /// Returns a raw mutable pointer to the sockaddr.
157    #[inline]
158    pub fn as_mut_ptr(&mut self) -> *mut SocketAddrStorage {
159        self.storage.as_mut_ptr()
160    }
161
162    /// Returns the length of the encoded sockaddr.
163    #[inline]
164    pub fn addr_len(&self) -> SocketAddrLen {
165        self.len.get()
166    }
167}
168
169impl PartialEq<Self> for SocketAddrAny {
170    fn eq(&self, other: &Self) -> bool {
171        self.bytes() == other.bytes()
172    }
173}
174
175impl Eq for SocketAddrAny {}
176
177// This just forwards to another `partial_cmp`.
178#[allow(clippy::non_canonical_partial_ord_impl)]
179impl PartialOrd<Self> for SocketAddrAny {
180    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
181        self.bytes().partial_cmp(other.bytes())
182    }
183}
184
185impl Ord for SocketAddrAny {
186    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
187        self.bytes().cmp(other.bytes())
188    }
189}
190
191impl core::hash::Hash for SocketAddrAny {
192    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
193        self.bytes().hash(state)
194    }
195}
196
197impl fmt::Debug for SocketAddrAny {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        match self.address_family() {
200            AddressFamily::INET => {
201                if let Ok(addr) = SocketAddrV4::try_from(self.clone()) {
202                    return addr.fmt(f);
203                }
204            }
205            AddressFamily::INET6 => {
206                if let Ok(addr) = SocketAddrV6::try_from(self.clone()) {
207                    return addr.fmt(f);
208                }
209            }
210            #[cfg(unix)]
211            AddressFamily::UNIX => {
212                if let Ok(addr) = SocketAddrUnix::try_from(self.clone()) {
213                    return addr.fmt(f);
214                }
215            }
216            #[cfg(target_os = "linux")]
217            AddressFamily::XDP => {
218                if let Ok(addr) = crate::net::xdp::SocketAddrXdp::try_from(self.clone()) {
219                    return addr.fmt(f);
220                }
221            }
222            #[cfg(linux_kernel)]
223            AddressFamily::NETLINK => {
224                if let Ok(addr) = crate::net::netlink::SocketAddrNetlink::try_from(self.clone()) {
225                    return addr.fmt(f);
226                }
227            }
228            _ => {}
229        }
230
231        f.debug_struct("SocketAddrAny")
232            .field("address_family", &self.address_family())
233            .field("namelen", &self.addr_len())
234            .finish()
235    }
236}
237
238// SAFETY: `with_sockaddr` calls `f` with a pointer to its own storage.
239unsafe impl SocketAddrArg for SocketAddrAny {
240    unsafe fn with_sockaddr<R>(
241        &self,
242        f: impl FnOnce(*const SocketAddrOpaque, SocketAddrLen) -> R,
243    ) -> R {
244        f(self.as_ptr().cast(), self.addr_len())
245    }
246}
247
248impl From<SocketAddr> for SocketAddrAny {
249    #[inline]
250    fn from(from: SocketAddr) -> Self {
251        from.as_any()
252    }
253}
254
255impl TryFrom<SocketAddrAny> for SocketAddr {
256    type Error = Errno;
257
258    /// Convert if the address is an IPv4 or IPv6 address.
259    ///
260    /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not IPv4 or
261    /// IPv6.
262    #[inline]
263    fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
264        match value.address_family() {
265            AddressFamily::INET => read_sockaddr::read_sockaddr_v4(&value).map(SocketAddr::V4),
266            AddressFamily::INET6 => read_sockaddr::read_sockaddr_v6(&value).map(SocketAddr::V6),
267            _ => Err(Errno::AFNOSUPPORT),
268        }
269    }
270}
271
272impl From<SocketAddrV4> for SocketAddrAny {
273    #[inline]
274    fn from(from: SocketAddrV4) -> Self {
275        from.as_any()
276    }
277}
278
279impl TryFrom<SocketAddrAny> for SocketAddrV4 {
280    type Error = Errno;
281
282    /// Convert if the address is an IPv4 address.
283    ///
284    /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not IPv4.
285    #[inline]
286    fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
287        read_sockaddr::read_sockaddr_v4(&value)
288    }
289}
290
291impl From<SocketAddrV6> for SocketAddrAny {
292    #[inline]
293    fn from(from: SocketAddrV6) -> Self {
294        from.as_any()
295    }
296}
297
298impl TryFrom<SocketAddrAny> for SocketAddrV6 {
299    type Error = Errno;
300
301    /// Convert if the address is an IPv6 address.
302    ///
303    /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not IPv6.
304    #[inline]
305    fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
306        read_sockaddr::read_sockaddr_v6(&value)
307    }
308}
309
310#[cfg(unix)]
311impl From<SocketAddrUnix> for SocketAddrAny {
312    #[inline]
313    fn from(from: SocketAddrUnix) -> Self {
314        from.as_any()
315    }
316}
317
318#[cfg(unix)]
319impl TryFrom<SocketAddrAny> for SocketAddrUnix {
320    type Error = Errno;
321
322    /// Convert if the address is a Unix socket address.
323    ///
324    /// Returns `Err(Errno::AFNOSUPPORT)` if the address family is not Unix.
325    #[inline]
326    fn try_from(value: SocketAddrAny) -> Result<Self, Self::Error> {
327        read_sockaddr::read_sockaddr_unix(&value)
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn any_read() {
337        let localhost = std::net::Ipv6Addr::LOCALHOST;
338        let addr = SocketAddrAny::from(SocketAddrV6::new(localhost, 7, 8, 9));
339        unsafe {
340            let same = SocketAddrAny::read(addr.as_ptr(), addr.addr_len());
341            assert_eq!(addr, same);
342        }
343    }
344}