1use alloc::boxed::Box;
2use alloc::vec::Vec;
3
4use zeroize::Zeroize;
5
6use super::{ActiveKeyExchange, hmac};
7use crate::error::Error;
8use crate::version::TLS13;
9
10pub struct HkdfExpanderUsingHmac(Box<dyn hmac::Key>);
12
13impl HkdfExpanderUsingHmac {
14    fn expand_unchecked(&self, info: &[&[u8]], output: &mut [u8]) {
15        let mut term = hmac::Tag::new(b"");
16
17        for (n, chunk) in output
18            .chunks_mut(self.0.tag_len())
19            .enumerate()
20        {
21            term = self
22                .0
23                .sign_concat(term.as_ref(), info, &[(n + 1) as u8]);
24            chunk.copy_from_slice(&term.as_ref()[..chunk.len()]);
25        }
26    }
27}
28
29impl HkdfExpander for HkdfExpanderUsingHmac {
30    fn expand_slice(&self, info: &[&[u8]], output: &mut [u8]) -> Result<(), OutputLengthError> {
31        if output.len() > 255 * self.0.tag_len() {
32            return Err(OutputLengthError);
33        }
34
35        self.expand_unchecked(info, output);
36        Ok(())
37    }
38
39    fn expand_block(&self, info: &[&[u8]]) -> OkmBlock {
40        let mut tag = [0u8; hmac::Tag::MAX_LEN];
41        let reduced_tag = &mut tag[..self.0.tag_len()];
42        self.expand_unchecked(info, reduced_tag);
43        OkmBlock::new(reduced_tag)
44    }
45
46    fn hash_len(&self) -> usize {
47        self.0.tag_len()
48    }
49}
50
51pub struct HkdfUsingHmac<'a>(pub &'a dyn hmac::Hmac);
53
54impl Hkdf for HkdfUsingHmac<'_> {
55    fn extract_from_zero_ikm(&self, salt: Option<&[u8]>) -> Box<dyn HkdfExpander> {
56        let zeroes = [0u8; hmac::Tag::MAX_LEN];
57        Box::new(HkdfExpanderUsingHmac(self.0.with_key(
58            &self.extract_prk_from_secret(salt, &zeroes[..self.0.hash_output_len()]),
59        )))
60    }
61
62    fn extract_from_secret(&self, salt: Option<&[u8]>, secret: &[u8]) -> Box<dyn HkdfExpander> {
63        Box::new(HkdfExpanderUsingHmac(
64            self.0
65                .with_key(&self.extract_prk_from_secret(salt, secret)),
66        ))
67    }
68
69    fn expander_for_okm(&self, okm: &OkmBlock) -> Box<dyn HkdfExpander> {
70        Box::new(HkdfExpanderUsingHmac(self.0.with_key(okm.as_ref())))
71    }
72
73    fn hmac_sign(&self, key: &OkmBlock, message: &[u8]) -> hmac::Tag {
74        self.0
75            .with_key(key.as_ref())
76            .sign(&[message])
77    }
78}
79
80impl HkdfPrkExtract for HkdfUsingHmac<'_> {
81    fn extract_prk_from_secret(&self, salt: Option<&[u8]>, secret: &[u8]) -> Vec<u8> {
82        let zeroes = [0u8; hmac::Tag::MAX_LEN];
83        let salt = match salt {
84            Some(salt) => salt,
85            None => &zeroes[..self.0.hash_output_len()],
86        };
87        self.0
88            .with_key(salt)
89            .sign(&[secret])
90            .as_ref()
91            .to_vec()
92    }
93}
94
95pub trait HkdfExpander: Send + Sync {
97    fn expand_slice(&self, info: &[&[u8]], output: &mut [u8]) -> Result<(), OutputLengthError>;
109
110    fn expand_block(&self, info: &[&[u8]]) -> OkmBlock;
120
121    fn hash_len(&self) -> usize;
125}
126
127pub trait Hkdf: Send + Sync {
135    fn extract_from_zero_ikm(&self, salt: Option<&[u8]>) -> Box<dyn HkdfExpander>;
141
142    fn extract_from_secret(&self, salt: Option<&[u8]>, secret: &[u8]) -> Box<dyn HkdfExpander>;
146
147    fn extract_from_kx_shared_secret(
155        &self,
156        salt: Option<&[u8]>,
157        kx: Box<dyn ActiveKeyExchange>,
158        peer_pub_key: &[u8],
159    ) -> Result<Box<dyn HkdfExpander>, Error> {
160        Ok(self.extract_from_secret(
161            salt,
162            kx.complete_for_tls_version(peer_pub_key, &TLS13)?
163                .secret_bytes(),
164        ))
165    }
166
167    fn expander_for_okm(&self, okm: &OkmBlock) -> Box<dyn HkdfExpander>;
169
170    fn hmac_sign(&self, key: &OkmBlock, message: &[u8]) -> hmac::Tag;
178
179    fn fips(&self) -> bool {
181        false
182    }
183}
184
185pub(crate) trait HkdfPrkExtract: Hkdf {
195    fn extract_prk_from_secret(&self, salt: Option<&[u8]>, secret: &[u8]) -> Vec<u8>;
202}
203
204pub fn expand<T, const N: usize>(expander: &dyn HkdfExpander, info: &[&[u8]]) -> T
214where
215    T: From<[u8; N]>,
216{
217    let mut output = [0u8; N];
218    expander
219        .expand_slice(info, &mut output)
220        .expect("expand type parameter T is too large");
221    T::from(output)
222}
223
224#[derive(Clone)]
226pub struct OkmBlock {
227    buf: [u8; Self::MAX_LEN],
228    used: usize,
229}
230
231impl OkmBlock {
232    pub fn new(bytes: &[u8]) -> Self {
236        let mut tag = Self {
237            buf: [0u8; Self::MAX_LEN],
238            used: bytes.len(),
239        };
240        tag.buf[..bytes.len()].copy_from_slice(bytes);
241        tag
242    }
243
244    pub const MAX_LEN: usize = 64;
246}
247
248impl Drop for OkmBlock {
249    fn drop(&mut self) {
250        self.buf.zeroize();
251    }
252}
253
254impl AsRef<[u8]> for OkmBlock {
255    fn as_ref(&self) -> &[u8] {
256        &self.buf[..self.used]
257    }
258}
259
260#[derive(Debug)]
263pub struct OutputLengthError;
264
265#[cfg(all(test, feature = "ring"))]
266mod tests {
267    use std::prelude::v1::*;
268
269    use super::{Hkdf, HkdfUsingHmac, expand};
270    use crate::crypto::ring::hmac;
273
274    struct ByteArray<const N: usize>([u8; N]);
275
276    impl<const N: usize> From<[u8; N]> for ByteArray<N> {
277        fn from(array: [u8; N]) -> Self {
278            Self(array)
279        }
280    }
281
282    #[test]
285    fn test_case_1() {
286        let hkdf = HkdfUsingHmac(&hmac::HMAC_SHA256);
287        let ikm = &[0x0b; 22];
288        let salt = &[
289            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
290        ];
291        let info: &[&[u8]] = &[
292            &[0xf0, 0xf1, 0xf2],
293            &[0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9],
294        ];
295
296        let output: ByteArray<42> = expand(
297            hkdf.extract_from_secret(Some(salt), ikm)
298                .as_ref(),
299            info,
300        );
301
302        assert_eq!(
303            &output.0,
304            &[
305                0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36,
306                0x2f, 0x2a, 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, 0x5d, 0xb0, 0x2d, 0x56,
307                0xec, 0xc4, 0xc5, 0xbf, 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, 0x58, 0x65
308            ]
309        );
310    }
311
312    #[test]
313    fn test_case_2() {
314        let hkdf = HkdfUsingHmac(&hmac::HMAC_SHA256);
315        let ikm: Vec<u8> = (0x00u8..=0x4f).collect();
316        let salt: Vec<u8> = (0x60u8..=0xaf).collect();
317        let info: Vec<u8> = (0xb0u8..=0xff).collect();
318
319        let output: ByteArray<82> = expand(
320            hkdf.extract_from_secret(Some(&salt), &ikm)
321                .as_ref(),
322            &[&info],
323        );
324
325        assert_eq!(
326            &output.0,
327            &[
328                0xb1, 0x1e, 0x39, 0x8d, 0xc8, 0x03, 0x27, 0xa1, 0xc8, 0xe7, 0xf7, 0x8c, 0x59, 0x6a,
329                0x49, 0x34, 0x4f, 0x01, 0x2e, 0xda, 0x2d, 0x4e, 0xfa, 0xd8, 0xa0, 0x50, 0xcc, 0x4c,
330                0x19, 0xaf, 0xa9, 0x7c, 0x59, 0x04, 0x5a, 0x99, 0xca, 0xc7, 0x82, 0x72, 0x71, 0xcb,
331                0x41, 0xc6, 0x5e, 0x59, 0x0e, 0x09, 0xda, 0x32, 0x75, 0x60, 0x0c, 0x2f, 0x09, 0xb8,
332                0x36, 0x77, 0x93, 0xa9, 0xac, 0xa3, 0xdb, 0x71, 0xcc, 0x30, 0xc5, 0x81, 0x79, 0xec,
333                0x3e, 0x87, 0xc1, 0x4c, 0x01, 0xd5, 0xc1, 0xf3, 0x43, 0x4f, 0x1d, 0x87
334            ]
335        );
336    }
337
338    #[test]
339    fn test_case_3() {
340        let hkdf = HkdfUsingHmac(&hmac::HMAC_SHA256);
341        let ikm = &[0x0b; 22];
342        let salt = &[];
343        let info = &[];
344
345        let output: ByteArray<42> = expand(
346            hkdf.extract_from_secret(Some(salt), ikm)
347                .as_ref(),
348            info,
349        );
350
351        assert_eq!(
352            &output.0,
353            &[
354                0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f, 0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c,
355                0x5a, 0x31, 0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e, 0xc3, 0x45, 0x4e, 0x5f,
356                0x3c, 0x73, 0x8d, 0x2d, 0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a, 0x96, 0xc8
357            ]
358        );
359    }
360
361    #[test]
362    fn test_salt_not_provided() {
363        let hkdf = HkdfUsingHmac(&hmac::HMAC_SHA384);
370        let ikm = &[0x0b; 40];
371        let info = &[&b"hel"[..], &b"lo"[..]];
372
373        let output: ByteArray<96> = expand(
374            hkdf.extract_from_secret(None, ikm)
375                .as_ref(),
376            info,
377        );
378
379        assert_eq!(
380            &output.0,
381            &[
382                0xd5, 0x45, 0xdd, 0x3a, 0xff, 0x5b, 0x19, 0x46, 0xd4, 0x86, 0xfd, 0xb8, 0xd8, 0x88,
383                0x2e, 0xe0, 0x1c, 0xc1, 0xa5, 0x48, 0xb6, 0x05, 0x75, 0xe4, 0xd7, 0x5d, 0x0f, 0x5f,
384                0x23, 0x40, 0xee, 0x6c, 0x9e, 0x7c, 0x65, 0xd0, 0xee, 0x79, 0xdb, 0xb2, 0x07, 0x1d,
385                0x66, 0xa5, 0x50, 0xc4, 0x8a, 0xa3, 0x93, 0x86, 0x8b, 0x7c, 0x69, 0x41, 0x6b, 0x3e,
386                0x61, 0x44, 0x98, 0xb8, 0xc2, 0xfc, 0x82, 0x82, 0xae, 0xcd, 0x46, 0xcf, 0xb1, 0x47,
387                0xdc, 0xd0, 0x69, 0x0d, 0x19, 0xad, 0xe6, 0x6c, 0x70, 0xfe, 0x87, 0x92, 0x04, 0xb6,
388                0x82, 0x2d, 0x97, 0x7e, 0x46, 0x80, 0x4c, 0xe5, 0x76, 0x72, 0xb4, 0xb8
389            ]
390        );
391    }
392
393    #[test]
394    fn test_output_length_bounds() {
395        let hkdf = HkdfUsingHmac(&hmac::HMAC_SHA256);
396        let ikm = &[];
397        let info = &[];
398
399        let mut output = [0u8; 32 * 255 + 1];
400        assert!(
401            hkdf.extract_from_secret(None, ikm)
402                .expand_slice(info, &mut output)
403                .is_err()
404        );
405    }
406}