1use simd_adler32::Adler32;
2
3use crate::{
4 huffman::{self, build_table},
5 tables::{
6 self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FIXED_DIST_TABLE,
7 FIXED_LITLEN_TABLE, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, LITLEN_TABLE_ENTRIES,
8 },
9};
10
11#[derive(Debug, PartialEq)]
13pub enum DecompressionError {
14 BadZlibHeader,
16 InsufficientInput,
18 InvalidBlockType,
20 InvalidUncompressedBlockLength,
22 InvalidHlit,
24 InvalidHdist,
26 InvalidCodeLengthRepeat,
29 BadCodeLengthHuffmanTree,
31 BadLiteralLengthHuffmanTree,
33 BadDistanceHuffmanTree,
35 InvalidLiteralLengthCode,
37 InvalidDistanceCode,
39 InputStartsWithRun,
41 DistanceTooFarBack,
43 WrongChecksum,
45 ExtraInput,
47}
48
49struct BlockHeader {
50 hlit: usize,
51 hdist: usize,
52 hclen: usize,
53 num_lengths_read: usize,
54
55 table: [u32; 128],
57 code_lengths: [u8; 320],
58}
59
60pub const LITERAL_ENTRY: u32 = 0x8000;
61pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
62pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
63
64#[derive(Eq, PartialEq, Debug)]
66struct CompressedBlock {
67 litlen_table: Box<[u32; 4096]>,
68 secondary_table: Vec<u16>,
69
70 dist_table: Box<[u32; 512]>,
71 dist_secondary_table: Vec<u16>,
72
73 eof_code: u16,
74 eof_mask: u16,
75 eof_bits: u8,
76}
77
78#[derive(Debug, Copy, Clone, Eq, PartialEq)]
79enum State {
80 ZlibHeader,
81 BlockHeader,
82 CodeLengthCodes,
83 CodeLengths,
84 CompressedData,
85 UncompressedData,
86 Checksum,
87 Done,
88}
89
90pub struct Decompressor {
92 compression: CompressedBlock,
94 header: BlockHeader,
96 uncompressed_bytes_left: u16,
98
99 buffer: u64,
100 nbits: u8,
101
102 queued_rle: Option<(u8, usize)>,
103 queued_backref: Option<(usize, usize)>,
104 last_block: bool,
105 fixed_table: bool,
106
107 state: State,
108 checksum: Adler32,
109 ignore_adler32: bool,
110}
111
112impl Default for Decompressor {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl Decompressor {
119 pub fn new() -> Self {
121 Self {
122 buffer: 0,
123 nbits: 0,
124 compression: CompressedBlock {
125 litlen_table: Box::new([0; 4096]),
126 dist_table: Box::new([0; 512]),
127 secondary_table: Vec::new(),
128 dist_secondary_table: Vec::new(),
129 eof_code: 0,
130 eof_mask: 0,
131 eof_bits: 0,
132 },
133 header: BlockHeader {
134 hlit: 0,
135 hdist: 0,
136 hclen: 0,
137 table: [0; 128],
138 num_lengths_read: 0,
139 code_lengths: [0; 320],
140 },
141 uncompressed_bytes_left: 0,
142 queued_rle: None,
143 queued_backref: None,
144 checksum: Adler32::new(),
145 state: State::ZlibHeader,
146 last_block: false,
147 ignore_adler32: false,
148 fixed_table: false,
149 }
150 }
151
152 pub fn ignore_adler32(&mut self) {
154 self.ignore_adler32 = true;
155 }
156
157 fn fill_buffer(&mut self, input: &mut &[u8]) {
158 if input.len() >= 8 {
159 self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits;
160 *input = &input[(63 - self.nbits as usize) / 8..];
161 self.nbits |= 56;
162 } else {
163 let nbytes = input.len().min((63 - self.nbits as usize) / 8);
164 let mut input_data = [0; 8];
165 input_data[..nbytes].copy_from_slice(&input[..nbytes]);
166 self.buffer |= u64::from_le_bytes(input_data)
167 .checked_shl(self.nbits as u32)
168 .unwrap_or(0);
169 self.nbits += nbytes as u8 * 8;
170 *input = &input[nbytes..];
171 }
172 }
173
174 fn peak_bits(&mut self, nbits: u8) -> u64 {
175 debug_assert!(nbits <= 56 && nbits <= self.nbits);
176 self.buffer & ((1u64 << nbits) - 1)
177 }
178 fn consume_bits(&mut self, nbits: u8) {
179 debug_assert!(self.nbits >= nbits);
180 self.buffer >>= nbits;
181 self.nbits -= nbits;
182 }
183
184 fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
185 self.fill_buffer(remaining_input);
186 if self.nbits < 10 {
187 return Ok(());
188 }
189
190 let start = self.peak_bits(3);
191 self.last_block = start & 1 != 0;
192 match start >> 1 {
193 0b00 => {
194 let align_bits = (self.nbits - 3) % 8;
195 let header_bits = 3 + 32 + align_bits;
196 if self.nbits < header_bits {
197 return Ok(());
198 }
199
200 let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16;
201 let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16;
202 if nlen != !len {
203 return Err(DecompressionError::InvalidUncompressedBlockLength);
204 }
205
206 self.state = State::UncompressedData;
207 self.uncompressed_bytes_left = len;
208 self.consume_bits(header_bits);
209 Ok(())
210 }
211 0b01 => {
212 self.consume_bits(3);
213
214 if self.peak_bits(7) == 0 {
218 self.consume_bits(7);
219 if self.last_block {
220 self.state = State::Checksum;
221 return Ok(());
222 }
223
224 while self.nbits >= 10 && self.peak_bits(10) == 0b010 {
230 self.consume_bits(10);
231 self.fill_buffer(remaining_input);
232 }
233 return self.read_block_header(remaining_input);
234 }
235
236 if !self.fixed_table {
238 self.fixed_table = true;
239 for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
240 chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
241 }
242 for chunk in self.compression.dist_table.chunks_exact_mut(32) {
243 chunk.copy_from_slice(&FIXED_DIST_TABLE);
244 }
245 self.compression.eof_bits = 7;
246 self.compression.eof_code = 0;
247 self.compression.eof_mask = 0x7f;
248 }
249
250 self.state = State::CompressedData;
251 Ok(())
252 }
253 0b10 => {
254 if self.nbits < 17 {
255 return Ok(());
256 }
257
258 self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257;
259 self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1;
260 self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4;
261 if self.header.hlit > 286 {
262 return Err(DecompressionError::InvalidHlit);
263 }
264 if self.header.hdist > 30 {
265 return Err(DecompressionError::InvalidHdist);
266 }
267
268 self.consume_bits(17);
269 self.state = State::CodeLengthCodes;
270 self.fixed_table = false;
271 Ok(())
272 }
273 0b11 => Err(DecompressionError::InvalidBlockType),
274 _ => unreachable!(),
275 }
276 }
277
278 fn read_code_length_codes(
279 &mut self,
280 remaining_input: &mut &[u8],
281 ) -> Result<(), DecompressionError> {
282 self.fill_buffer(remaining_input);
283 if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
284 return Ok(());
285 }
286
287 let mut code_length_lengths = [0; 19];
288 for i in 0..self.header.hclen {
289 code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8;
290 self.consume_bits(3);
291
292 if i == 17 {
295 self.fill_buffer(remaining_input);
296 }
297 }
298
299 let mut codes = [0; 19];
300 if !build_table(
301 &code_length_lengths,
302 &[],
303 &mut codes,
304 &mut self.header.table,
305 &mut Vec::new(),
306 false,
307 false,
308 ) {
309 return Err(DecompressionError::BadCodeLengthHuffmanTree);
310 }
311
312 self.state = State::CodeLengths;
313 self.header.num_lengths_read = 0;
314 Ok(())
315 }
316
317 fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
318 let total_lengths = self.header.hlit + self.header.hdist;
319 while self.header.num_lengths_read < total_lengths {
320 self.fill_buffer(remaining_input);
321 if self.nbits < 7 {
322 return Ok(());
323 }
324
325 let code = self.peak_bits(7);
326 let entry = self.header.table[code as usize];
327 let length = (entry & 0x7) as u8;
328 let symbol = (entry >> 16) as u8;
329
330 debug_assert!(length != 0);
331 match symbol {
332 0..=15 => {
333 self.header.code_lengths[self.header.num_lengths_read] = symbol;
334 self.header.num_lengths_read += 1;
335 self.consume_bits(length);
336 }
337 16..=18 => {
338 let (base_repeat, extra_bits) = match symbol {
339 16 => (3, 2),
340 17 => (3, 3),
341 18 => (11, 7),
342 _ => unreachable!(),
343 };
344
345 if self.nbits < length + extra_bits {
346 return Ok(());
347 }
348
349 let value = match symbol {
350 16 => {
351 self.header.code_lengths[self
352 .header
353 .num_lengths_read
354 .checked_sub(1)
355 .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
356 }
358 17 => 0,
359 18 => 0,
360 _ => unreachable!(),
361 };
362
363 let repeat =
364 (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat;
365 if self.header.num_lengths_read + repeat > total_lengths {
366 return Err(DecompressionError::InvalidCodeLengthRepeat);
367 }
368
369 for i in 0..repeat {
370 self.header.code_lengths[self.header.num_lengths_read + i] = value;
371 }
372 self.header.num_lengths_read += repeat;
373 self.consume_bits(length + extra_bits);
374 }
375 _ => unreachable!(),
376 }
377 }
378
379 self.header
380 .code_lengths
381 .copy_within(self.header.hlit..total_lengths, 288);
382 for i in self.header.hlit..288 {
383 self.header.code_lengths[i] = 0;
384 }
385 for i in 288 + self.header.hdist..320 {
386 self.header.code_lengths[i] = 0;
387 }
388
389 Self::build_tables(
390 self.header.hlit,
391 &self.header.code_lengths,
392 &mut self.compression,
393 )?;
394 self.state = State::CompressedData;
395 Ok(())
396 }
397
398 fn build_tables(
399 hlit: usize,
400 code_lengths: &[u8],
401 compression: &mut CompressedBlock,
402 ) -> Result<(), DecompressionError> {
403 if code_lengths[256] == 0 {
405 return Err(DecompressionError::BadLiteralLengthHuffmanTree);
407 }
408
409 let mut codes = [0; 288];
410 compression.secondary_table.clear();
411 if !huffman::build_table(
412 &code_lengths[..hlit],
413 &LITLEN_TABLE_ENTRIES,
414 &mut codes[..hlit],
415 &mut *compression.litlen_table,
416 &mut compression.secondary_table,
417 false,
418 true,
419 ) {
420 return Err(DecompressionError::BadCodeLengthHuffmanTree);
421 }
422
423 compression.eof_code = codes[256];
424 compression.eof_mask = (1 << code_lengths[256]) - 1;
425 compression.eof_bits = code_lengths[256];
426
427 let lengths = &code_lengths[288..320];
429 if lengths == [0; 32] {
430 compression.dist_table.fill(0);
431 } else {
432 let mut dist_codes = [0; 32];
433 if !huffman::build_table(
434 lengths,
435 &tables::DISTANCE_TABLE_ENTRIES,
436 &mut dist_codes,
437 &mut *compression.dist_table,
438 &mut compression.dist_secondary_table,
439 true,
440 false,
441 ) {
442 return Err(DecompressionError::BadDistanceHuffmanTree);
443 }
444 }
445
446 Ok(())
447 }
448
449 fn read_compressed(
450 &mut self,
451 remaining_input: &mut &[u8],
452 output: &mut [u8],
453 mut output_index: usize,
454 ) -> Result<usize, DecompressionError> {
455 self.fill_buffer(remaining_input);
468 let mut litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize];
469 while self.state == State::CompressedData
470 && output_index + 8 <= output.len()
471 && remaining_input.len() >= 8
472 {
473 let mut bits;
476 let mut litlen_code_bits = litlen_entry as u8;
477 if litlen_entry & LITERAL_ENTRY != 0 {
478 let litlen_entry2 = self.compression.litlen_table
479 [(self.buffer >> litlen_code_bits & 0xfff) as usize];
480 let litlen_code_bits2 = litlen_entry2 as u8;
481 let litlen_entry3 = self.compression.litlen_table
482 [(self.buffer >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize];
483 let litlen_code_bits3 = litlen_entry3 as u8;
484 let litlen_entry4 = self.compression.litlen_table[(self.buffer
485 >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
486 & 0xfff)
487 as usize];
488
489 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
490 output[output_index] = (litlen_entry >> 16) as u8;
491 output[output_index + 1] = (litlen_entry >> 24) as u8;
492 output_index += advance_output_bytes;
493
494 if litlen_entry2 & LITERAL_ENTRY != 0 {
495 let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
496 output[output_index] = (litlen_entry2 >> 16) as u8;
497 output[output_index + 1] = (litlen_entry2 >> 24) as u8;
498 output_index += advance_output_bytes2;
499
500 if litlen_entry3 & LITERAL_ENTRY != 0 {
501 let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
502 output[output_index] = (litlen_entry3 >> 16) as u8;
503 output[output_index + 1] = (litlen_entry3 >> 24) as u8;
504 output_index += advance_output_bytes3;
505
506 litlen_entry = litlen_entry4;
507 self.consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3);
508 self.fill_buffer(remaining_input);
509 continue;
510 } else {
511 self.consume_bits(litlen_code_bits + litlen_code_bits2);
512 litlen_entry = litlen_entry3;
513 litlen_code_bits = litlen_code_bits3;
514 self.fill_buffer(remaining_input);
515 bits = self.buffer;
516 }
517 } else {
518 self.consume_bits(litlen_code_bits);
519 bits = self.buffer;
520 litlen_entry = litlen_entry2;
521 litlen_code_bits = litlen_code_bits2;
522 if self.nbits < 48 {
523 self.fill_buffer(remaining_input);
524 }
525 }
526 } else {
527 bits = self.buffer;
528 }
529
530 let (length_base, length_extra_bits, litlen_code_bits) =
532 if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
533 (
534 litlen_entry >> 16,
535 (litlen_entry >> 8) as u8,
536 litlen_code_bits,
537 )
538 } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
539 let secondary_table_index =
540 (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
541 let secondary_entry =
542 self.compression.secondary_table[secondary_table_index as usize];
543 let litlen_symbol = secondary_entry >> 4;
544 let litlen_code_bits = (secondary_entry & 0xf) as u8;
545
546 match litlen_symbol {
547 0..=255 => {
548 self.consume_bits(litlen_code_bits);
549 litlen_entry =
550 self.compression.litlen_table[(self.buffer & 0xfff) as usize];
551 self.fill_buffer(remaining_input);
552 output[output_index] = litlen_symbol as u8;
553 output_index += 1;
554 continue;
555 }
556 256 => {
557 self.consume_bits(litlen_code_bits);
558 self.state = match self.last_block {
559 true => State::Checksum,
560 false => State::BlockHeader,
561 };
562 break;
563 }
564 _ => (
565 LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
566 LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
567 litlen_code_bits,
568 ),
569 }
570 } else if litlen_code_bits == 0 {
571 return Err(DecompressionError::InvalidLiteralLengthCode);
572 } else {
573 self.consume_bits(litlen_code_bits);
574 self.state = match self.last_block {
575 true => State::Checksum,
576 false => State::BlockHeader,
577 };
578 break;
579 };
580 bits >>= litlen_code_bits;
581
582 let length_extra_mask = (1 << length_extra_bits) - 1;
583 let length = length_base as usize + (bits & length_extra_mask) as usize;
584 bits >>= length_extra_bits;
585
586 let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
587 let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
588 (
589 (dist_entry >> 16) as u16,
590 (dist_entry >> 8) as u8 & 0xf,
591 dist_entry as u8,
592 )
593 } else if dist_entry >> 8 == 0 {
594 return Err(DecompressionError::InvalidDistanceCode);
595 } else {
596 let secondary_table_index =
597 (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
598 let secondary_entry =
599 self.compression.dist_secondary_table[secondary_table_index as usize];
600 let dist_symbol = (secondary_entry >> 4) as usize;
601 if dist_symbol >= 30 {
602 return Err(DecompressionError::InvalidDistanceCode);
603 }
604
605 (
606 DIST_SYM_TO_DIST_BASE[dist_symbol],
607 DIST_SYM_TO_DIST_EXTRA[dist_symbol],
608 (secondary_entry & 0xf) as u8,
609 )
610 };
611 bits >>= dist_code_bits;
612
613 let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
614 if dist > output_index {
615 return Err(DecompressionError::DistanceTooFarBack);
616 }
617
618 self.consume_bits(
619 litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
620 );
621 self.fill_buffer(remaining_input);
622 litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize];
623
624 let copy_length = length.min(output.len() - output_index);
625 if dist == 1 {
626 let last = output[output_index - 1];
627 output[output_index..][..copy_length].fill(last);
628
629 if copy_length < length {
630 self.queued_rle = Some((last, length - copy_length));
631 output_index = output.len();
632 break;
633 }
634 } else if output_index + length + 15 <= output.len() {
635 let start = output_index - dist;
636 output.copy_within(start..start + 16, output_index);
637
638 if length > 16 || dist < 16 {
639 for i in (0..length).step_by(dist.min(16)).skip(1) {
640 output.copy_within(start + i..start + i + 16, output_index + i);
641 }
642 }
643 } else {
644 if dist < copy_length {
645 for i in 0..copy_length {
646 output[output_index + i] = output[output_index + i - dist];
647 }
648 } else {
649 output.copy_within(
650 output_index - dist..output_index + copy_length - dist,
651 output_index,
652 )
653 }
654
655 if copy_length < length {
656 self.queued_backref = Some((dist, length - copy_length));
657 output_index = output.len();
658 break;
659 }
660 }
661 output_index += copy_length;
662 }
663
664 while let State::CompressedData = self.state {
669 self.fill_buffer(remaining_input);
670 if output_index == output.len() {
671 break;
672 }
673
674 let mut bits = self.buffer;
675 let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize];
676 let litlen_code_bits = litlen_entry as u8;
677
678 if litlen_entry & LITERAL_ENTRY != 0 {
679 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
682
683 if self.nbits < litlen_code_bits {
684 break;
685 } else if output_index + 1 < output.len() {
686 output[output_index] = (litlen_entry >> 16) as u8;
687 output[output_index + 1] = (litlen_entry >> 24) as u8;
688 output_index += advance_output_bytes;
689 self.consume_bits(litlen_code_bits);
690 continue;
691 } else if output_index + advance_output_bytes == output.len() {
692 debug_assert_eq!(advance_output_bytes, 1);
693 output[output_index] = (litlen_entry >> 16) as u8;
694 output_index += 1;
695 self.consume_bits(litlen_code_bits);
696 break;
697 } else {
698 debug_assert_eq!(advance_output_bytes, 2);
699 output[output_index] = (litlen_entry >> 16) as u8;
700 self.queued_rle = Some(((litlen_entry >> 24) as u8, 1));
701 output_index += 1;
702 self.consume_bits(litlen_code_bits);
703 break;
704 }
705 }
706
707 let (length_base, length_extra_bits, litlen_code_bits) =
708 if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
709 (
710 litlen_entry >> 16,
711 (litlen_entry >> 8) as u8,
712 litlen_code_bits,
713 )
714 } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
715 let secondary_table_index =
716 (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
717 let secondary_entry =
718 self.compression.secondary_table[secondary_table_index as usize];
719 let litlen_symbol = secondary_entry >> 4;
720 let litlen_code_bits = (secondary_entry & 0xf) as u8;
721
722 if self.nbits < litlen_code_bits {
723 break;
724 } else if litlen_symbol < 256 {
725 self.consume_bits(litlen_code_bits);
726 output[output_index] = litlen_symbol as u8;
727 output_index += 1;
728 continue;
729 } else if litlen_symbol == 256 {
730 self.consume_bits(litlen_code_bits);
731 self.state = match self.last_block {
732 true => State::Checksum,
733 false => State::BlockHeader,
734 };
735 break;
736 }
737
738 (
739 LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
740 LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
741 litlen_code_bits,
742 )
743 } else if litlen_code_bits == 0 {
744 return Err(DecompressionError::InvalidLiteralLengthCode);
745 } else {
746 if self.nbits < litlen_code_bits {
747 break;
748 }
749 self.consume_bits(litlen_code_bits);
750 self.state = match self.last_block {
751 true => State::Checksum,
752 false => State::BlockHeader,
753 };
754 break;
755 };
756 bits >>= litlen_code_bits;
757
758 let length_extra_mask = (1 << length_extra_bits) - 1;
759 let length = length_base as usize + (bits & length_extra_mask) as usize;
760 bits >>= length_extra_bits;
761
762 let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
763 let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
764 (
765 (dist_entry >> 16) as u16,
766 (dist_entry >> 8) as u8 & 0xf,
767 dist_entry as u8,
768 )
769 } else if self.nbits > litlen_code_bits + length_extra_bits + 9 {
770 if dist_entry >> 8 == 0 {
771 return Err(DecompressionError::InvalidDistanceCode);
772 }
773
774 let secondary_table_index =
775 (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
776 let secondary_entry =
777 self.compression.dist_secondary_table[secondary_table_index as usize];
778 let dist_symbol = (secondary_entry >> 4) as usize;
779 if dist_symbol >= 30 {
780 return Err(DecompressionError::InvalidDistanceCode);
781 }
782
783 (
784 DIST_SYM_TO_DIST_BASE[dist_symbol],
785 DIST_SYM_TO_DIST_EXTRA[dist_symbol],
786 (secondary_entry & 0xf) as u8,
787 )
788 } else {
789 break;
790 };
791 bits >>= dist_code_bits;
792
793 let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
794 let total_bits =
795 litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
796
797 if self.nbits < total_bits {
798 break;
799 } else if dist > output_index {
800 return Err(DecompressionError::DistanceTooFarBack);
801 }
802
803 self.consume_bits(total_bits);
804
805 let copy_length = length.min(output.len() - output_index);
806 if dist == 1 {
807 let last = output[output_index - 1];
808 output[output_index..][..copy_length].fill(last);
809
810 if copy_length < length {
811 self.queued_rle = Some((last, length - copy_length));
812 output_index = output.len();
813 break;
814 }
815 } else if output_index + length + 15 <= output.len() {
816 let start = output_index - dist;
817 output.copy_within(start..start + 16, output_index);
818
819 if length > 16 || dist < 16 {
820 for i in (0..length).step_by(dist.min(16)).skip(1) {
821 output.copy_within(start + i..start + i + 16, output_index + i);
822 }
823 }
824 } else {
825 if dist < copy_length {
826 for i in 0..copy_length {
827 output[output_index + i] = output[output_index + i - dist];
828 }
829 } else {
830 output.copy_within(
831 output_index - dist..output_index + copy_length - dist,
832 output_index,
833 )
834 }
835
836 if copy_length < length {
837 self.queued_backref = Some((dist, length - copy_length));
838 output_index = output.len();
839 break;
840 }
841 }
842 output_index += copy_length;
843 }
844
845 if self.state == State::CompressedData
846 && self.queued_backref.is_none()
847 && self.queued_rle.is_none()
848 && self.nbits >= 15
849 && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code
850 {
851 self.consume_bits(self.compression.eof_bits);
852 self.state = match self.last_block {
853 true => State::Checksum,
854 false => State::BlockHeader,
855 };
856 }
857
858 Ok(output_index)
859 }
860
861 pub fn read(
880 &mut self,
881 input: &[u8],
882 output: &mut [u8],
883 output_position: usize,
884 end_of_input: bool,
885 ) -> Result<(usize, usize), DecompressionError> {
886 if let State::Done = self.state {
887 return Ok((0, 0));
888 }
889
890 assert!(output_position <= output.len());
891
892 let mut remaining_input = input;
893 let mut output_index = output_position;
894
895 if let Some((data, len)) = self.queued_rle.take() {
896 let n = len.min(output.len() - output_index);
897 output[output_index..][..n].fill(data);
898 output_index += n;
899 if n < len {
900 self.queued_rle = Some((data, len - n));
901 return Ok((0, n));
902 }
903 }
904 if let Some((dist, len)) = self.queued_backref.take() {
905 let n = len.min(output.len() - output_index);
906 for i in 0..n {
907 output[output_index + i] = output[output_index + i - dist];
908 }
909 output_index += n;
910 if n < len {
911 self.queued_backref = Some((dist, len - n));
912 return Ok((0, n));
913 }
914 }
915
916 let mut last_state = None;
918 while last_state != Some(self.state) {
919 last_state = Some(self.state);
920 match self.state {
921 State::ZlibHeader => {
922 self.fill_buffer(&mut remaining_input);
923 if self.nbits < 16 {
924 break;
925 }
926
927 let input0 = self.peak_bits(8);
928 let input1 = self.peak_bits(16) >> 8 & 0xff;
929 if input0 & 0x0f != 0x08
930 || (input0 & 0xf0) > 0x70
931 || input1 & 0x20 != 0
932 || (input0 << 8 | input1) % 31 != 0
933 {
934 return Err(DecompressionError::BadZlibHeader);
935 }
936
937 self.consume_bits(16);
938 self.state = State::BlockHeader;
939 }
940 State::BlockHeader => {
941 self.read_block_header(&mut remaining_input)?;
942 }
943 State::CodeLengthCodes => {
944 self.read_code_length_codes(&mut remaining_input)?;
945 }
946 State::CodeLengths => {
947 self.read_code_lengths(&mut remaining_input)?;
948 }
949 State::CompressedData => {
950 output_index =
951 self.read_compressed(&mut remaining_input, output, output_index)?
952 }
953 State::UncompressedData => {
954 debug_assert_eq!(self.nbits % 8, 0);
956 while self.nbits > 0
957 && self.uncompressed_bytes_left > 0
958 && output_index < output.len()
959 {
960 output[output_index] = self.peak_bits(8) as u8;
961 self.consume_bits(8);
962 output_index += 1;
963 self.uncompressed_bytes_left -= 1;
964 }
965 if self.nbits == 0 {
967 self.buffer = 0;
968 }
969
970 let copy_bytes = (self.uncompressed_bytes_left as usize)
972 .min(remaining_input.len())
973 .min(output.len() - output_index);
974 output[output_index..][..copy_bytes]
975 .copy_from_slice(&remaining_input[..copy_bytes]);
976 remaining_input = &remaining_input[copy_bytes..];
977 output_index += copy_bytes;
978 self.uncompressed_bytes_left -= copy_bytes as u16;
979
980 if self.uncompressed_bytes_left == 0 {
981 self.state = if self.last_block {
982 State::Checksum
983 } else {
984 State::BlockHeader
985 };
986 }
987 }
988 State::Checksum => {
989 self.fill_buffer(&mut remaining_input);
990
991 let align_bits = self.nbits % 8;
992 if self.nbits >= 32 + align_bits {
993 self.checksum.write(&output[output_position..output_index]);
994 if align_bits != 0 {
995 self.consume_bits(align_bits);
996 }
997 #[cfg(not(fuzzing))]
998 if !self.ignore_adler32
999 && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish()
1000 {
1001 return Err(DecompressionError::WrongChecksum);
1002 }
1003 self.state = State::Done;
1004 self.consume_bits(32);
1005 break;
1006 }
1007 }
1008 State::Done => unreachable!(),
1009 }
1010 }
1011
1012 if !self.ignore_adler32 && self.state != State::Done {
1013 self.checksum.write(&output[output_position..output_index]);
1014 }
1015
1016 if self.state == State::Done || !end_of_input || output_index == output.len() {
1017 let input_left = remaining_input.len();
1018 Ok((input.len() - input_left, output_index - output_position))
1019 } else {
1020 Err(DecompressionError::InsufficientInput)
1021 }
1022 }
1023
1024 pub fn is_done(&self) -> bool {
1026 self.state == State::Done
1027 }
1028}
1029
1030pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1032 match decompress_to_vec_bounded(input, usize::MAX) {
1033 Ok(output) => Ok(output),
1034 Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner),
1035 Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1036 unreachable!("Impossible to allocate more than isize::MAX bytes")
1037 }
1038 }
1039}
1040
1041pub enum BoundedDecompressionError {
1043 DecompressionError {
1045 inner: DecompressionError,
1047 },
1048
1049 OutputTooLarge {
1051 partial_output: Vec<u8>,
1053 },
1054}
1055impl From<DecompressionError> for BoundedDecompressionError {
1056 fn from(inner: DecompressionError) -> Self {
1057 BoundedDecompressionError::DecompressionError { inner }
1058 }
1059}
1060
1061pub fn decompress_to_vec_bounded(
1064 input: &[u8],
1065 maxlen: usize,
1066) -> Result<Vec<u8>, BoundedDecompressionError> {
1067 let mut decoder = Decompressor::new();
1068 let mut output = vec![0; 1024.min(maxlen)];
1069 let mut input_index = 0;
1070 let mut output_index = 0;
1071 loop {
1072 let (consumed, produced) =
1073 decoder.read(&input[input_index..], &mut output, output_index, true)?;
1074 input_index += consumed;
1075 output_index += produced;
1076 if decoder.is_done() || output_index == maxlen {
1077 break;
1078 }
1079 output.resize((output_index + 32 * 1024).min(maxlen), 0);
1080 }
1081 output.resize(output_index, 0);
1082
1083 if decoder.is_done() {
1084 Ok(output)
1085 } else {
1086 Err(BoundedDecompressionError::OutputTooLarge {
1087 partial_output: output,
1088 })
1089 }
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use crate::tables::{LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1095
1096 use super::*;
1097 use rand::Rng;
1098
1099 fn roundtrip(data: &[u8]) {
1100 let compressed = crate::compress_to_vec(data);
1101 let decompressed = decompress_to_vec(&compressed).unwrap();
1102 assert_eq!(&decompressed, data);
1103 }
1104
1105 fn roundtrip_miniz_oxide(data: &[u8]) {
1106 let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1107 let decompressed = decompress_to_vec(&compressed).unwrap();
1108 assert_eq!(decompressed.len(), data.len());
1109 for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1110 assert_eq!(a, b, "chunk {}..{}", i, i + 1);
1111 }
1112 assert_eq!(&decompressed, data);
1113 }
1114
1115 #[allow(unused)]
1116 fn compare_decompression(data: &[u8]) {
1117 let decompressed = decompress_to_vec(data).unwrap();
1122 let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(data).unwrap();
1123 for i in 0..decompressed.len().min(decompressed2.len()) {
1124 if decompressed[i] != decompressed2[i] {
1125 panic!(
1126 "mismatch at index {} {:?} {:?}",
1127 i,
1128 &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1129 &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1130 );
1131 }
1132 }
1133 if decompressed != decompressed2 {
1134 panic!(
1135 "length mismatch {} {} {:x?}",
1136 decompressed.len(),
1137 decompressed2.len(),
1138 &decompressed2[decompressed.len()..][..16]
1139 );
1140 }
1141 }
1143
1144 #[test]
1145 fn tables() {
1146 for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1147 let len_base = LEN_SYM_TO_LEN_BASE[i];
1148 for j in 0..(1 << bits) {
1149 if i == 27 && j == 31 {
1150 continue;
1151 }
1152 assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1153 assert_eq!(
1154 LENGTH_TO_SYMBOL[len_base + j - 3],
1155 i as u16 + 257,
1156 "{} {}",
1157 i,
1158 j
1159 );
1160 }
1161 }
1162 }
1163
1164 #[test]
1165 fn fixed_tables() {
1166 let mut compression = CompressedBlock {
1167 litlen_table: Box::new([0; 4096]),
1168 dist_table: Box::new([0; 512]),
1169 secondary_table: Vec::new(),
1170 dist_secondary_table: Vec::new(),
1171 eof_code: 0,
1172 eof_mask: 0,
1173 eof_bits: 0,
1174 };
1175 Decompressor::build_tables(288, &FIXED_CODE_LENGTHS, &mut compression).unwrap();
1176
1177 assert_eq!(compression.litlen_table[..512], FIXED_LITLEN_TABLE);
1178 assert_eq!(compression.dist_table[..32], FIXED_DIST_TABLE);
1179 }
1180
1181 #[test]
1182 fn it_works() {
1183 roundtrip(b"Hello world!");
1184 }
1185
1186 #[test]
1187 fn constant() {
1188 roundtrip_miniz_oxide(&[0; 50]);
1189 roundtrip_miniz_oxide(&vec![5; 2048]);
1190 roundtrip_miniz_oxide(&vec![128; 2048]);
1191 roundtrip_miniz_oxide(&vec![254; 2048]);
1192 }
1193
1194 #[test]
1195 fn random() {
1196 let mut rng = rand::thread_rng();
1197 let mut data = vec![0; 50000];
1198 for _ in 0..10 {
1199 for byte in &mut data {
1200 *byte = rng.gen::<u8>() % 5;
1201 }
1202 println!("Random data: {:?}", data);
1203 roundtrip_miniz_oxide(&data);
1204 }
1205 }
1206
1207 #[test]
1208 fn ignore_adler32() {
1209 let mut compressed = crate::compress_to_vec(b"Hello world!");
1210 let last_byte = compressed.len() - 1;
1211 compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1212
1213 match decompress_to_vec(&compressed) {
1214 Err(DecompressionError::WrongChecksum) => {}
1215 r => panic!("expected WrongChecksum, got {:?}", r),
1216 }
1217
1218 let mut decompressor = Decompressor::new();
1219 decompressor.ignore_adler32();
1220 let mut decompressed = vec![0; 1024];
1221 let decompressed_len = decompressor
1222 .read(&compressed, &mut decompressed, 0, true)
1223 .unwrap()
1224 .1;
1225 assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1226 }
1227
1228 #[test]
1229 fn checksum_after_eof() {
1230 let input = b"Hello world!";
1231 let compressed = crate::compress_to_vec(input);
1232
1233 let mut decompressor = Decompressor::new();
1234 let mut decompressed = vec![0; 1024];
1235 let (input_consumed, output_written) = decompressor
1236 .read(
1237 &compressed[..compressed.len() - 1],
1238 &mut decompressed,
1239 0,
1240 false,
1241 )
1242 .unwrap();
1243 assert_eq!(output_written, input.len());
1244 assert_eq!(input_consumed, compressed.len() - 1);
1245
1246 let (input_consumed, output_written) = decompressor
1247 .read(
1248 &compressed[input_consumed..],
1249 &mut decompressed[..output_written],
1250 output_written,
1251 true,
1252 )
1253 .unwrap();
1254 assert!(decompressor.is_done());
1255 assert_eq!(input_consumed, 1);
1256 assert_eq!(output_written, 0);
1257
1258 assert_eq!(&decompressed[..input.len()], input);
1259 }
1260
1261 #[test]
1262 fn zero_length() {
1263 let mut compressed = crate::compress_to_vec(b"").to_vec();
1264
1265 for _ in 0..10 {
1267 println!("compressed len: {}", compressed.len());
1268 compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1269 }
1270
1271 for end_of_input in [true, false] {
1274 let mut decompressor = Decompressor::new();
1275 let (input_consumed, output_written) = decompressor
1276 .read(&compressed, &mut [], 0, end_of_input)
1277 .unwrap();
1278
1279 assert!(decompressor.is_done());
1280 assert_eq!(input_consumed, compressed.len());
1281 assert_eq!(output_written, 0);
1282 }
1283 }
1284
1285 mod test_utils;
1286 use tables::FIXED_CODE_LENGTHS;
1287 use test_utils::{decompress_by_chunks, TestDecompressionError};
1288
1289 fn verify_no_sensitivity_to_input_chunking(
1290 input: &[u8],
1291 ) -> Result<Vec<u8>, TestDecompressionError> {
1292 let r_whole = decompress_by_chunks(input, vec![input.len()], false);
1293 let r_bytewise = decompress_by_chunks(input, std::iter::repeat(1), false);
1294 assert_eq!(r_whole, r_bytewise);
1295 r_whole }
1297
1298 #[test]
1303 fn test_input_chunking_sensitivity_when_handling_distance_codes() {
1304 let result = verify_no_sensitivity_to_input_chunking(include_bytes!(
1305 "../tests/input-chunking-sensitivity-example1.zz"
1306 ))
1307 .unwrap();
1308 assert_eq!(result.len(), 281);
1309 assert_eq!(simd_adler32::adler32(&result.as_slice()), 751299);
1310 }
1311
1312 #[test]
1317 fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example1() {
1318 let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1319 "../tests/input-chunking-sensitivity-example2.zz"
1320 ))
1321 .unwrap_err();
1322 assert_eq!(
1323 err,
1324 TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1325 );
1326 }
1327
1328 #[test]
1333 fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example2() {
1334 let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1335 "../tests/input-chunking-sensitivity-example3.zz"
1336 ))
1337 .unwrap_err();
1338 assert_eq!(
1339 err,
1340 TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1341 );
1342 }
1343}