fdeflate/
decompress.rs

1use simd_adler32::Adler32;
2
3use crate::{
4    huffman::{self, build_table},
5    tables::{
6        self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FIXED_DIST_TABLE,
7        FIXED_LITLEN_TABLE, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, LITLEN_TABLE_ENTRIES,
8    },
9};
10
11/// An error encountered while decompressing a deflate stream.
12#[derive(Debug, PartialEq)]
13pub enum DecompressionError {
14    /// The zlib header is corrupt.
15    BadZlibHeader,
16    /// All input was consumed, but the end of the stream hasn't been reached.
17    InsufficientInput,
18    /// A block header specifies an invalid block type.
19    InvalidBlockType,
20    /// An uncompressed block's NLEN value is invalid.
21    InvalidUncompressedBlockLength,
22    /// Too many literals were specified.
23    InvalidHlit,
24    /// Too many distance codes were specified.
25    InvalidHdist,
26    /// Attempted to repeat a previous code before reading any codes, or past the end of the code
27    /// lengths.
28    InvalidCodeLengthRepeat,
29    /// The stream doesn't specify a valid huffman tree.
30    BadCodeLengthHuffmanTree,
31    /// The stream doesn't specify a valid huffman tree.
32    BadLiteralLengthHuffmanTree,
33    /// The stream doesn't specify a valid huffman tree.
34    BadDistanceHuffmanTree,
35    /// The stream contains a literal/length code that was not allowed by the header.
36    InvalidLiteralLengthCode,
37    /// The stream contains a distance code that was not allowed by the header.
38    InvalidDistanceCode,
39    /// The stream contains contains back-reference as the first symbol.
40    InputStartsWithRun,
41    /// The stream contains a back-reference that is too far back.
42    DistanceTooFarBack,
43    /// The deflate stream checksum is incorrect.
44    WrongChecksum,
45    /// Extra input data.
46    ExtraInput,
47}
48
49struct BlockHeader {
50    hlit: usize,
51    hdist: usize,
52    hclen: usize,
53    num_lengths_read: usize,
54
55    /// Low 3-bits are code length code length, high 5-bits are code length code.
56    table: [u32; 128],
57    code_lengths: [u8; 320],
58}
59
60pub const LITERAL_ENTRY: u32 = 0x8000;
61pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
62pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
63
64/// The Decompressor state for a compressed block.
65#[derive(Eq, PartialEq, Debug)]
66struct CompressedBlock {
67    litlen_table: Box<[u32; 4096]>,
68    secondary_table: Vec<u16>,
69
70    dist_table: Box<[u32; 512]>,
71    dist_secondary_table: Vec<u16>,
72
73    eof_code: u16,
74    eof_mask: u16,
75    eof_bits: u8,
76}
77
78#[derive(Debug, Copy, Clone, Eq, PartialEq)]
79enum State {
80    ZlibHeader,
81    BlockHeader,
82    CodeLengthCodes,
83    CodeLengths,
84    CompressedData,
85    UncompressedData,
86    Checksum,
87    Done,
88}
89
90/// Decompressor for arbitrary zlib streams.
91pub struct Decompressor {
92    /// State for decoding a compressed block.
93    compression: CompressedBlock,
94    // State for decoding a block header.
95    header: BlockHeader,
96    // Number of bytes left for uncompressed block.
97    uncompressed_bytes_left: u16,
98
99    buffer: u64,
100    nbits: u8,
101
102    queued_rle: Option<(u8, usize)>,
103    queued_backref: Option<(usize, usize)>,
104    last_block: bool,
105    fixed_table: bool,
106
107    state: State,
108    checksum: Adler32,
109    ignore_adler32: bool,
110}
111
112impl Default for Decompressor {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl Decompressor {
119    /// Create a new decompressor.
120    pub fn new() -> Self {
121        Self {
122            buffer: 0,
123            nbits: 0,
124            compression: CompressedBlock {
125                litlen_table: Box::new([0; 4096]),
126                dist_table: Box::new([0; 512]),
127                secondary_table: Vec::new(),
128                dist_secondary_table: Vec::new(),
129                eof_code: 0,
130                eof_mask: 0,
131                eof_bits: 0,
132            },
133            header: BlockHeader {
134                hlit: 0,
135                hdist: 0,
136                hclen: 0,
137                table: [0; 128],
138                num_lengths_read: 0,
139                code_lengths: [0; 320],
140            },
141            uncompressed_bytes_left: 0,
142            queued_rle: None,
143            queued_backref: None,
144            checksum: Adler32::new(),
145            state: State::ZlibHeader,
146            last_block: false,
147            ignore_adler32: false,
148            fixed_table: false,
149        }
150    }
151
152    /// Ignore the checksum at the end of the stream.
153    pub fn ignore_adler32(&mut self) {
154        self.ignore_adler32 = true;
155    }
156
157    fn fill_buffer(&mut self, input: &mut &[u8]) {
158        if input.len() >= 8 {
159            self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits;
160            *input = &input[(63 - self.nbits as usize) / 8..];
161            self.nbits |= 56;
162        } else {
163            let nbytes = input.len().min((63 - self.nbits as usize) / 8);
164            let mut input_data = [0; 8];
165            input_data[..nbytes].copy_from_slice(&input[..nbytes]);
166            self.buffer |= u64::from_le_bytes(input_data)
167                .checked_shl(self.nbits as u32)
168                .unwrap_or(0);
169            self.nbits += nbytes as u8 * 8;
170            *input = &input[nbytes..];
171        }
172    }
173
174    fn peak_bits(&mut self, nbits: u8) -> u64 {
175        debug_assert!(nbits <= 56 && nbits <= self.nbits);
176        self.buffer & ((1u64 << nbits) - 1)
177    }
178    fn consume_bits(&mut self, nbits: u8) {
179        debug_assert!(self.nbits >= nbits);
180        self.buffer >>= nbits;
181        self.nbits -= nbits;
182    }
183
184    fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
185        self.fill_buffer(remaining_input);
186        if self.nbits < 10 {
187            return Ok(());
188        }
189
190        let start = self.peak_bits(3);
191        self.last_block = start & 1 != 0;
192        match start >> 1 {
193            0b00 => {
194                let align_bits = (self.nbits - 3) % 8;
195                let header_bits = 3 + 32 + align_bits;
196                if self.nbits < header_bits {
197                    return Ok(());
198                }
199
200                let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16;
201                let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16;
202                if nlen != !len {
203                    return Err(DecompressionError::InvalidUncompressedBlockLength);
204                }
205
206                self.state = State::UncompressedData;
207                self.uncompressed_bytes_left = len;
208                self.consume_bits(header_bits);
209                Ok(())
210            }
211            0b01 => {
212                self.consume_bits(3);
213
214                // Check for an entirely empty blocks which can happen if there are "partial
215                // flushes" in the deflate stream. With fixed huffman codes, the EOF symbol is
216                // 7-bits of zeros so we peak ahead and see if the next 7-bits are all zero.
217                if self.peak_bits(7) == 0 {
218                    self.consume_bits(7);
219                    if self.last_block {
220                        self.state = State::Checksum;
221                        return Ok(());
222                    }
223
224                    // At this point we've consumed the entire block and need to read the next block
225                    // header. If tail call optimization were guaranteed, we could just recurse
226                    // here. But without it, a long sequence of empty fixed-blocks might cause a
227                    // stack overflow. Instead, we consume all empty blocks in a loop and then
228                    // recurse. This is the only recursive call this function, and thus is safe.
229                    while self.nbits >= 10 && self.peak_bits(10) == 0b010 {
230                        self.consume_bits(10);
231                        self.fill_buffer(remaining_input);
232                    }
233                    return self.read_block_header(remaining_input);
234                }
235
236                // Build decoding tables if the previous block wasn't also a fixed block.
237                if !self.fixed_table {
238                    self.fixed_table = true;
239                    for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
240                        chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
241                    }
242                    for chunk in self.compression.dist_table.chunks_exact_mut(32) {
243                        chunk.copy_from_slice(&FIXED_DIST_TABLE);
244                    }
245                    self.compression.eof_bits = 7;
246                    self.compression.eof_code = 0;
247                    self.compression.eof_mask = 0x7f;
248                }
249
250                self.state = State::CompressedData;
251                Ok(())
252            }
253            0b10 => {
254                if self.nbits < 17 {
255                    return Ok(());
256                }
257
258                self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257;
259                self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1;
260                self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4;
261                if self.header.hlit > 286 {
262                    return Err(DecompressionError::InvalidHlit);
263                }
264                if self.header.hdist > 30 {
265                    return Err(DecompressionError::InvalidHdist);
266                }
267
268                self.consume_bits(17);
269                self.state = State::CodeLengthCodes;
270                self.fixed_table = false;
271                Ok(())
272            }
273            0b11 => Err(DecompressionError::InvalidBlockType),
274            _ => unreachable!(),
275        }
276    }
277
278    fn read_code_length_codes(
279        &mut self,
280        remaining_input: &mut &[u8],
281    ) -> Result<(), DecompressionError> {
282        self.fill_buffer(remaining_input);
283        if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
284            return Ok(());
285        }
286
287        let mut code_length_lengths = [0; 19];
288        for i in 0..self.header.hclen {
289            code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8;
290            self.consume_bits(3);
291
292            // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds
293            // between 56 and 63 bits total.
294            if i == 17 {
295                self.fill_buffer(remaining_input);
296            }
297        }
298
299        let mut codes = [0; 19];
300        if !build_table(
301            &code_length_lengths,
302            &[],
303            &mut codes,
304            &mut self.header.table,
305            &mut Vec::new(),
306            false,
307            false,
308        ) {
309            return Err(DecompressionError::BadCodeLengthHuffmanTree);
310        }
311
312        self.state = State::CodeLengths;
313        self.header.num_lengths_read = 0;
314        Ok(())
315    }
316
317    fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
318        let total_lengths = self.header.hlit + self.header.hdist;
319        while self.header.num_lengths_read < total_lengths {
320            self.fill_buffer(remaining_input);
321            if self.nbits < 7 {
322                return Ok(());
323            }
324
325            let code = self.peak_bits(7);
326            let entry = self.header.table[code as usize];
327            let length = (entry & 0x7) as u8;
328            let symbol = (entry >> 16) as u8;
329
330            debug_assert!(length != 0);
331            match symbol {
332                0..=15 => {
333                    self.header.code_lengths[self.header.num_lengths_read] = symbol;
334                    self.header.num_lengths_read += 1;
335                    self.consume_bits(length);
336                }
337                16..=18 => {
338                    let (base_repeat, extra_bits) = match symbol {
339                        16 => (3, 2),
340                        17 => (3, 3),
341                        18 => (11, 7),
342                        _ => unreachable!(),
343                    };
344
345                    if self.nbits < length + extra_bits {
346                        return Ok(());
347                    }
348
349                    let value = match symbol {
350                        16 => {
351                            self.header.code_lengths[self
352                                .header
353                                .num_lengths_read
354                                .checked_sub(1)
355                                .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
356                            // TODO: is this right?
357                        }
358                        17 => 0,
359                        18 => 0,
360                        _ => unreachable!(),
361                    };
362
363                    let repeat =
364                        (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat;
365                    if self.header.num_lengths_read + repeat > total_lengths {
366                        return Err(DecompressionError::InvalidCodeLengthRepeat);
367                    }
368
369                    for i in 0..repeat {
370                        self.header.code_lengths[self.header.num_lengths_read + i] = value;
371                    }
372                    self.header.num_lengths_read += repeat;
373                    self.consume_bits(length + extra_bits);
374                }
375                _ => unreachable!(),
376            }
377        }
378
379        self.header
380            .code_lengths
381            .copy_within(self.header.hlit..total_lengths, 288);
382        for i in self.header.hlit..288 {
383            self.header.code_lengths[i] = 0;
384        }
385        for i in 288 + self.header.hdist..320 {
386            self.header.code_lengths[i] = 0;
387        }
388
389        Self::build_tables(
390            self.header.hlit,
391            &self.header.code_lengths,
392            &mut self.compression,
393        )?;
394        self.state = State::CompressedData;
395        Ok(())
396    }
397
398    fn build_tables(
399        hlit: usize,
400        code_lengths: &[u8],
401        compression: &mut CompressedBlock,
402    ) -> Result<(), DecompressionError> {
403        // If there is no code assigned for the EOF symbol then the bitstream is invalid.
404        if code_lengths[256] == 0 {
405            // TODO: Return a dedicated error in this case.
406            return Err(DecompressionError::BadLiteralLengthHuffmanTree);
407        }
408
409        let mut codes = [0; 288];
410        compression.secondary_table.clear();
411        if !huffman::build_table(
412            &code_lengths[..hlit],
413            &LITLEN_TABLE_ENTRIES,
414            &mut codes[..hlit],
415            &mut *compression.litlen_table,
416            &mut compression.secondary_table,
417            false,
418            true,
419        ) {
420            return Err(DecompressionError::BadCodeLengthHuffmanTree);
421        }
422
423        compression.eof_code = codes[256];
424        compression.eof_mask = (1 << code_lengths[256]) - 1;
425        compression.eof_bits = code_lengths[256];
426
427        // Build the distance code table.
428        let lengths = &code_lengths[288..320];
429        if lengths == [0; 32] {
430            compression.dist_table.fill(0);
431        } else {
432            let mut dist_codes = [0; 32];
433            if !huffman::build_table(
434                lengths,
435                &tables::DISTANCE_TABLE_ENTRIES,
436                &mut dist_codes,
437                &mut *compression.dist_table,
438                &mut compression.dist_secondary_table,
439                true,
440                false,
441            ) {
442                return Err(DecompressionError::BadDistanceHuffmanTree);
443            }
444        }
445
446        Ok(())
447    }
448
449    fn read_compressed(
450        &mut self,
451        remaining_input: &mut &[u8],
452        output: &mut [u8],
453        mut output_index: usize,
454    ) -> Result<usize, DecompressionError> {
455        // Fast decoding loop.
456        //
457        // This loop is optimized for speed and is the main decoding loop for the decompressor,
458        // which is used when there are at least 8 bytes of input and output data available. It
459        // assumes that the bitbuffer is full (nbits >= 56) and that litlen_entry has been loaded.
460        //
461        // These assumptions enable a few optimizations:
462        // - Nearly all checks for nbits are avoided.
463        // - Checking the input size is optimized out in the refill function call.
464        // - The litlen_entry for the next loop iteration can be loaded in parallel with refilling
465        //   the bit buffer. This is because when the input is non-empty, the bit buffer actually
466        //   has 64-bits of valid data (even though nbits will be in 56..=63).
467        self.fill_buffer(remaining_input);
468        let mut litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize];
469        while self.state == State::CompressedData
470            && output_index + 8 <= output.len()
471            && remaining_input.len() >= 8
472        {
473            // First check whether the next symbol is a literal. This code does up to 2 additional
474            // table lookups to decode more literals.
475            let mut bits;
476            let mut litlen_code_bits = litlen_entry as u8;
477            if litlen_entry & LITERAL_ENTRY != 0 {
478                let litlen_entry2 = self.compression.litlen_table
479                    [(self.buffer >> litlen_code_bits & 0xfff) as usize];
480                let litlen_code_bits2 = litlen_entry2 as u8;
481                let litlen_entry3 = self.compression.litlen_table
482                    [(self.buffer >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize];
483                let litlen_code_bits3 = litlen_entry3 as u8;
484                let litlen_entry4 = self.compression.litlen_table[(self.buffer
485                    >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
486                    & 0xfff)
487                    as usize];
488
489                let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
490                output[output_index] = (litlen_entry >> 16) as u8;
491                output[output_index + 1] = (litlen_entry >> 24) as u8;
492                output_index += advance_output_bytes;
493
494                if litlen_entry2 & LITERAL_ENTRY != 0 {
495                    let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
496                    output[output_index] = (litlen_entry2 >> 16) as u8;
497                    output[output_index + 1] = (litlen_entry2 >> 24) as u8;
498                    output_index += advance_output_bytes2;
499
500                    if litlen_entry3 & LITERAL_ENTRY != 0 {
501                        let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
502                        output[output_index] = (litlen_entry3 >> 16) as u8;
503                        output[output_index + 1] = (litlen_entry3 >> 24) as u8;
504                        output_index += advance_output_bytes3;
505
506                        litlen_entry = litlen_entry4;
507                        self.consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3);
508                        self.fill_buffer(remaining_input);
509                        continue;
510                    } else {
511                        self.consume_bits(litlen_code_bits + litlen_code_bits2);
512                        litlen_entry = litlen_entry3;
513                        litlen_code_bits = litlen_code_bits3;
514                        self.fill_buffer(remaining_input);
515                        bits = self.buffer;
516                    }
517                } else {
518                    self.consume_bits(litlen_code_bits);
519                    bits = self.buffer;
520                    litlen_entry = litlen_entry2;
521                    litlen_code_bits = litlen_code_bits2;
522                    if self.nbits < 48 {
523                        self.fill_buffer(remaining_input);
524                    }
525                }
526            } else {
527                bits = self.buffer;
528            }
529
530            // The next symbol is either a 13+ bit literal, back-reference, or an EOF symbol.
531            let (length_base, length_extra_bits, litlen_code_bits) =
532                if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
533                    (
534                        litlen_entry >> 16,
535                        (litlen_entry >> 8) as u8,
536                        litlen_code_bits,
537                    )
538                } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
539                    let secondary_table_index =
540                        (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
541                    let secondary_entry =
542                        self.compression.secondary_table[secondary_table_index as usize];
543                    let litlen_symbol = secondary_entry >> 4;
544                    let litlen_code_bits = (secondary_entry & 0xf) as u8;
545
546                    match litlen_symbol {
547                        0..=255 => {
548                            self.consume_bits(litlen_code_bits);
549                            litlen_entry =
550                                self.compression.litlen_table[(self.buffer & 0xfff) as usize];
551                            self.fill_buffer(remaining_input);
552                            output[output_index] = litlen_symbol as u8;
553                            output_index += 1;
554                            continue;
555                        }
556                        256 => {
557                            self.consume_bits(litlen_code_bits);
558                            self.state = match self.last_block {
559                                true => State::Checksum,
560                                false => State::BlockHeader,
561                            };
562                            break;
563                        }
564                        _ => (
565                            LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
566                            LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
567                            litlen_code_bits,
568                        ),
569                    }
570                } else if litlen_code_bits == 0 {
571                    return Err(DecompressionError::InvalidLiteralLengthCode);
572                } else {
573                    self.consume_bits(litlen_code_bits);
574                    self.state = match self.last_block {
575                        true => State::Checksum,
576                        false => State::BlockHeader,
577                    };
578                    break;
579                };
580            bits >>= litlen_code_bits;
581
582            let length_extra_mask = (1 << length_extra_bits) - 1;
583            let length = length_base as usize + (bits & length_extra_mask) as usize;
584            bits >>= length_extra_bits;
585
586            let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
587            let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
588                (
589                    (dist_entry >> 16) as u16,
590                    (dist_entry >> 8) as u8 & 0xf,
591                    dist_entry as u8,
592                )
593            } else if dist_entry >> 8 == 0 {
594                return Err(DecompressionError::InvalidDistanceCode);
595            } else {
596                let secondary_table_index =
597                    (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
598                let secondary_entry =
599                    self.compression.dist_secondary_table[secondary_table_index as usize];
600                let dist_symbol = (secondary_entry >> 4) as usize;
601                if dist_symbol >= 30 {
602                    return Err(DecompressionError::InvalidDistanceCode);
603                }
604
605                (
606                    DIST_SYM_TO_DIST_BASE[dist_symbol],
607                    DIST_SYM_TO_DIST_EXTRA[dist_symbol],
608                    (secondary_entry & 0xf) as u8,
609                )
610            };
611            bits >>= dist_code_bits;
612
613            let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
614            if dist > output_index {
615                return Err(DecompressionError::DistanceTooFarBack);
616            }
617
618            self.consume_bits(
619                litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
620            );
621            self.fill_buffer(remaining_input);
622            litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize];
623
624            let copy_length = length.min(output.len() - output_index);
625            if dist == 1 {
626                let last = output[output_index - 1];
627                output[output_index..][..copy_length].fill(last);
628
629                if copy_length < length {
630                    self.queued_rle = Some((last, length - copy_length));
631                    output_index = output.len();
632                    break;
633                }
634            } else if output_index + length + 15 <= output.len() {
635                let start = output_index - dist;
636                output.copy_within(start..start + 16, output_index);
637
638                if length > 16 || dist < 16 {
639                    for i in (0..length).step_by(dist.min(16)).skip(1) {
640                        output.copy_within(start + i..start + i + 16, output_index + i);
641                    }
642                }
643            } else {
644                if dist < copy_length {
645                    for i in 0..copy_length {
646                        output[output_index + i] = output[output_index + i - dist];
647                    }
648                } else {
649                    output.copy_within(
650                        output_index - dist..output_index + copy_length - dist,
651                        output_index,
652                    )
653                }
654
655                if copy_length < length {
656                    self.queued_backref = Some((dist, length - copy_length));
657                    output_index = output.len();
658                    break;
659                }
660            }
661            output_index += copy_length;
662        }
663
664        // Careful decoding loop.
665        //
666        // This loop processes the remaining input when we're too close to the end of the input or
667        // output to use the fast loop.
668        while let State::CompressedData = self.state {
669            self.fill_buffer(remaining_input);
670            if output_index == output.len() {
671                break;
672            }
673
674            let mut bits = self.buffer;
675            let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize];
676            let litlen_code_bits = litlen_entry as u8;
677
678            if litlen_entry & LITERAL_ENTRY != 0 {
679                // Fast path: the next symbol is <= 12 bits and a literal, the table specifies the
680                // output bytes and we can directly write them to the output buffer.
681                let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
682
683                if self.nbits < litlen_code_bits {
684                    break;
685                } else if output_index + 1 < output.len() {
686                    output[output_index] = (litlen_entry >> 16) as u8;
687                    output[output_index + 1] = (litlen_entry >> 24) as u8;
688                    output_index += advance_output_bytes;
689                    self.consume_bits(litlen_code_bits);
690                    continue;
691                } else if output_index + advance_output_bytes == output.len() {
692                    debug_assert_eq!(advance_output_bytes, 1);
693                    output[output_index] = (litlen_entry >> 16) as u8;
694                    output_index += 1;
695                    self.consume_bits(litlen_code_bits);
696                    break;
697                } else {
698                    debug_assert_eq!(advance_output_bytes, 2);
699                    output[output_index] = (litlen_entry >> 16) as u8;
700                    self.queued_rle = Some(((litlen_entry >> 24) as u8, 1));
701                    output_index += 1;
702                    self.consume_bits(litlen_code_bits);
703                    break;
704                }
705            }
706
707            let (length_base, length_extra_bits, litlen_code_bits) =
708                if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
709                    (
710                        litlen_entry >> 16,
711                        (litlen_entry >> 8) as u8,
712                        litlen_code_bits,
713                    )
714                } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
715                    let secondary_table_index =
716                        (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
717                    let secondary_entry =
718                        self.compression.secondary_table[secondary_table_index as usize];
719                    let litlen_symbol = secondary_entry >> 4;
720                    let litlen_code_bits = (secondary_entry & 0xf) as u8;
721
722                    if self.nbits < litlen_code_bits {
723                        break;
724                    } else if litlen_symbol < 256 {
725                        self.consume_bits(litlen_code_bits);
726                        output[output_index] = litlen_symbol as u8;
727                        output_index += 1;
728                        continue;
729                    } else if litlen_symbol == 256 {
730                        self.consume_bits(litlen_code_bits);
731                        self.state = match self.last_block {
732                            true => State::Checksum,
733                            false => State::BlockHeader,
734                        };
735                        break;
736                    }
737
738                    (
739                        LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
740                        LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
741                        litlen_code_bits,
742                    )
743                } else if litlen_code_bits == 0 {
744                    return Err(DecompressionError::InvalidLiteralLengthCode);
745                } else {
746                    if self.nbits < litlen_code_bits {
747                        break;
748                    }
749                    self.consume_bits(litlen_code_bits);
750                    self.state = match self.last_block {
751                        true => State::Checksum,
752                        false => State::BlockHeader,
753                    };
754                    break;
755                };
756            bits >>= litlen_code_bits;
757
758            let length_extra_mask = (1 << length_extra_bits) - 1;
759            let length = length_base as usize + (bits & length_extra_mask) as usize;
760            bits >>= length_extra_bits;
761
762            let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
763            let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
764                (
765                    (dist_entry >> 16) as u16,
766                    (dist_entry >> 8) as u8 & 0xf,
767                    dist_entry as u8,
768                )
769            } else if self.nbits > litlen_code_bits + length_extra_bits + 9 {
770                if dist_entry >> 8 == 0 {
771                    return Err(DecompressionError::InvalidDistanceCode);
772                }
773
774                let secondary_table_index =
775                    (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
776                let secondary_entry =
777                    self.compression.dist_secondary_table[secondary_table_index as usize];
778                let dist_symbol = (secondary_entry >> 4) as usize;
779                if dist_symbol >= 30 {
780                    return Err(DecompressionError::InvalidDistanceCode);
781                }
782
783                (
784                    DIST_SYM_TO_DIST_BASE[dist_symbol],
785                    DIST_SYM_TO_DIST_EXTRA[dist_symbol],
786                    (secondary_entry & 0xf) as u8,
787                )
788            } else {
789                break;
790            };
791            bits >>= dist_code_bits;
792
793            let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
794            let total_bits =
795                litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
796
797            if self.nbits < total_bits {
798                break;
799            } else if dist > output_index {
800                return Err(DecompressionError::DistanceTooFarBack);
801            }
802
803            self.consume_bits(total_bits);
804
805            let copy_length = length.min(output.len() - output_index);
806            if dist == 1 {
807                let last = output[output_index - 1];
808                output[output_index..][..copy_length].fill(last);
809
810                if copy_length < length {
811                    self.queued_rle = Some((last, length - copy_length));
812                    output_index = output.len();
813                    break;
814                }
815            } else if output_index + length + 15 <= output.len() {
816                let start = output_index - dist;
817                output.copy_within(start..start + 16, output_index);
818
819                if length > 16 || dist < 16 {
820                    for i in (0..length).step_by(dist.min(16)).skip(1) {
821                        output.copy_within(start + i..start + i + 16, output_index + i);
822                    }
823                }
824            } else {
825                if dist < copy_length {
826                    for i in 0..copy_length {
827                        output[output_index + i] = output[output_index + i - dist];
828                    }
829                } else {
830                    output.copy_within(
831                        output_index - dist..output_index + copy_length - dist,
832                        output_index,
833                    )
834                }
835
836                if copy_length < length {
837                    self.queued_backref = Some((dist, length - copy_length));
838                    output_index = output.len();
839                    break;
840                }
841            }
842            output_index += copy_length;
843        }
844
845        if self.state == State::CompressedData
846            && self.queued_backref.is_none()
847            && self.queued_rle.is_none()
848            && self.nbits >= 15
849            && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code
850        {
851            self.consume_bits(self.compression.eof_bits);
852            self.state = match self.last_block {
853                true => State::Checksum,
854                false => State::BlockHeader,
855            };
856        }
857
858        Ok(output_index)
859    }
860
861    /// Decompresses a chunk of data.
862    ///
863    /// Returns the number of bytes read from `input` and the number of bytes written to `output`,
864    /// or an error if the deflate stream is not valid. `input` is the compressed data. `output` is
865    /// the buffer to write the decompressed data to, starting at index `output_position`.
866    /// `end_of_input` indicates whether more data may be available in the future.
867    ///
868    /// The contents of `output` after `output_position` are ignored. However, this function may
869    /// write additional data to `output` past what is indicated by the return value.
870    ///
871    /// When this function returns `Ok`, at least one of the following is true:
872    /// - The input is fully consumed.
873    /// - The output is full but there are more bytes to output.
874    /// - The deflate stream is complete (and `is_done` will return true).
875    ///
876    /// # Panics
877    ///
878    /// This function will panic if `output_position` is out of bounds.
879    pub fn read(
880        &mut self,
881        input: &[u8],
882        output: &mut [u8],
883        output_position: usize,
884        end_of_input: bool,
885    ) -> Result<(usize, usize), DecompressionError> {
886        if let State::Done = self.state {
887            return Ok((0, 0));
888        }
889
890        assert!(output_position <= output.len());
891
892        let mut remaining_input = input;
893        let mut output_index = output_position;
894
895        if let Some((data, len)) = self.queued_rle.take() {
896            let n = len.min(output.len() - output_index);
897            output[output_index..][..n].fill(data);
898            output_index += n;
899            if n < len {
900                self.queued_rle = Some((data, len - n));
901                return Ok((0, n));
902            }
903        }
904        if let Some((dist, len)) = self.queued_backref.take() {
905            let n = len.min(output.len() - output_index);
906            for i in 0..n {
907                output[output_index + i] = output[output_index + i - dist];
908            }
909            output_index += n;
910            if n < len {
911                self.queued_backref = Some((dist, len - n));
912                return Ok((0, n));
913            }
914        }
915
916        // Main decoding state machine.
917        let mut last_state = None;
918        while last_state != Some(self.state) {
919            last_state = Some(self.state);
920            match self.state {
921                State::ZlibHeader => {
922                    self.fill_buffer(&mut remaining_input);
923                    if self.nbits < 16 {
924                        break;
925                    }
926
927                    let input0 = self.peak_bits(8);
928                    let input1 = self.peak_bits(16) >> 8 & 0xff;
929                    if input0 & 0x0f != 0x08
930                        || (input0 & 0xf0) > 0x70
931                        || input1 & 0x20 != 0
932                        || (input0 << 8 | input1) % 31 != 0
933                    {
934                        return Err(DecompressionError::BadZlibHeader);
935                    }
936
937                    self.consume_bits(16);
938                    self.state = State::BlockHeader;
939                }
940                State::BlockHeader => {
941                    self.read_block_header(&mut remaining_input)?;
942                }
943                State::CodeLengthCodes => {
944                    self.read_code_length_codes(&mut remaining_input)?;
945                }
946                State::CodeLengths => {
947                    self.read_code_lengths(&mut remaining_input)?;
948                }
949                State::CompressedData => {
950                    output_index =
951                        self.read_compressed(&mut remaining_input, output, output_index)?
952                }
953                State::UncompressedData => {
954                    // Drain any bytes from our buffer.
955                    debug_assert_eq!(self.nbits % 8, 0);
956                    while self.nbits > 0
957                        && self.uncompressed_bytes_left > 0
958                        && output_index < output.len()
959                    {
960                        output[output_index] = self.peak_bits(8) as u8;
961                        self.consume_bits(8);
962                        output_index += 1;
963                        self.uncompressed_bytes_left -= 1;
964                    }
965                    // Buffer may contain one additional byte. Clear it to avoid confusion.
966                    if self.nbits == 0 {
967                        self.buffer = 0;
968                    }
969
970                    // Copy subsequent bytes directly from the input.
971                    let copy_bytes = (self.uncompressed_bytes_left as usize)
972                        .min(remaining_input.len())
973                        .min(output.len() - output_index);
974                    output[output_index..][..copy_bytes]
975                        .copy_from_slice(&remaining_input[..copy_bytes]);
976                    remaining_input = &remaining_input[copy_bytes..];
977                    output_index += copy_bytes;
978                    self.uncompressed_bytes_left -= copy_bytes as u16;
979
980                    if self.uncompressed_bytes_left == 0 {
981                        self.state = if self.last_block {
982                            State::Checksum
983                        } else {
984                            State::BlockHeader
985                        };
986                    }
987                }
988                State::Checksum => {
989                    self.fill_buffer(&mut remaining_input);
990
991                    let align_bits = self.nbits % 8;
992                    if self.nbits >= 32 + align_bits {
993                        self.checksum.write(&output[output_position..output_index]);
994                        if align_bits != 0 {
995                            self.consume_bits(align_bits);
996                        }
997                        #[cfg(not(fuzzing))]
998                        if !self.ignore_adler32
999                            && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish()
1000                        {
1001                            return Err(DecompressionError::WrongChecksum);
1002                        }
1003                        self.state = State::Done;
1004                        self.consume_bits(32);
1005                        break;
1006                    }
1007                }
1008                State::Done => unreachable!(),
1009            }
1010        }
1011
1012        if !self.ignore_adler32 && self.state != State::Done {
1013            self.checksum.write(&output[output_position..output_index]);
1014        }
1015
1016        if self.state == State::Done || !end_of_input || output_index == output.len() {
1017            let input_left = remaining_input.len();
1018            Ok((input.len() - input_left, output_index - output_position))
1019        } else {
1020            Err(DecompressionError::InsufficientInput)
1021        }
1022    }
1023
1024    /// Returns true if the decompressor has finished decompressing the input.
1025    pub fn is_done(&self) -> bool {
1026        self.state == State::Done
1027    }
1028}
1029
1030/// Decompress the given data.
1031pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1032    match decompress_to_vec_bounded(input, usize::MAX) {
1033        Ok(output) => Ok(output),
1034        Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner),
1035        Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1036            unreachable!("Impossible to allocate more than isize::MAX bytes")
1037        }
1038    }
1039}
1040
1041/// An error encountered while decompressing a deflate stream given a bounded maximum output.
1042pub enum BoundedDecompressionError {
1043    /// The input is not a valid deflate stream.
1044    DecompressionError {
1045        /// The underlying error.
1046        inner: DecompressionError,
1047    },
1048
1049    /// The output is too large.
1050    OutputTooLarge {
1051        /// The output decoded so far.
1052        partial_output: Vec<u8>,
1053    },
1054}
1055impl From<DecompressionError> for BoundedDecompressionError {
1056    fn from(inner: DecompressionError) -> Self {
1057        BoundedDecompressionError::DecompressionError { inner }
1058    }
1059}
1060
1061/// Decompress the given data, returning an error if the output is larger than
1062/// `maxlen` bytes.
1063pub fn decompress_to_vec_bounded(
1064    input: &[u8],
1065    maxlen: usize,
1066) -> Result<Vec<u8>, BoundedDecompressionError> {
1067    let mut decoder = Decompressor::new();
1068    let mut output = vec![0; 1024.min(maxlen)];
1069    let mut input_index = 0;
1070    let mut output_index = 0;
1071    loop {
1072        let (consumed, produced) =
1073            decoder.read(&input[input_index..], &mut output, output_index, true)?;
1074        input_index += consumed;
1075        output_index += produced;
1076        if decoder.is_done() || output_index == maxlen {
1077            break;
1078        }
1079        output.resize((output_index + 32 * 1024).min(maxlen), 0);
1080    }
1081    output.resize(output_index, 0);
1082
1083    if decoder.is_done() {
1084        Ok(output)
1085    } else {
1086        Err(BoundedDecompressionError::OutputTooLarge {
1087            partial_output: output,
1088        })
1089    }
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094    use crate::tables::{LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1095
1096    use super::*;
1097    use rand::Rng;
1098
1099    fn roundtrip(data: &[u8]) {
1100        let compressed = crate::compress_to_vec(data);
1101        let decompressed = decompress_to_vec(&compressed).unwrap();
1102        assert_eq!(&decompressed, data);
1103    }
1104
1105    fn roundtrip_miniz_oxide(data: &[u8]) {
1106        let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1107        let decompressed = decompress_to_vec(&compressed).unwrap();
1108        assert_eq!(decompressed.len(), data.len());
1109        for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1110            assert_eq!(a, b, "chunk {}..{}", i, i + 1);
1111        }
1112        assert_eq!(&decompressed, data);
1113    }
1114
1115    #[allow(unused)]
1116    fn compare_decompression(data: &[u8]) {
1117        // let decompressed0 = flate2::read::ZlibDecoder::new(std::io::Cursor::new(&data))
1118        //     .bytes()
1119        //     .collect::<Result<Vec<_>, _>>()
1120        //     .unwrap();
1121        let decompressed = decompress_to_vec(data).unwrap();
1122        let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(data).unwrap();
1123        for i in 0..decompressed.len().min(decompressed2.len()) {
1124            if decompressed[i] != decompressed2[i] {
1125                panic!(
1126                    "mismatch at index {} {:?} {:?}",
1127                    i,
1128                    &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1129                    &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1130                );
1131            }
1132        }
1133        if decompressed != decompressed2 {
1134            panic!(
1135                "length mismatch {} {} {:x?}",
1136                decompressed.len(),
1137                decompressed2.len(),
1138                &decompressed2[decompressed.len()..][..16]
1139            );
1140        }
1141        //assert_eq!(decompressed, decompressed2);
1142    }
1143
1144    #[test]
1145    fn tables() {
1146        for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1147            let len_base = LEN_SYM_TO_LEN_BASE[i];
1148            for j in 0..(1 << bits) {
1149                if i == 27 && j == 31 {
1150                    continue;
1151                }
1152                assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1153                assert_eq!(
1154                    LENGTH_TO_SYMBOL[len_base + j - 3],
1155                    i as u16 + 257,
1156                    "{} {}",
1157                    i,
1158                    j
1159                );
1160            }
1161        }
1162    }
1163
1164    #[test]
1165    fn fixed_tables() {
1166        let mut compression = CompressedBlock {
1167            litlen_table: Box::new([0; 4096]),
1168            dist_table: Box::new([0; 512]),
1169            secondary_table: Vec::new(),
1170            dist_secondary_table: Vec::new(),
1171            eof_code: 0,
1172            eof_mask: 0,
1173            eof_bits: 0,
1174        };
1175        Decompressor::build_tables(288, &FIXED_CODE_LENGTHS, &mut compression).unwrap();
1176
1177        assert_eq!(compression.litlen_table[..512], FIXED_LITLEN_TABLE);
1178        assert_eq!(compression.dist_table[..32], FIXED_DIST_TABLE);
1179    }
1180
1181    #[test]
1182    fn it_works() {
1183        roundtrip(b"Hello world!");
1184    }
1185
1186    #[test]
1187    fn constant() {
1188        roundtrip_miniz_oxide(&[0; 50]);
1189        roundtrip_miniz_oxide(&vec![5; 2048]);
1190        roundtrip_miniz_oxide(&vec![128; 2048]);
1191        roundtrip_miniz_oxide(&vec![254; 2048]);
1192    }
1193
1194    #[test]
1195    fn random() {
1196        let mut rng = rand::thread_rng();
1197        let mut data = vec![0; 50000];
1198        for _ in 0..10 {
1199            for byte in &mut data {
1200                *byte = rng.gen::<u8>() % 5;
1201            }
1202            println!("Random data: {:?}", data);
1203            roundtrip_miniz_oxide(&data);
1204        }
1205    }
1206
1207    #[test]
1208    fn ignore_adler32() {
1209        let mut compressed = crate::compress_to_vec(b"Hello world!");
1210        let last_byte = compressed.len() - 1;
1211        compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1212
1213        match decompress_to_vec(&compressed) {
1214            Err(DecompressionError::WrongChecksum) => {}
1215            r => panic!("expected WrongChecksum, got {:?}", r),
1216        }
1217
1218        let mut decompressor = Decompressor::new();
1219        decompressor.ignore_adler32();
1220        let mut decompressed = vec![0; 1024];
1221        let decompressed_len = decompressor
1222            .read(&compressed, &mut decompressed, 0, true)
1223            .unwrap()
1224            .1;
1225        assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1226    }
1227
1228    #[test]
1229    fn checksum_after_eof() {
1230        let input = b"Hello world!";
1231        let compressed = crate::compress_to_vec(input);
1232
1233        let mut decompressor = Decompressor::new();
1234        let mut decompressed = vec![0; 1024];
1235        let (input_consumed, output_written) = decompressor
1236            .read(
1237                &compressed[..compressed.len() - 1],
1238                &mut decompressed,
1239                0,
1240                false,
1241            )
1242            .unwrap();
1243        assert_eq!(output_written, input.len());
1244        assert_eq!(input_consumed, compressed.len() - 1);
1245
1246        let (input_consumed, output_written) = decompressor
1247            .read(
1248                &compressed[input_consumed..],
1249                &mut decompressed[..output_written],
1250                output_written,
1251                true,
1252            )
1253            .unwrap();
1254        assert!(decompressor.is_done());
1255        assert_eq!(input_consumed, 1);
1256        assert_eq!(output_written, 0);
1257
1258        assert_eq!(&decompressed[..input.len()], input);
1259    }
1260
1261    #[test]
1262    fn zero_length() {
1263        let mut compressed = crate::compress_to_vec(b"").to_vec();
1264
1265        // Splice in zero-length non-compressed blocks.
1266        for _ in 0..10 {
1267            println!("compressed len: {}", compressed.len());
1268            compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1269        }
1270
1271        // Ensure that the full input is decompressed, regardless of whether
1272        // `end_of_input` is set.
1273        for end_of_input in [true, false] {
1274            let mut decompressor = Decompressor::new();
1275            let (input_consumed, output_written) = decompressor
1276                .read(&compressed, &mut [], 0, end_of_input)
1277                .unwrap();
1278
1279            assert!(decompressor.is_done());
1280            assert_eq!(input_consumed, compressed.len());
1281            assert_eq!(output_written, 0);
1282        }
1283    }
1284
1285    mod test_utils;
1286    use tables::FIXED_CODE_LENGTHS;
1287    use test_utils::{decompress_by_chunks, TestDecompressionError};
1288
1289    fn verify_no_sensitivity_to_input_chunking(
1290        input: &[u8],
1291    ) -> Result<Vec<u8>, TestDecompressionError> {
1292        let r_whole = decompress_by_chunks(input, vec![input.len()], false);
1293        let r_bytewise = decompress_by_chunks(input, std::iter::repeat(1), false);
1294        assert_eq!(r_whole, r_bytewise);
1295        r_whole // Returning an arbitrary result, since this is equal to `r_bytewise`.
1296    }
1297
1298    /// This is a regression test found by the `buf_independent` fuzzer from the `png` crate.  When
1299    /// this test case was found, the results were unexpectedly different when 1) decompressing the
1300    /// whole input (successful result) vs 2) decompressing byte-by-byte
1301    /// (`Err(InvalidDistanceCode)`).
1302    #[test]
1303    fn test_input_chunking_sensitivity_when_handling_distance_codes() {
1304        let result = verify_no_sensitivity_to_input_chunking(include_bytes!(
1305            "../tests/input-chunking-sensitivity-example1.zz"
1306        ))
1307        .unwrap();
1308        assert_eq!(result.len(), 281);
1309        assert_eq!(simd_adler32::adler32(&result.as_slice()), 751299);
1310    }
1311
1312    /// This is a regression test found by the `inflate_bytewise3` fuzzer from the `fdeflate`
1313    /// crate.  When this test case was found, the results were unexpectedly different when 1)
1314    /// decompressing the whole input (`Err(DistanceTooFarBack)`) vs 2) decompressing byte-by-byte
1315    /// (successful result)`).
1316    #[test]
1317    fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example1() {
1318        let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1319            "../tests/input-chunking-sensitivity-example2.zz"
1320        ))
1321        .unwrap_err();
1322        assert_eq!(
1323            err,
1324            TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1325        );
1326    }
1327
1328    /// This is a regression test found by the `inflate_bytewise3` fuzzer from the `fdeflate`
1329    /// crate.  When this test case was found, the results were unexpectedly different when 1)
1330    /// decompressing the whole input (`Err(InvalidDistanceCode)`) vs 2) decompressing byte-by-byte
1331    /// (successful result)`).
1332    #[test]
1333    fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example2() {
1334        let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1335            "../tests/input-chunking-sensitivity-example3.zz"
1336        ))
1337        .unwrap_err();
1338        assert_eq!(
1339            err,
1340            TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1341        );
1342    }
1343}