1#[cfg(feature = "std")]
36use alloc::collections::VecDeque;
37use alloc::vec::Vec;
38use core::fmt::Debug;
39#[cfg(feature = "std")]
40use std::sync::Mutex;
41
42use crate::enums::CertificateCompressionAlgorithm;
43use crate::msgs::base::{Payload, PayloadU24};
44use crate::msgs::codec::Codec;
45use crate::msgs::handshake::{CertificatePayloadTls13, CompressedCertificatePayload};
46use crate::sync::Arc;
47
48pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
51    &[
52        #[cfg(feature = "brotli")]
53        BROTLI_DECOMPRESSOR,
54        #[cfg(feature = "zlib")]
55        ZLIB_DECOMPRESSOR,
56    ]
57}
58
59pub trait CertDecompressor: Debug + Send + Sync {
61    fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
68
69    fn algorithm(&self) -> CertificateCompressionAlgorithm;
71}
72
73pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
76    &[
77        #[cfg(feature = "brotli")]
78        BROTLI_COMPRESSOR,
79        #[cfg(feature = "zlib")]
80        ZLIB_COMPRESSOR,
81    ]
82}
83
84pub trait CertCompressor: Debug + Send + Sync {
86    fn compress(
95        &self,
96        input: Vec<u8>,
97        level: CompressionLevel,
98    ) -> Result<Vec<u8>, CompressionFailed>;
99
100    fn algorithm(&self) -> CertificateCompressionAlgorithm;
102}
103
104#[derive(Debug, Copy, Clone, Eq, PartialEq)]
106pub enum CompressionLevel {
107    Interactive,
111
112    Amortized,
116}
117
118#[derive(Debug)]
120pub struct DecompressionFailed;
121
122#[derive(Debug)]
124pub struct CompressionFailed;
125
126#[cfg(feature = "zlib")]
127mod feat_zlib_rs {
128    use zlib_rs::c_api::Z_BEST_COMPRESSION;
129    use zlib_rs::{ReturnCode, deflate, inflate};
130
131    use super::*;
132
133    pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
135
136    #[derive(Debug)]
137    struct ZlibRsDecompressor;
138
139    impl CertDecompressor for ZlibRsDecompressor {
140        fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
141            let output_len = output.len();
142            match inflate::uncompress_slice(output, input, inflate::InflateConfig::default()) {
143                (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
144                (_, _) => Err(DecompressionFailed),
145            }
146        }
147
148        fn algorithm(&self) -> CertificateCompressionAlgorithm {
149            CertificateCompressionAlgorithm::Zlib
150        }
151    }
152
153    pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
155
156    #[derive(Debug)]
157    struct ZlibRsCompressor;
158
159    impl CertCompressor for ZlibRsCompressor {
160        fn compress(
161            &self,
162            input: Vec<u8>,
163            level: CompressionLevel,
164        ) -> Result<Vec<u8>, CompressionFailed> {
165            let mut output = alloc::vec![0u8; deflate::compress_bound(input.len())];
166            let config = match level {
167                CompressionLevel::Interactive => deflate::DeflateConfig::default(),
168                CompressionLevel::Amortized => deflate::DeflateConfig::new(Z_BEST_COMPRESSION),
169            };
170            let (output_filled, rc) = deflate::compress_slice(&mut output, &input, config);
171            if rc != ReturnCode::Ok {
172                return Err(CompressionFailed);
173            }
174
175            let used = output_filled.len();
176            output.truncate(used);
177            Ok(output)
178        }
179
180        fn algorithm(&self) -> CertificateCompressionAlgorithm {
181            CertificateCompressionAlgorithm::Zlib
182        }
183    }
184}
185
186#[cfg(feature = "zlib")]
187pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
188
189#[cfg(feature = "brotli")]
190mod feat_brotli {
191    use std::io::{Cursor, Write};
192
193    use super::*;
194
195    pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
197
198    #[derive(Debug)]
199    struct BrotliDecompressor;
200
201    impl CertDecompressor for BrotliDecompressor {
202        fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
203            let mut in_cursor = Cursor::new(input);
204            let mut out_cursor = Cursor::new(output);
205
206            brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
207                .map_err(|_| DecompressionFailed)?;
208
209            if out_cursor.position() as usize != out_cursor.into_inner().len() {
210                return Err(DecompressionFailed);
211            }
212
213            Ok(())
214        }
215
216        fn algorithm(&self) -> CertificateCompressionAlgorithm {
217            CertificateCompressionAlgorithm::Brotli
218        }
219    }
220
221    pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
223
224    #[derive(Debug)]
225    struct BrotliCompressor;
226
227    impl CertCompressor for BrotliCompressor {
228        fn compress(
229            &self,
230            input: Vec<u8>,
231            level: CompressionLevel,
232        ) -> Result<Vec<u8>, CompressionFailed> {
233            let quality = match level {
234                CompressionLevel::Interactive => QUALITY_FAST,
235                CompressionLevel::Amortized => QUALITY_SLOW,
236            };
237            let output = Cursor::new(Vec::with_capacity(input.len() / 2));
238            let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
239            compressor
240                .write_all(&input)
241                .map_err(|_| CompressionFailed)?;
242            Ok(compressor.into_inner().into_inner())
243        }
244
245        fn algorithm(&self) -> CertificateCompressionAlgorithm {
246            CertificateCompressionAlgorithm::Brotli
247        }
248    }
249
250    const BUFFER_SIZE: usize = 4096;
254
255    const LGWIN: u32 = 22;
257
258    const QUALITY_FAST: u32 = 4;
261
262    const QUALITY_SLOW: u32 = 11;
264}
265
266#[cfg(feature = "brotli")]
267pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
268
269#[derive(Debug)]
275pub enum CompressionCache {
276    Disabled,
279
280    #[cfg(feature = "std")]
282    Enabled(CompressionCacheInner),
283}
284
285#[cfg(feature = "std")]
289#[derive(Debug)]
290pub struct CompressionCacheInner {
291    size: usize,
293
294    entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
298}
299
300impl CompressionCache {
301    #[cfg(feature = "std")]
304    pub fn new(size: usize) -> Self {
305        if size == 0 {
306            return Self::Disabled;
307        }
308
309        Self::Enabled(CompressionCacheInner {
310            size,
311            entries: Mutex::new(VecDeque::with_capacity(size)),
312        })
313    }
314
315    pub(crate) fn compression_for(
321        &self,
322        compressor: &dyn CertCompressor,
323        original: &CertificatePayloadTls13<'_>,
324    ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
325        match self {
326            Self::Disabled => Self::uncached_compression(compressor, original),
327
328            #[cfg(feature = "std")]
329            Self::Enabled(_) => self.compression_for_impl(compressor, original),
330        }
331    }
332
333    #[cfg(feature = "std")]
334    fn compression_for_impl(
335        &self,
336        compressor: &dyn CertCompressor,
337        original: &CertificatePayloadTls13<'_>,
338    ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
339        let (max_size, entries) = match self {
340            Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
341            _ => unreachable!(),
342        };
343
344        if !original.context.0.is_empty() {
347            return Self::uncached_compression(compressor, original);
348        }
349
350        let encoding = original.get_encoding();
352        let algorithm = compressor.algorithm();
353
354        let mut cache = entries
355            .lock()
356            .map_err(|_| CompressionFailed)?;
357        for (i, item) in cache.iter().enumerate() {
358            if item.algorithm == algorithm && item.original == encoding {
359                let item = cache.remove(i).unwrap();
361                cache.push_back(item.clone());
362                return Ok(item);
363            }
364        }
365        drop(cache);
366
367        let uncompressed_len = encoding.len() as u32;
369        let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
370        let new_entry = Arc::new(CompressionCacheEntry {
371            algorithm,
372            original: encoding,
373            compressed: CompressedCertificatePayload {
374                alg: algorithm,
375                uncompressed_len,
376                compressed: PayloadU24(Payload::new(compressed)),
377            },
378        });
379
380        let mut cache = entries
382            .lock()
383            .map_err(|_| CompressionFailed)?;
384        if cache.len() == max_size {
385            cache.pop_front();
386        }
387        cache.push_back(new_entry.clone());
388        Ok(new_entry)
389    }
390
391    fn uncached_compression(
393        compressor: &dyn CertCompressor,
394        original: &CertificatePayloadTls13<'_>,
395    ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
396        let algorithm = compressor.algorithm();
397        let encoding = original.get_encoding();
398        let uncompressed_len = encoding.len() as u32;
399        let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
400
401        Ok(Arc::new(CompressionCacheEntry {
404            algorithm,
405            original: Vec::new(),
406            compressed: CompressedCertificatePayload {
407                alg: algorithm,
408                uncompressed_len,
409                compressed: PayloadU24(Payload::new(compressed)),
410            },
411        }))
412    }
413}
414
415impl Default for CompressionCache {
416    fn default() -> Self {
417        #[cfg(feature = "std")]
418        {
419            Self::new(4)
421        }
422
423        #[cfg(not(feature = "std"))]
424        {
425            Self::Disabled
426        }
427    }
428}
429
430#[cfg_attr(not(feature = "std"), allow(dead_code))]
431#[derive(Debug)]
432pub(crate) struct CompressionCacheEntry {
433    algorithm: CertificateCompressionAlgorithm,
435    original: Vec<u8>,
436
437    compressed: CompressedCertificatePayload<'static>,
439}
440
441impl CompressionCacheEntry {
442    pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
443        self.compressed.as_borrowed()
444    }
445}
446
447#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
448mod tests {
449    use std::{println, vec};
450
451    use super::*;
452
453    #[test]
454    #[cfg(feature = "zlib")]
455    fn test_zlib() {
456        test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
457    }
458
459    #[test]
460    #[cfg(feature = "brotli")]
461    fn test_brotli() {
462        test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
463    }
464
465    fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
466        assert_eq!(comp.algorithm(), decomp.algorithm());
467        for sz in [16, 64, 512, 2048, 8192, 16384] {
468            test_trivial_pairwise(comp, decomp, sz);
469        }
470        test_decompress_wrong_len(comp, decomp);
471        test_decompress_garbage(decomp);
472    }
473
474    fn test_trivial_pairwise(
475        comp: &dyn CertCompressor,
476        decomp: &dyn CertDecompressor,
477        plain_len: usize,
478    ) {
479        let original = vec![0u8; plain_len];
480
481        for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
482            let compressed = comp
483                .compress(original.clone(), level)
484                .unwrap();
485            println!(
486                "{:?} compressed trivial {} -> {} using {:?} level",
487                comp.algorithm(),
488                original.len(),
489                compressed.len(),
490                level
491            );
492            let mut recovered = vec![0xffu8; plain_len];
493            decomp
494                .decompress(&compressed, &mut recovered)
495                .unwrap();
496            assert_eq!(original, recovered);
497        }
498    }
499
500    fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
501        let original = vec![0u8; 2048];
502        let compressed = comp
503            .compress(original.clone(), CompressionLevel::Interactive)
504            .unwrap();
505        println!("{compressed:?}");
506
507        let mut recovered = vec![0xffu8; original.len() + 1];
509        decomp
510            .decompress(&compressed, &mut recovered)
511            .unwrap_err();
512
513        let mut recovered = vec![0xffu8; original.len() - 1];
515        decomp
516            .decompress(&compressed, &mut recovered)
517            .unwrap_err();
518    }
519
520    fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
521        let junk = [0u8; 1024];
522        let mut recovered = vec![0u8; 512];
523        decomp
524            .decompress(&junk, &mut recovered)
525            .unwrap_err();
526    }
527
528    #[test]
529    #[cfg(all(feature = "brotli", feature = "zlib"))]
530    fn test_cache_evicts_lru() {
531        use core::sync::atomic::{AtomicBool, Ordering};
532
533        use pki_types::CertificateDer;
534
535        let cache = CompressionCache::default();
536
537        let cert = CertificateDer::from(vec![1]);
538
539        let cert1 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"1"));
540        let cert2 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"2"));
541        let cert3 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"3"));
542        let cert4 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"4"));
543
544        cache
547            .compression_for(
548                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
549                &cert1,
550            )
551            .unwrap();
552        cache
553            .compression_for(
554                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
555                &cert2,
556            )
557            .unwrap();
558        cache
559            .compression_for(
560                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
561                &cert3,
562            )
563            .unwrap();
564        cache
565            .compression_for(
566                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
567                &cert4,
568            )
569            .unwrap();
570
571        cache
575            .compression_for(
576                &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
577                &cert4,
578            )
579            .unwrap();
580
581        cache
583            .compression_for(
584                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
585                &cert2,
586            )
587            .unwrap();
588        cache
589            .compression_for(
590                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
591                &cert3,
592            )
593            .unwrap();
594        cache
595            .compression_for(
596                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
597                &cert4,
598            )
599            .unwrap();
600        cache
601            .compression_for(
602                &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
603                &cert4,
604            )
605            .unwrap();
606
607        cache
609            .compression_for(
610                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
611                &cert1,
612            )
613            .unwrap();
614
615        cache
618            .compression_for(
619                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
620                &cert4,
621            )
622            .unwrap();
623        cache
624            .compression_for(
625                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
626                &cert3,
627            )
628            .unwrap();
629        cache
630            .compression_for(
631                &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
632                &cert1,
633            )
634            .unwrap();
635
636        cache
639            .compression_for(
640                &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
641                &cert1,
642            )
643            .unwrap();
644
645        cache
647            .compression_for(
648                &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
649                &cert4,
650            )
651            .unwrap();
652
653        #[derive(Debug)]
654        struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
655
656        impl CertCompressor for RequireCompress {
657            fn compress(
658                &self,
659                input: Vec<u8>,
660                level: CompressionLevel,
661            ) -> Result<Vec<u8>, CompressionFailed> {
662                self.1.store(true, Ordering::SeqCst);
663                self.0.compress(input, level)
664            }
665
666            fn algorithm(&self) -> CertificateCompressionAlgorithm {
667                self.0.algorithm()
668            }
669        }
670
671        impl Drop for RequireCompress {
672            fn drop(&mut self) {
673                assert_eq!(self.1.load(Ordering::SeqCst), self.2);
674            }
675        }
676    }
677}