rustls/crypto/ring/
quic.rs

1#![allow(clippy::duplicate_mod)]
2
3use alloc::boxed::Box;
4
5use super::ring_like::aead;
6use crate::crypto::cipher::{AeadKey, Iv, Nonce};
7use crate::error::Error;
8use crate::quic;
9
10pub(crate) struct HeaderProtectionKey(aead::quic::HeaderProtectionKey);
11
12impl HeaderProtectionKey {
13    pub(crate) fn new(key: AeadKey, alg: &'static aead::quic::Algorithm) -> Self {
14        Self(aead::quic::HeaderProtectionKey::new(alg, key.as_ref()).unwrap())
15    }
16
17    fn xor_in_place(
18        &self,
19        sample: &[u8],
20        first: &mut u8,
21        packet_number: &mut [u8],
22        masked: bool,
23    ) -> Result<(), Error> {
24        // This implements "Header Protection Application" almost verbatim.
25        // <https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.1>
26
27        let mask = self
28            .0
29            .new_mask(sample)
30            .map_err(|_| Error::General("sample of invalid length".into()))?;
31
32        // The `unwrap()` will not panic because `new_mask` returns a
33        // non-empty result.
34        let (first_mask, pn_mask) = mask.split_first().unwrap();
35
36        // It is OK for the `mask` to be longer than `packet_number`,
37        // but a valid `packet_number` will never be longer than `mask`.
38        if packet_number.len() > pn_mask.len() {
39            return Err(Error::General("packet number too long".into()));
40        }
41
42        // Infallible from this point on. Before this point, `first` and
43        // `packet_number` are unchanged.
44
45        const LONG_HEADER_FORM: u8 = 0x80;
46        let bits = match *first & LONG_HEADER_FORM == LONG_HEADER_FORM {
47            true => 0x0f,  // Long header: 4 bits masked
48            false => 0x1f, // Short header: 5 bits masked
49        };
50
51        let first_plain = match masked {
52            // When unmasking, use the packet length bits after unmasking
53            true => *first ^ (first_mask & bits),
54            // When masking, use the packet length bits before masking
55            false => *first,
56        };
57        let pn_len = (first_plain & 0x03) as usize + 1;
58
59        *first ^= first_mask & bits;
60        for (dst, m) in packet_number
61            .iter_mut()
62            .zip(pn_mask)
63            .take(pn_len)
64        {
65            *dst ^= m;
66        }
67
68        Ok(())
69    }
70}
71
72impl quic::HeaderProtectionKey for HeaderProtectionKey {
73    fn encrypt_in_place(
74        &self,
75        sample: &[u8],
76        first: &mut u8,
77        packet_number: &mut [u8],
78    ) -> Result<(), Error> {
79        self.xor_in_place(sample, first, packet_number, false)
80    }
81
82    fn decrypt_in_place(
83        &self,
84        sample: &[u8],
85        first: &mut u8,
86        packet_number: &mut [u8],
87    ) -> Result<(), Error> {
88        self.xor_in_place(sample, first, packet_number, true)
89    }
90
91    #[inline]
92    fn sample_len(&self) -> usize {
93        self.0.algorithm().sample_len()
94    }
95}
96
97pub(crate) struct PacketKey {
98    /// Encrypts or decrypts a packet's payload
99    key: aead::LessSafeKey,
100    /// Computes unique nonces for each packet
101    iv: Iv,
102    /// Confidentiality limit (see [`quic::PacketKey::confidentiality_limit`])
103    confidentiality_limit: u64,
104    /// Integrity limit (see [`quic::PacketKey::integrity_limit`])
105    integrity_limit: u64,
106}
107
108impl PacketKey {
109    pub(crate) fn new(
110        key: AeadKey,
111        iv: Iv,
112        confidentiality_limit: u64,
113        integrity_limit: u64,
114        aead_algorithm: &'static aead::Algorithm,
115    ) -> Self {
116        Self {
117            key: aead::LessSafeKey::new(
118                aead::UnboundKey::new(aead_algorithm, key.as_ref()).unwrap(),
119            ),
120            iv,
121            confidentiality_limit,
122            integrity_limit,
123        }
124    }
125}
126
127impl quic::PacketKey for PacketKey {
128    fn encrypt_in_place(
129        &self,
130        packet_number: u64,
131        header: &[u8],
132        payload: &mut [u8],
133    ) -> Result<quic::Tag, Error> {
134        let aad = aead::Aad::from(header);
135        let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, packet_number).0);
136        let tag = self
137            .key
138            .seal_in_place_separate_tag(nonce, aad, payload)
139            .map_err(|_| Error::EncryptError)?;
140        Ok(quic::Tag::from(tag.as_ref()))
141    }
142
143    fn encrypt_in_place_for_path(
144        &self,
145        path_id: u32,
146        packet_number: u64,
147        header: &[u8],
148        payload: &mut [u8],
149    ) -> Result<quic::Tag, Error> {
150        let aad = aead::Aad::from(header);
151        let nonce =
152            aead::Nonce::assume_unique_for_key(Nonce::for_path(path_id, &self.iv, packet_number).0);
153        let tag = self
154            .key
155            .seal_in_place_separate_tag(nonce, aad, payload)
156            .map_err(|_| Error::EncryptError)?;
157        Ok(quic::Tag::from(tag.as_ref()))
158    }
159
160    /// Decrypt a QUIC packet
161    ///
162    /// Takes the packet `header`, which is used as the additional authenticated data, and the
163    /// `payload`, which includes the authentication tag.
164    ///
165    /// If the return value is `Ok`, the decrypted payload can be found in `payload`, up to the
166    /// length found in the return value.
167    fn decrypt_in_place<'a>(
168        &self,
169        packet_number: u64,
170        header: &[u8],
171        payload: &'a mut [u8],
172    ) -> Result<&'a [u8], Error> {
173        let payload_len = payload.len();
174        let aad = aead::Aad::from(header);
175        let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, packet_number).0);
176        self.key
177            .open_in_place(nonce, aad, payload)
178            .map_err(|_| Error::DecryptError)?;
179
180        let plain_len = payload_len - self.key.algorithm().tag_len();
181        Ok(&payload[..plain_len])
182    }
183
184    fn decrypt_in_place_for_path<'a>(
185        &self,
186        path_id: u32,
187        packet_number: u64,
188        header: &[u8],
189        payload: &'a mut [u8],
190    ) -> Result<&'a [u8], Error> {
191        let payload_len = payload.len();
192        let aad = aead::Aad::from(header);
193        let nonce =
194            aead::Nonce::assume_unique_for_key(Nonce::for_path(path_id, &self.iv, packet_number).0);
195        self.key
196            .open_in_place(nonce, aad, payload)
197            .map_err(|_| Error::DecryptError)?;
198
199        let plain_len = payload_len - self.key.algorithm().tag_len();
200        Ok(&payload[..plain_len])
201    }
202
203    /// Tag length for the underlying AEAD algorithm
204    #[inline]
205    fn tag_len(&self) -> usize {
206        self.key.algorithm().tag_len()
207    }
208
209    /// Confidentiality limit (see [`quic::PacketKey::confidentiality_limit`])
210    fn confidentiality_limit(&self) -> u64 {
211        self.confidentiality_limit
212    }
213
214    /// Integrity limit (see [`quic::PacketKey::integrity_limit`])
215    fn integrity_limit(&self) -> u64 {
216        self.integrity_limit
217    }
218}
219
220pub(crate) struct KeyBuilder {
221    pub(crate) packet_alg: &'static aead::Algorithm,
222    pub(crate) header_alg: &'static aead::quic::Algorithm,
223    pub(crate) confidentiality_limit: u64,
224    pub(crate) integrity_limit: u64,
225}
226
227impl quic::Algorithm for KeyBuilder {
228    fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn quic::PacketKey> {
229        Box::new(PacketKey::new(
230            key,
231            iv,
232            self.confidentiality_limit,
233            self.integrity_limit,
234            self.packet_alg,
235        ))
236    }
237
238    fn header_protection_key(&self, key: AeadKey) -> Box<dyn quic::HeaderProtectionKey> {
239        Box::new(HeaderProtectionKey::new(key, self.header_alg))
240    }
241
242    fn aead_key_len(&self) -> usize {
243        self.packet_alg.key_len()
244    }
245
246    fn fips(&self) -> bool {
247        super::fips()
248    }
249}
250
251#[cfg(test)]
252#[macro_rules_attribute::apply(test_for_each_provider)]
253mod tests {
254    use std::dbg;
255
256    use super::provider::tls13::{
257        TLS13_AES_128_GCM_SHA256_INTERNAL, TLS13_CHACHA20_POLY1305_SHA256_INTERNAL,
258    };
259    use crate::common_state::Side;
260    use crate::crypto::tls13::OkmBlock;
261    use crate::quic::*;
262
263    fn test_short_packet(version: Version, expected: &[u8]) {
264        const PN: u64 = 654360564;
265        const SECRET: &[u8] = &[
266            0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad,
267            0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3,
268            0x0f, 0x21, 0x63, 0x2b,
269        ];
270
271        let secret = OkmBlock::new(SECRET);
272        let builder = KeyBuilder::new(
273            &secret,
274            version,
275            TLS13_CHACHA20_POLY1305_SHA256_INTERNAL
276                .quic
277                .unwrap(),
278            TLS13_CHACHA20_POLY1305_SHA256_INTERNAL.hkdf_provider,
279        );
280        let packet = builder.packet_key();
281        let hpk = builder.header_protection_key();
282
283        const PLAIN: &[u8] = &[0x42, 0x00, 0xbf, 0xf4, 0x01];
284
285        let mut buf = PLAIN.to_vec();
286        let (header, payload) = buf.split_at_mut(4);
287        let tag = packet
288            .encrypt_in_place(PN, header, payload)
289            .unwrap();
290        buf.extend(tag.as_ref());
291
292        let pn_offset = 1;
293        let (header, sample) = buf.split_at_mut(pn_offset + 4);
294        let (first, rest) = header.split_at_mut(1);
295        let sample = &sample[..hpk.sample_len()];
296        hpk.encrypt_in_place(sample, &mut first[0], dbg!(rest))
297            .unwrap();
298
299        assert_eq!(&buf, expected);
300
301        let (header, sample) = buf.split_at_mut(pn_offset + 4);
302        let (first, rest) = header.split_at_mut(1);
303        let sample = &sample[..hpk.sample_len()];
304        hpk.decrypt_in_place(sample, &mut first[0], rest)
305            .unwrap();
306
307        let (header, payload_tag) = buf.split_at_mut(4);
308        let plain = packet
309            .decrypt_in_place(PN, header, payload_tag)
310            .unwrap();
311
312        assert_eq!(plain, &PLAIN[4..]);
313    }
314
315    #[test]
316    fn short_packet_header_protection() {
317        // https://www.rfc-editor.org/rfc/rfc9001.html#name-chacha20-poly1305-short-hea
318        test_short_packet(
319            Version::V1,
320            &[
321                0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
322                0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
323            ],
324        );
325    }
326
327    #[test]
328    fn key_update_test_vector() {
329        fn equal_okm(x: &OkmBlock, y: &OkmBlock) -> bool {
330            x.as_ref() == y.as_ref()
331        }
332
333        let mut secrets = Secrets::new(
334            // Constant dummy values for reproducibility
335            OkmBlock::new(
336                &[
337                    0xb8, 0x76, 0x77, 0x08, 0xf8, 0x77, 0x23, 0x58, 0xa6, 0xea, 0x9f, 0xc4, 0x3e,
338                    0x4a, 0xdd, 0x2c, 0x96, 0x1b, 0x3f, 0x52, 0x87, 0xa6, 0xd1, 0x46, 0x7e, 0xe0,
339                    0xae, 0xab, 0x33, 0x72, 0x4d, 0xbf,
340                ][..],
341            ),
342            OkmBlock::new(
343                &[
344                    0x42, 0xdc, 0x97, 0x21, 0x40, 0xe0, 0xf2, 0xe3, 0x98, 0x45, 0xb7, 0x67, 0x61,
345                    0x34, 0x39, 0xdc, 0x67, 0x58, 0xca, 0x43, 0x25, 0x9b, 0x87, 0x85, 0x06, 0x82,
346                    0x4e, 0xb1, 0xe4, 0x38, 0xd8, 0x55,
347                ][..],
348            ),
349            TLS13_AES_128_GCM_SHA256_INTERNAL,
350            TLS13_AES_128_GCM_SHA256_INTERNAL
351                .quic
352                .unwrap(),
353            Side::Client,
354            Version::V1,
355        );
356        secrets.update();
357
358        assert!(equal_okm(
359            &secrets.client,
360            &OkmBlock::new(
361                &[
362                    0x42, 0xca, 0xc8, 0xc9, 0x1c, 0xd5, 0xeb, 0x40, 0x68, 0x2e, 0x43, 0x2e, 0xdf,
363                    0x2d, 0x2b, 0xe9, 0xf4, 0x1a, 0x52, 0xca, 0x6b, 0x22, 0xd8, 0xe6, 0xcd, 0xb1,
364                    0xe8, 0xac, 0xa9, 0x6, 0x1f, 0xce
365                ][..]
366            )
367        ));
368        assert!(equal_okm(
369            &secrets.server,
370            &OkmBlock::new(
371                &[
372                    0xeb, 0x7f, 0x5e, 0x2a, 0x12, 0x3f, 0x40, 0x7d, 0xb4, 0x99, 0xe3, 0x61, 0xca,
373                    0xe5, 0x90, 0xd4, 0xd9, 0x92, 0xe1, 0x4b, 0x7a, 0xce, 0x3, 0xc2, 0x44, 0xe0,
374                    0x42, 0x21, 0x15, 0xb6, 0xd3, 0x8a
375                ][..]
376            )
377        ));
378    }
379
380    #[test]
381    fn short_packet_header_protection_v2() {
382        // https://www.ietf.org/archive/id/draft-ietf-quic-v2-10.html#name-chacha20-poly1305-short-head
383        test_short_packet(
384            Version::V2,
385            &[
386                0x55, 0x58, 0xb1, 0xc6, 0x0a, 0xe7, 0xb6, 0xb9, 0x32, 0xbc, 0x27, 0xd7, 0x86, 0xf4,
387                0xbc, 0x2b, 0xb2, 0x0f, 0x21, 0x62, 0xba,
388            ],
389        );
390    }
391
392    #[test]
393    fn initial_test_vector_v2() {
394        // https://www.ietf.org/archive/id/draft-ietf-quic-v2-10.html#name-sample-packet-protection-2
395        let icid = [0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
396        let server = Keys::initial(
397            Version::V2,
398            TLS13_AES_128_GCM_SHA256_INTERNAL,
399            TLS13_AES_128_GCM_SHA256_INTERNAL
400                .quic
401                .unwrap(),
402            &icid,
403            Side::Server,
404        );
405        let mut server_payload = [
406            0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03,
407            0x03, 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78,
408            0x25, 0xdd, 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43,
409            0x0b, 0x9a, 0x04, 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00,
410            0x24, 0x00, 0x1d, 0x00, 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0,
411            0x8a, 0x60, 0x99, 0x3c, 0x14, 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83,
412            0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03,
413            0x04,
414        ];
415        let mut server_header = [
416            0xd1, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62,
417            0xb5, 0x00, 0x40, 0x75, 0x00, 0x01,
418        ];
419        let tag = server
420            .local
421            .packet
422            .encrypt_in_place(1, &server_header, &mut server_payload)
423            .unwrap();
424        let (first, rest) = server_header.split_at_mut(1);
425        let rest_len = rest.len();
426        server
427            .local
428            .header
429            .encrypt_in_place(
430                &server_payload[2..18],
431                &mut first[0],
432                &mut rest[rest_len - 2..],
433            )
434            .unwrap();
435        let mut server_packet = server_header.to_vec();
436        server_packet.extend(server_payload);
437        server_packet.extend(tag.as_ref());
438        let expected_server_packet = [
439            0xdc, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62,
440            0xb5, 0x00, 0x40, 0x75, 0xd9, 0x2f, 0xaa, 0xf1, 0x6f, 0x05, 0xd8, 0xa4, 0x39, 0x8c,
441            0x47, 0x08, 0x96, 0x98, 0xba, 0xee, 0xa2, 0x6b, 0x91, 0xeb, 0x76, 0x1d, 0x9b, 0x89,
442            0x23, 0x7b, 0xbf, 0x87, 0x26, 0x30, 0x17, 0x91, 0x53, 0x58, 0x23, 0x00, 0x35, 0xf7,
443            0xfd, 0x39, 0x45, 0xd8, 0x89, 0x65, 0xcf, 0x17, 0xf9, 0xaf, 0x6e, 0x16, 0x88, 0x6c,
444            0x61, 0xbf, 0xc7, 0x03, 0x10, 0x6f, 0xba, 0xf3, 0xcb, 0x4c, 0xfa, 0x52, 0x38, 0x2d,
445            0xd1, 0x6a, 0x39, 0x3e, 0x42, 0x75, 0x75, 0x07, 0x69, 0x80, 0x75, 0xb2, 0xc9, 0x84,
446            0xc7, 0x07, 0xf0, 0xa0, 0x81, 0x2d, 0x8c, 0xd5, 0xa6, 0x88, 0x1e, 0xaf, 0x21, 0xce,
447            0xda, 0x98, 0xf4, 0xbd, 0x23, 0xf6, 0xfe, 0x1a, 0x3e, 0x2c, 0x43, 0xed, 0xd9, 0xce,
448            0x7c, 0xa8, 0x4b, 0xed, 0x85, 0x21, 0xe2, 0xe1, 0x40,
449        ];
450        assert_eq!(server_packet[..], expected_server_packet[..]);
451    }
452
453    // This test is based on picoquic's output for `multipath_aead_test` in
454    // `picoquictest/multipath_test.c`.
455    //
456    // See <https://github.com/private-octopus/picoquic/blob/be0d99e6d4f8759cb7920425351c06a1c6f4a958/picoquictest/multipath_test.c#L1537-L1606>
457    #[test]
458    fn test_multipath_aead_basic() {
459        const SECRET: &[u8; 32] = &[
460            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
461            24, 35, 26, 27, 28, 29, 30, 31,
462        ];
463        const PN: u64 = 12345;
464        const PATH_ID: u32 = 2;
465        const PAYLOAD: &[u8] = b"The quick brown fox jumps over the lazy dog";
466        const HEADER: &[u8] = b"This is a test";
467
468        const EXPECTED: &[u8] = &[
469            123, 139, 232, 52, 136, 25, 201, 143, 250, 89, 87, 39, 37, 63, 0, 210, 220, 227, 186,
470            140, 183, 251, 13, 203, 6, 116, 204, 100, 166, 64, 43, 185, 174, 85, 212, 163, 242,
471            141, 24, 166, 62, 228, 187, 137, 248, 31, 152, 126, 240, 151, 79, 51, 253, 130, 43,
472            114, 173, 234, 254,
473        ];
474
475        let secret = OkmBlock::new(SECRET);
476        let builder = KeyBuilder::new(
477            &secret,
478            Version::V1,
479            TLS13_AES_128_GCM_SHA256_INTERNAL
480                .quic
481                .unwrap(),
482            TLS13_AES_128_GCM_SHA256_INTERNAL.hkdf_provider,
483        );
484
485        let packet = builder.packet_key();
486        let mut buf = PAYLOAD.to_vec();
487        let tag = packet
488            .encrypt_in_place_for_path(PATH_ID, PN, HEADER, &mut buf)
489            .unwrap();
490        buf.extend_from_slice(tag.as_ref());
491
492        assert_eq!(buf.as_slice(), EXPECTED);
493    }
494
495    // This test is based on `multipath_aead_test` in `picoquictest/multipath_test.c`
496    //
497    // See <https://github.com/private-octopus/picoquic/blob/be0d99e6d4f8759cb7920425351c06a1c6f4a958/picoquictest/multipath_test.c#L1537-L1606>
498    #[test]
499    fn test_multipath_aead_roundtrip() {
500        const SECRET: &[u8; 32] = &[
501            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
502            24, 35, 26, 27, 28, 29, 30, 31,
503        ];
504        const PAYLOAD: &[u8] = b"The quick brown fox jumps over the lazy dog";
505        const HEADER: &[u8] = b"This is a test";
506        const PN: u64 = 12345;
507
508        const TEST_PATH_IDS: &[u32] = &[0, 1, 2, 0xaead];
509
510        let secret = OkmBlock::new(SECRET);
511        let builder = KeyBuilder::new(
512            &secret,
513            Version::V1,
514            TLS13_AES_128_GCM_SHA256_INTERNAL
515                .quic
516                .unwrap(),
517            TLS13_AES_128_GCM_SHA256_INTERNAL.hkdf_provider,
518        );
519        let packet = builder.packet_key();
520
521        for &path_id in TEST_PATH_IDS {
522            let mut buf = PAYLOAD.to_vec();
523            let tag = packet
524                .encrypt_in_place_for_path(path_id, PN, HEADER, &mut buf)
525                .unwrap();
526            buf.extend_from_slice(tag.as_ref());
527            let decrypted = packet
528                .decrypt_in_place_for_path(path_id, PN, HEADER, &mut buf)
529                .unwrap();
530            assert_eq!(decrypted, PAYLOAD);
531        }
532    }
533}