1#![allow(clippy::duplicate_mod)]
2
3use alloc::boxed::Box;
4
5use super::ring_like::aead;
6use crate::crypto::cipher::{AeadKey, Iv, Nonce};
7use crate::error::Error;
8use crate::quic;
9
10pub(crate) struct HeaderProtectionKey(aead::quic::HeaderProtectionKey);
11
12impl HeaderProtectionKey {
13 pub(crate) fn new(key: AeadKey, alg: &'static aead::quic::Algorithm) -> Self {
14 Self(aead::quic::HeaderProtectionKey::new(alg, key.as_ref()).unwrap())
15 }
16
17 fn xor_in_place(
18 &self,
19 sample: &[u8],
20 first: &mut u8,
21 packet_number: &mut [u8],
22 masked: bool,
23 ) -> Result<(), Error> {
24 let mask = self
28 .0
29 .new_mask(sample)
30 .map_err(|_| Error::General("sample of invalid length".into()))?;
31
32 let (first_mask, pn_mask) = mask.split_first().unwrap();
35
36 if packet_number.len() > pn_mask.len() {
39 return Err(Error::General("packet number too long".into()));
40 }
41
42 const LONG_HEADER_FORM: u8 = 0x80;
46 let bits = match *first & LONG_HEADER_FORM == LONG_HEADER_FORM {
47 true => 0x0f, false => 0x1f, };
50
51 let first_plain = match masked {
52 true => *first ^ (first_mask & bits),
54 false => *first,
56 };
57 let pn_len = (first_plain & 0x03) as usize + 1;
58
59 *first ^= first_mask & bits;
60 for (dst, m) in packet_number
61 .iter_mut()
62 .zip(pn_mask)
63 .take(pn_len)
64 {
65 *dst ^= m;
66 }
67
68 Ok(())
69 }
70}
71
72impl quic::HeaderProtectionKey for HeaderProtectionKey {
73 fn encrypt_in_place(
74 &self,
75 sample: &[u8],
76 first: &mut u8,
77 packet_number: &mut [u8],
78 ) -> Result<(), Error> {
79 self.xor_in_place(sample, first, packet_number, false)
80 }
81
82 fn decrypt_in_place(
83 &self,
84 sample: &[u8],
85 first: &mut u8,
86 packet_number: &mut [u8],
87 ) -> Result<(), Error> {
88 self.xor_in_place(sample, first, packet_number, true)
89 }
90
91 #[inline]
92 fn sample_len(&self) -> usize {
93 self.0.algorithm().sample_len()
94 }
95}
96
97pub(crate) struct PacketKey {
98 key: aead::LessSafeKey,
100 iv: Iv,
102 confidentiality_limit: u64,
104 integrity_limit: u64,
106}
107
108impl PacketKey {
109 pub(crate) fn new(
110 key: AeadKey,
111 iv: Iv,
112 confidentiality_limit: u64,
113 integrity_limit: u64,
114 aead_algorithm: &'static aead::Algorithm,
115 ) -> Self {
116 Self {
117 key: aead::LessSafeKey::new(
118 aead::UnboundKey::new(aead_algorithm, key.as_ref()).unwrap(),
119 ),
120 iv,
121 confidentiality_limit,
122 integrity_limit,
123 }
124 }
125}
126
127impl quic::PacketKey for PacketKey {
128 fn encrypt_in_place(
129 &self,
130 packet_number: u64,
131 header: &[u8],
132 payload: &mut [u8],
133 ) -> Result<quic::Tag, Error> {
134 let aad = aead::Aad::from(header);
135 let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, packet_number).0);
136 let tag = self
137 .key
138 .seal_in_place_separate_tag(nonce, aad, payload)
139 .map_err(|_| Error::EncryptError)?;
140 Ok(quic::Tag::from(tag.as_ref()))
141 }
142
143 fn encrypt_in_place_for_path(
144 &self,
145 path_id: u32,
146 packet_number: u64,
147 header: &[u8],
148 payload: &mut [u8],
149 ) -> Result<quic::Tag, Error> {
150 let aad = aead::Aad::from(header);
151 let nonce =
152 aead::Nonce::assume_unique_for_key(Nonce::for_path(path_id, &self.iv, packet_number).0);
153 let tag = self
154 .key
155 .seal_in_place_separate_tag(nonce, aad, payload)
156 .map_err(|_| Error::EncryptError)?;
157 Ok(quic::Tag::from(tag.as_ref()))
158 }
159
160 fn decrypt_in_place<'a>(
168 &self,
169 packet_number: u64,
170 header: &[u8],
171 payload: &'a mut [u8],
172 ) -> Result<&'a [u8], Error> {
173 let payload_len = payload.len();
174 let aad = aead::Aad::from(header);
175 let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, packet_number).0);
176 self.key
177 .open_in_place(nonce, aad, payload)
178 .map_err(|_| Error::DecryptError)?;
179
180 let plain_len = payload_len - self.key.algorithm().tag_len();
181 Ok(&payload[..plain_len])
182 }
183
184 fn decrypt_in_place_for_path<'a>(
185 &self,
186 path_id: u32,
187 packet_number: u64,
188 header: &[u8],
189 payload: &'a mut [u8],
190 ) -> Result<&'a [u8], Error> {
191 let payload_len = payload.len();
192 let aad = aead::Aad::from(header);
193 let nonce =
194 aead::Nonce::assume_unique_for_key(Nonce::for_path(path_id, &self.iv, packet_number).0);
195 self.key
196 .open_in_place(nonce, aad, payload)
197 .map_err(|_| Error::DecryptError)?;
198
199 let plain_len = payload_len - self.key.algorithm().tag_len();
200 Ok(&payload[..plain_len])
201 }
202
203 #[inline]
205 fn tag_len(&self) -> usize {
206 self.key.algorithm().tag_len()
207 }
208
209 fn confidentiality_limit(&self) -> u64 {
211 self.confidentiality_limit
212 }
213
214 fn integrity_limit(&self) -> u64 {
216 self.integrity_limit
217 }
218}
219
220pub(crate) struct KeyBuilder {
221 pub(crate) packet_alg: &'static aead::Algorithm,
222 pub(crate) header_alg: &'static aead::quic::Algorithm,
223 pub(crate) confidentiality_limit: u64,
224 pub(crate) integrity_limit: u64,
225}
226
227impl quic::Algorithm for KeyBuilder {
228 fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn quic::PacketKey> {
229 Box::new(PacketKey::new(
230 key,
231 iv,
232 self.confidentiality_limit,
233 self.integrity_limit,
234 self.packet_alg,
235 ))
236 }
237
238 fn header_protection_key(&self, key: AeadKey) -> Box<dyn quic::HeaderProtectionKey> {
239 Box::new(HeaderProtectionKey::new(key, self.header_alg))
240 }
241
242 fn aead_key_len(&self) -> usize {
243 self.packet_alg.key_len()
244 }
245
246 fn fips(&self) -> bool {
247 super::fips()
248 }
249}
250
251#[cfg(test)]
252#[macro_rules_attribute::apply(test_for_each_provider)]
253mod tests {
254 use std::dbg;
255
256 use super::provider::tls13::{
257 TLS13_AES_128_GCM_SHA256_INTERNAL, TLS13_CHACHA20_POLY1305_SHA256_INTERNAL,
258 };
259 use crate::common_state::Side;
260 use crate::crypto::tls13::OkmBlock;
261 use crate::quic::*;
262
263 fn test_short_packet(version: Version, expected: &[u8]) {
264 const PN: u64 = 654360564;
265 const SECRET: &[u8] = &[
266 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad,
267 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3,
268 0x0f, 0x21, 0x63, 0x2b,
269 ];
270
271 let secret = OkmBlock::new(SECRET);
272 let builder = KeyBuilder::new(
273 &secret,
274 version,
275 TLS13_CHACHA20_POLY1305_SHA256_INTERNAL
276 .quic
277 .unwrap(),
278 TLS13_CHACHA20_POLY1305_SHA256_INTERNAL.hkdf_provider,
279 );
280 let packet = builder.packet_key();
281 let hpk = builder.header_protection_key();
282
283 const PLAIN: &[u8] = &[0x42, 0x00, 0xbf, 0xf4, 0x01];
284
285 let mut buf = PLAIN.to_vec();
286 let (header, payload) = buf.split_at_mut(4);
287 let tag = packet
288 .encrypt_in_place(PN, header, payload)
289 .unwrap();
290 buf.extend(tag.as_ref());
291
292 let pn_offset = 1;
293 let (header, sample) = buf.split_at_mut(pn_offset + 4);
294 let (first, rest) = header.split_at_mut(1);
295 let sample = &sample[..hpk.sample_len()];
296 hpk.encrypt_in_place(sample, &mut first[0], dbg!(rest))
297 .unwrap();
298
299 assert_eq!(&buf, expected);
300
301 let (header, sample) = buf.split_at_mut(pn_offset + 4);
302 let (first, rest) = header.split_at_mut(1);
303 let sample = &sample[..hpk.sample_len()];
304 hpk.decrypt_in_place(sample, &mut first[0], rest)
305 .unwrap();
306
307 let (header, payload_tag) = buf.split_at_mut(4);
308 let plain = packet
309 .decrypt_in_place(PN, header, payload_tag)
310 .unwrap();
311
312 assert_eq!(plain, &PLAIN[4..]);
313 }
314
315 #[test]
316 fn short_packet_header_protection() {
317 test_short_packet(
319 Version::V1,
320 &[
321 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
322 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
323 ],
324 );
325 }
326
327 #[test]
328 fn key_update_test_vector() {
329 fn equal_okm(x: &OkmBlock, y: &OkmBlock) -> bool {
330 x.as_ref() == y.as_ref()
331 }
332
333 let mut secrets = Secrets::new(
334 OkmBlock::new(
336 &[
337 0xb8, 0x76, 0x77, 0x08, 0xf8, 0x77, 0x23, 0x58, 0xa6, 0xea, 0x9f, 0xc4, 0x3e,
338 0x4a, 0xdd, 0x2c, 0x96, 0x1b, 0x3f, 0x52, 0x87, 0xa6, 0xd1, 0x46, 0x7e, 0xe0,
339 0xae, 0xab, 0x33, 0x72, 0x4d, 0xbf,
340 ][..],
341 ),
342 OkmBlock::new(
343 &[
344 0x42, 0xdc, 0x97, 0x21, 0x40, 0xe0, 0xf2, 0xe3, 0x98, 0x45, 0xb7, 0x67, 0x61,
345 0x34, 0x39, 0xdc, 0x67, 0x58, 0xca, 0x43, 0x25, 0x9b, 0x87, 0x85, 0x06, 0x82,
346 0x4e, 0xb1, 0xe4, 0x38, 0xd8, 0x55,
347 ][..],
348 ),
349 TLS13_AES_128_GCM_SHA256_INTERNAL,
350 TLS13_AES_128_GCM_SHA256_INTERNAL
351 .quic
352 .unwrap(),
353 Side::Client,
354 Version::V1,
355 );
356 secrets.update();
357
358 assert!(equal_okm(
359 &secrets.client,
360 &OkmBlock::new(
361 &[
362 0x42, 0xca, 0xc8, 0xc9, 0x1c, 0xd5, 0xeb, 0x40, 0x68, 0x2e, 0x43, 0x2e, 0xdf,
363 0x2d, 0x2b, 0xe9, 0xf4, 0x1a, 0x52, 0xca, 0x6b, 0x22, 0xd8, 0xe6, 0xcd, 0xb1,
364 0xe8, 0xac, 0xa9, 0x6, 0x1f, 0xce
365 ][..]
366 )
367 ));
368 assert!(equal_okm(
369 &secrets.server,
370 &OkmBlock::new(
371 &[
372 0xeb, 0x7f, 0x5e, 0x2a, 0x12, 0x3f, 0x40, 0x7d, 0xb4, 0x99, 0xe3, 0x61, 0xca,
373 0xe5, 0x90, 0xd4, 0xd9, 0x92, 0xe1, 0x4b, 0x7a, 0xce, 0x3, 0xc2, 0x44, 0xe0,
374 0x42, 0x21, 0x15, 0xb6, 0xd3, 0x8a
375 ][..]
376 )
377 ));
378 }
379
380 #[test]
381 fn short_packet_header_protection_v2() {
382 test_short_packet(
384 Version::V2,
385 &[
386 0x55, 0x58, 0xb1, 0xc6, 0x0a, 0xe7, 0xb6, 0xb9, 0x32, 0xbc, 0x27, 0xd7, 0x86, 0xf4,
387 0xbc, 0x2b, 0xb2, 0x0f, 0x21, 0x62, 0xba,
388 ],
389 );
390 }
391
392 #[test]
393 fn initial_test_vector_v2() {
394 let icid = [0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
396 let server = Keys::initial(
397 Version::V2,
398 TLS13_AES_128_GCM_SHA256_INTERNAL,
399 TLS13_AES_128_GCM_SHA256_INTERNAL
400 .quic
401 .unwrap(),
402 &icid,
403 Side::Server,
404 );
405 let mut server_payload = [
406 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03,
407 0x03, 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78,
408 0x25, 0xdd, 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43,
409 0x0b, 0x9a, 0x04, 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00,
410 0x24, 0x00, 0x1d, 0x00, 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0,
411 0x8a, 0x60, 0x99, 0x3c, 0x14, 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83,
412 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03,
413 0x04,
414 ];
415 let mut server_header = [
416 0xd1, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62,
417 0xb5, 0x00, 0x40, 0x75, 0x00, 0x01,
418 ];
419 let tag = server
420 .local
421 .packet
422 .encrypt_in_place(1, &server_header, &mut server_payload)
423 .unwrap();
424 let (first, rest) = server_header.split_at_mut(1);
425 let rest_len = rest.len();
426 server
427 .local
428 .header
429 .encrypt_in_place(
430 &server_payload[2..18],
431 &mut first[0],
432 &mut rest[rest_len - 2..],
433 )
434 .unwrap();
435 let mut server_packet = server_header.to_vec();
436 server_packet.extend(server_payload);
437 server_packet.extend(tag.as_ref());
438 let expected_server_packet = [
439 0xdc, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62,
440 0xb5, 0x00, 0x40, 0x75, 0xd9, 0x2f, 0xaa, 0xf1, 0x6f, 0x05, 0xd8, 0xa4, 0x39, 0x8c,
441 0x47, 0x08, 0x96, 0x98, 0xba, 0xee, 0xa2, 0x6b, 0x91, 0xeb, 0x76, 0x1d, 0x9b, 0x89,
442 0x23, 0x7b, 0xbf, 0x87, 0x26, 0x30, 0x17, 0x91, 0x53, 0x58, 0x23, 0x00, 0x35, 0xf7,
443 0xfd, 0x39, 0x45, 0xd8, 0x89, 0x65, 0xcf, 0x17, 0xf9, 0xaf, 0x6e, 0x16, 0x88, 0x6c,
444 0x61, 0xbf, 0xc7, 0x03, 0x10, 0x6f, 0xba, 0xf3, 0xcb, 0x4c, 0xfa, 0x52, 0x38, 0x2d,
445 0xd1, 0x6a, 0x39, 0x3e, 0x42, 0x75, 0x75, 0x07, 0x69, 0x80, 0x75, 0xb2, 0xc9, 0x84,
446 0xc7, 0x07, 0xf0, 0xa0, 0x81, 0x2d, 0x8c, 0xd5, 0xa6, 0x88, 0x1e, 0xaf, 0x21, 0xce,
447 0xda, 0x98, 0xf4, 0xbd, 0x23, 0xf6, 0xfe, 0x1a, 0x3e, 0x2c, 0x43, 0xed, 0xd9, 0xce,
448 0x7c, 0xa8, 0x4b, 0xed, 0x85, 0x21, 0xe2, 0xe1, 0x40,
449 ];
450 assert_eq!(server_packet[..], expected_server_packet[..]);
451 }
452
453 #[test]
458 fn test_multipath_aead_basic() {
459 const SECRET: &[u8; 32] = &[
460 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
461 24, 35, 26, 27, 28, 29, 30, 31,
462 ];
463 const PN: u64 = 12345;
464 const PATH_ID: u32 = 2;
465 const PAYLOAD: &[u8] = b"The quick brown fox jumps over the lazy dog";
466 const HEADER: &[u8] = b"This is a test";
467
468 const EXPECTED: &[u8] = &[
469 123, 139, 232, 52, 136, 25, 201, 143, 250, 89, 87, 39, 37, 63, 0, 210, 220, 227, 186,
470 140, 183, 251, 13, 203, 6, 116, 204, 100, 166, 64, 43, 185, 174, 85, 212, 163, 242,
471 141, 24, 166, 62, 228, 187, 137, 248, 31, 152, 126, 240, 151, 79, 51, 253, 130, 43,
472 114, 173, 234, 254,
473 ];
474
475 let secret = OkmBlock::new(SECRET);
476 let builder = KeyBuilder::new(
477 &secret,
478 Version::V1,
479 TLS13_AES_128_GCM_SHA256_INTERNAL
480 .quic
481 .unwrap(),
482 TLS13_AES_128_GCM_SHA256_INTERNAL.hkdf_provider,
483 );
484
485 let packet = builder.packet_key();
486 let mut buf = PAYLOAD.to_vec();
487 let tag = packet
488 .encrypt_in_place_for_path(PATH_ID, PN, HEADER, &mut buf)
489 .unwrap();
490 buf.extend_from_slice(tag.as_ref());
491
492 assert_eq!(buf.as_slice(), EXPECTED);
493 }
494
495 #[test]
499 fn test_multipath_aead_roundtrip() {
500 const SECRET: &[u8; 32] = &[
501 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
502 24, 35, 26, 27, 28, 29, 30, 31,
503 ];
504 const PAYLOAD: &[u8] = b"The quick brown fox jumps over the lazy dog";
505 const HEADER: &[u8] = b"This is a test";
506 const PN: u64 = 12345;
507
508 const TEST_PATH_IDS: &[u32] = &[0, 1, 2, 0xaead];
509
510 let secret = OkmBlock::new(SECRET);
511 let builder = KeyBuilder::new(
512 &secret,
513 Version::V1,
514 TLS13_AES_128_GCM_SHA256_INTERNAL
515 .quic
516 .unwrap(),
517 TLS13_AES_128_GCM_SHA256_INTERNAL.hkdf_provider,
518 );
519 let packet = builder.packet_key();
520
521 for &path_id in TEST_PATH_IDS {
522 let mut buf = PAYLOAD.to_vec();
523 let tag = packet
524 .encrypt_in_place_for_path(path_id, PN, HEADER, &mut buf)
525 .unwrap();
526 buf.extend_from_slice(tag.as_ref());
527 let decrypted = packet
528 .decrypt_in_place_for_path(path_id, PN, HEADER, &mut buf)
529 .unwrap();
530 assert_eq!(decrypted, PAYLOAD);
531 }
532 }
533}