use std::cmp;
use std::io;
pub struct BufferedReader<R: io::Read> {
inner: R,
buf: Box<[u8]>,
pos: u32,
num_valid: u32,
}
impl<R: io::Read> BufferedReader<R> {
pub fn new(inner: R) -> BufferedReader<R> {
#[cfg(not(fuzzing))]
const CAPACITY: usize = 2048;
#[cfg(fuzzing)]
const CAPACITY: usize = 31;
let buf = vec![0; CAPACITY].into_boxed_slice();
BufferedReader {
inner: inner,
buf: buf,
pos: 0,
num_valid: 0,
}
}
pub fn into_inner(self) -> R {
self.inner
}
}
pub trait ReadBytes {
fn read_u8(&mut self) -> io::Result<u8>;
fn read_u8_or_eof(&mut self) -> io::Result<Option<u8>>;
fn read_into(&mut self, buffer: &mut [u8]) -> io::Result<()>;
fn skip(&mut self, amount: u32) -> io::Result<()>;
fn read_be_u16(&mut self) -> io::Result<u16> {
let b0 = try!(self.read_u8()) as u16;
let b1 = try!(self.read_u8()) as u16;
Ok(b0 << 8 | b1)
}
fn read_be_u16_or_eof(&mut self) -> io::Result<Option<u16>> {
if let Some(b0) = try!(self.read_u8_or_eof()) {
if let Some(b1) = try!(self.read_u8_or_eof()) {
return Ok(Some((b0 as u16) << 8 | (b1 as u16)));
}
}
Ok(None)
}
fn read_be_u24(&mut self) -> io::Result<u32> {
let b0 = try!(self.read_u8()) as u32;
let b1 = try!(self.read_u8()) as u32;
let b2 = try!(self.read_u8()) as u32;
Ok(b0 << 16 | b1 << 8 | b2)
}
fn read_be_u32(&mut self) -> io::Result<u32> {
let b0 = try!(self.read_u8()) as u32;
let b1 = try!(self.read_u8()) as u32;
let b2 = try!(self.read_u8()) as u32;
let b3 = try!(self.read_u8()) as u32;
Ok(b0 << 24 | b1 << 16 | b2 << 8 | b3)
}
fn read_le_u32(&mut self) -> io::Result<u32> {
let b0 = try!(self.read_u8()) as u32;
let b1 = try!(self.read_u8()) as u32;
let b2 = try!(self.read_u8()) as u32;
let b3 = try!(self.read_u8()) as u32;
Ok(b3 << 24 | b2 << 16 | b1 << 8 | b0)
}
}
impl<R: io::Read> ReadBytes for BufferedReader<R>
{
#[inline(always)]
fn read_u8(&mut self) -> io::Result<u8> {
if self.pos == self.num_valid {
self.pos = 0;
self.num_valid = try!(self.inner.read(&mut self.buf)) as u32;
if self.num_valid == 0 {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
"Expected one more byte."))
}
}
let byte = unsafe { *self.buf.get_unchecked(self.pos as usize) };
self.pos += 1;
Ok(byte)
}
fn read_u8_or_eof(&mut self) -> io::Result<Option<u8>> {
if self.pos == self.num_valid {
self.pos = 0;
self.num_valid = try!(self.inner.read(&mut self.buf)) as u32;
if self.num_valid == 0 {
return Ok(None);
}
}
Ok(Some(try!(self.read_u8())))
}
fn read_into(&mut self, buffer: &mut [u8]) -> io::Result<()> {
let mut bytes_left = buffer.len();
while bytes_left > 0 {
let from = buffer.len() - bytes_left;
let count = cmp::min(bytes_left, (self.num_valid - self.pos) as usize);
buffer[from..from + count].copy_from_slice(
&self.buf[self.pos as usize..self.pos as usize + count]);
bytes_left -= count;
self.pos += count as u32;
if bytes_left > 0 {
self.pos = 0;
self.num_valid = try!(self.inner.read(&mut self.buf)) as u32;
if self.num_valid == 0 {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
"Expected more bytes."))
}
}
}
Ok(())
}
fn skip(&mut self, mut amount: u32) -> io::Result<()> {
while amount > 0 {
let num_left = self.num_valid - self.pos;
let read_now = cmp::min(amount, num_left);
self.pos += read_now;
amount -= read_now;
if amount > 0 {
self.pos = 0;
self.num_valid = try!(self.inner.read(&mut self.buf)) as u32;
if self.num_valid == 0 {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
"Expected more bytes."))
}
}
}
Ok(())
}
}
impl<'r, R: ReadBytes> ReadBytes for &'r mut R {
#[inline(always)]
fn read_u8(&mut self) -> io::Result<u8> {
(*self).read_u8()
}
fn read_u8_or_eof(&mut self) -> io::Result<Option<u8>> {
(*self).read_u8_or_eof()
}
fn read_into(&mut self, buffer: &mut [u8]) -> io::Result<()> {
(*self).read_into(buffer)
}
fn skip(&mut self, amount: u32) -> io::Result<()> {
(*self).skip(amount)
}
}
impl<T: AsRef<[u8]>> ReadBytes for io::Cursor<T> {
fn read_u8(&mut self) -> io::Result<u8> {
let pos = self.position();
if pos < self.get_ref().as_ref().len() as u64 {
self.set_position(pos + 1);
Ok(self.get_ref().as_ref()[pos as usize])
} else {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected eof"))
}
}
fn read_u8_or_eof(&mut self) -> io::Result<Option<u8>> {
let pos = self.position();
if pos < self.get_ref().as_ref().len() as u64 {
self.set_position(pos + 1);
Ok(Some(self.get_ref().as_ref()[pos as usize]))
} else {
Ok(None)
}
}
fn read_into(&mut self, buffer: &mut [u8]) -> io::Result<()> {
let pos = self.position();
if pos + buffer.len() as u64 <= self.get_ref().as_ref().len() as u64 {
let start = pos as usize;
let end = pos as usize + buffer.len();
buffer.copy_from_slice(&self.get_ref().as_ref()[start..end]);
self.set_position(pos + buffer.len() as u64);
Ok(())
} else {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected eof"))
}
}
fn skip(&mut self, amount: u32) -> io::Result<()> {
let pos = self.position();
if pos + amount as u64 <= self.get_ref().as_ref().len() as u64 {
self.set_position(pos + amount as u64);
Ok(())
} else {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected eof"))
}
}
}
#[test]
fn verify_read_into_buffered_reader() {
let mut reader = BufferedReader::new(io::Cursor::new(vec![2u8, 3, 5, 7, 11, 13, 17, 19, 23]));
let mut buf1 = [0u8; 3];
let mut buf2 = [0u8; 5];
let mut buf3 = [0u8; 2];
reader.read_into(&mut buf1).ok().unwrap();
reader.read_into(&mut buf2).ok().unwrap();
assert!(reader.read_into(&mut buf3).is_err());
assert_eq!(&buf1[..], &[2u8, 3, 5]);
assert_eq!(&buf2[..], &[7u8, 11, 13, 17, 19]);
}
#[test]
fn verify_read_into_cursor() {
let mut cursor = io::Cursor::new(vec![2u8, 3, 5, 7, 11, 13, 17, 19, 23]);
let mut buf1 = [0u8; 3];
let mut buf2 = [0u8; 5];
let mut buf3 = [0u8; 2];
cursor.read_into(&mut buf1).ok().unwrap();
cursor.read_into(&mut buf2).ok().unwrap();
assert!(cursor.read_into(&mut buf3).is_err());
assert_eq!(&buf1[..], &[2u8, 3, 5]);
assert_eq!(&buf2[..], &[7u8, 11, 13, 17, 19]);
}
#[test]
fn verify_read_u8_buffered_reader() {
let mut reader = BufferedReader::new(io::Cursor::new(vec![0u8, 2, 129, 89, 122]));
assert_eq!(reader.read_u8().unwrap(), 0);
assert_eq!(reader.read_u8().unwrap(), 2);
assert_eq!(reader.read_u8().unwrap(), 129);
assert_eq!(reader.read_u8().unwrap(), 89);
assert_eq!(reader.read_u8_or_eof().unwrap(), Some(122));
assert_eq!(reader.read_u8_or_eof().unwrap(), None);
assert!(reader.read_u8().is_err());
}
#[test]
fn verify_read_u8_cursor() {
let mut reader = io::Cursor::new(vec![0u8, 2, 129, 89, 122]);
assert_eq!(reader.read_u8().unwrap(), 0);
assert_eq!(reader.read_u8().unwrap(), 2);
assert_eq!(reader.read_u8().unwrap(), 129);
assert_eq!(reader.read_u8().unwrap(), 89);
assert_eq!(reader.read_u8_or_eof().unwrap(), Some(122));
assert_eq!(reader.read_u8_or_eof().unwrap(), None);
assert!(reader.read_u8().is_err());
}
#[test]
fn verify_read_be_u16_buffered_reader() {
let mut reader = BufferedReader::new(io::Cursor::new(vec![0u8, 2, 129, 89, 122]));
assert_eq!(reader.read_be_u16().ok(), Some(2));
assert_eq!(reader.read_be_u16().ok(), Some(33113));
assert!(reader.read_be_u16().is_err());
}
#[test]
fn verify_read_be_u16_cursor() {
let mut cursor = io::Cursor::new(vec![0u8, 2, 129, 89, 122]);
assert_eq!(cursor.read_be_u16().ok(), Some(2));
assert_eq!(cursor.read_be_u16().ok(), Some(33113));
assert!(cursor.read_be_u16().is_err());
}
#[test]
fn verify_read_be_u24_buffered_reader() {
let mut reader = BufferedReader::new(io::Cursor::new(vec![0u8, 0, 2, 0x8f, 0xff, 0xf3, 122]));
assert_eq!(reader.read_be_u24().ok(), Some(2));
assert_eq!(reader.read_be_u24().ok(), Some(9_437_171));
assert!(reader.read_be_u24().is_err());
}
#[test]
fn verify_read_be_u24_cursor() {
let mut cursor = io::Cursor::new(vec![0u8, 0, 2, 0x8f, 0xff, 0xf3, 122]);
assert_eq!(cursor.read_be_u24().ok(), Some(2));
assert_eq!(cursor.read_be_u24().ok(), Some(9_437_171));
assert!(cursor.read_be_u24().is_err());
}
#[test]
fn verify_read_be_u32_buffered_reader() {
let mut reader = BufferedReader::new(io::Cursor::new(vec![0u8, 0, 0, 2, 0x80, 0x01, 0xff, 0xe9, 0]));
assert_eq!(reader.read_be_u32().ok(), Some(2));
assert_eq!(reader.read_be_u32().ok(), Some(2_147_614_697));
assert!(reader.read_be_u32().is_err());
}
#[test]
fn verify_read_be_u32_cursor() {
let mut cursor = io::Cursor::new(vec![0u8, 0, 0, 2, 0x80, 0x01, 0xff, 0xe9, 0]);
assert_eq!(cursor.read_be_u32().ok(), Some(2));
assert_eq!(cursor.read_be_u32().ok(), Some(2_147_614_697));
assert!(cursor.read_be_u32().is_err());
}
#[test]
fn verify_read_le_u32_buffered_reader() {
let mut reader = BufferedReader::new(io::Cursor::new(vec![2u8, 0, 0, 0, 0xe9, 0xff, 0x01, 0x80, 0]));
assert_eq!(reader.read_le_u32().ok(), Some(2));
assert_eq!(reader.read_le_u32().ok(), Some(2_147_614_697));
assert!(reader.read_le_u32().is_err());
}
#[test]
fn verify_read_le_u32_cursor() {
let mut reader = io::Cursor::new(vec![2u8, 0, 0, 0, 0xe9, 0xff, 0x01, 0x80, 0]);
assert_eq!(reader.read_le_u32().ok(), Some(2));
assert_eq!(reader.read_le_u32().ok(), Some(2_147_614_697));
assert!(reader.read_le_u32().is_err());
}
#[inline(always)]
fn shift_left(x: u8, shift: u32) -> u8 {
debug_assert!(shift <= 8);
((x as u32) << shift) as u8
}
#[inline(always)]
fn shift_right(x: u8, shift: u32) -> u8 {
debug_assert!(shift <= 8);
((x as u32) >> shift) as u8
}
pub struct Bitstream<R: ReadBytes> {
reader: R,
data: u8,
bits_left: u32,
}
impl<R: ReadBytes> Bitstream<R> {
pub fn new(reader: R) -> Bitstream<R> {
Bitstream {
reader: reader,
data: 0,
bits_left: 0,
}
}
#[inline(always)]
fn mask_u8(bits: u32) -> u8 {
debug_assert!(bits <= 8);
shift_left(0xff, 8 - bits)
}
#[inline(always)]
pub fn read_bit(&mut self) -> io::Result<bool> {
let result = if self.bits_left == 0 {
let fresh_byte = try!(self.reader.read_u8());
self.data = fresh_byte << 1;
self.bits_left = 7;
fresh_byte & 0b1000_0000
} else {
let bit = self.data & 0b1000_0000;
self.data = self.data << 1;
self.bits_left = self.bits_left - 1;
bit
};
Ok(result != 0)
}
#[inline(always)]
pub fn read_unary(&mut self) -> io::Result<u32> {
let mut n = self.data.leading_zeros();
if n < self.bits_left {
self.data = self.data << (n + 1);
self.bits_left = self.bits_left - (n + 1);
} else {
n = self.bits_left;
loop {
let fresh_byte = try!(self.reader.read_u8());
let zeros = fresh_byte.leading_zeros();
n = n + zeros;
if zeros < 8 {
self.bits_left = 8 - (zeros + 1);
self.data = shift_left(fresh_byte, zeros + 1);
break;
}
}
}
Ok(n)
}
#[inline(always)]
pub fn read_leq_u8(&mut self, bits: u32) -> io::Result<u8> {
debug_assert!(bits <= 8);
let result = if self.bits_left < bits {
let msb = self.data;
self.data = try!(self.reader.read_u8());
let lsb = (self.data & Bitstream::<R>::mask_u8(bits - self.bits_left))
>> self.bits_left;
self.data = shift_left(self.data, bits - self.bits_left);
self.bits_left = 8 - (bits - self.bits_left);
msb | lsb
} else {
let result = self.data & Bitstream::<R>::mask_u8(bits);
self.data = self.data << bits;
self.bits_left = self.bits_left - bits;
result
};
debug_assert!(self.bits_left < 8);
debug_assert_eq!(self.data & !Bitstream::<R>::mask_u8(self.bits_left), 0u8);
Ok(shift_right(result, 8 - bits))
}
#[inline(always)]
pub fn read_gt_u8_leq_u16(&mut self, bits: u32) -> io::Result<u32> {
debug_assert!((8 < bits) && (bits <= 16));
let mask_msb = 0xffffffff << (bits - self.bits_left);
let msb = ((self.data as u32) << (bits - 8)) & mask_msb;
let bits_to_read = bits - self.bits_left;
let fresh_byte = try!(self.reader.read_u8()) as u32;
let lsb = if bits_to_read >= 8 {
fresh_byte << (bits_to_read - 8)
} else {
fresh_byte >> (8 - bits_to_read)
};
let combined = msb | lsb;
let result = if bits_to_read <= 8 {
self.bits_left = 8 - bits_to_read;
self.data = fresh_byte.wrapping_shl(8 - self.bits_left) as u8;
combined
} else {
let fresher_byte = try!(self.reader.read_u8()) as u32;
let lsb = fresher_byte >> (16 - bits_to_read);
self.bits_left = 16 - bits_to_read;
self.data = fresher_byte.wrapping_shl(8 - self.bits_left) as u8;
combined | lsb
};
Ok(result)
}
#[inline(always)]
pub fn read_leq_u16(&mut self, bits: u32) -> io::Result<u16> {
debug_assert!(bits <= 16);
if bits <= 8 {
let result = try!(self.read_leq_u8(bits));
Ok(result as u16)
} else {
let msb = try!(self.read_leq_u8(8)) as u16;
let lsb = try!(self.read_leq_u8(bits - 8)) as u16;
Ok((msb << (bits - 8)) | lsb)
}
}
#[inline(always)]
pub fn read_leq_u32(&mut self, bits: u32) -> io::Result<u32> {
debug_assert!(bits <= 32);
if bits <= 16 {
let result = try!(self.read_leq_u16(bits));
Ok(result as u32)
} else {
let msb = try!(self.read_leq_u16(16)) as u32;
let lsb = try!(self.read_leq_u16(bits - 16)) as u32;
Ok((msb << (bits - 16)) | lsb)
}
}
}
#[test]
fn verify_read_bit() {
let data = io::Cursor::new(vec![0b1010_0100, 0b1110_0001]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_bit().unwrap(), true);
assert_eq!(bits.read_bit().unwrap(), false);
assert_eq!(bits.read_bit().unwrap(), true);
assert_eq!(bits.read_leq_u8(1).unwrap(), 0);
assert_eq!(bits.read_bit().unwrap(), false);
assert_eq!(bits.read_bit().unwrap(), true);
assert_eq!(bits.read_bit().unwrap(), false);
assert_eq!(bits.read_bit().unwrap(), false);
assert_eq!(bits.read_bit().unwrap(), true);
assert_eq!(bits.read_bit().unwrap(), true);
assert_eq!(bits.read_bit().unwrap(), true);
assert_eq!(bits.read_leq_u8(2).unwrap(), 0);
assert_eq!(bits.read_bit().unwrap(), false);
assert_eq!(bits.read_bit().unwrap(), false);
assert_eq!(bits.read_bit().unwrap(), true);
assert!(bits.read_bit().is_err());
}
#[test]
fn verify_read_unary() {
let data = io::Cursor::new(vec![
0b1010_0100, 0b1000_0000, 0b0010_0000, 0b0000_0000, 0b0000_1010]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_unary().unwrap(), 0);
assert_eq!(bits.read_unary().unwrap(), 1);
assert_eq!(bits.read_unary().unwrap(), 2);
assert_eq!(bits.read_unary().unwrap(), 2);
assert_eq!(bits.read_unary().unwrap(), 9);
assert_eq!(bits.read_unary().unwrap(), 17);
assert_eq!(bits.read_leq_u8(3).unwrap(), 0b010);
assert!(bits.read_bit().is_err());
}
#[test]
fn verify_read_leq_u8() {
let data = io::Cursor::new(vec![0b1010_0101,
0b1110_0001,
0b1101_0010,
0b0101_0101,
0b0111_0011,
0b0011_1111,
0b1010_1010,
0b0000_1100]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_leq_u8(0).unwrap(), 0);
assert_eq!(bits.read_leq_u8(1).unwrap(), 1);
assert_eq!(bits.read_leq_u8(1).unwrap(), 0);
assert_eq!(bits.read_leq_u8(2).unwrap(), 0b10);
assert_eq!(bits.read_leq_u8(2).unwrap(), 0b01);
assert_eq!(bits.read_leq_u8(3).unwrap(), 0b011);
assert_eq!(bits.read_leq_u8(3).unwrap(), 0b110);
assert_eq!(bits.read_leq_u8(4).unwrap(), 0b0001);
assert_eq!(bits.read_leq_u8(5).unwrap(), 0b11010);
assert_eq!(bits.read_leq_u8(6).unwrap(), 0b010010);
assert_eq!(bits.read_leq_u8(7).unwrap(), 0b1010101);
assert_eq!(bits.read_leq_u8(8).unwrap(), 0b11001100);
assert_eq!(bits.read_leq_u8(6).unwrap(), 0b111111);
assert_eq!(bits.read_leq_u8(8).unwrap(), 0b10101010);
assert_eq!(bits.read_leq_u8(4).unwrap(), 0b0000);
assert_eq!(bits.read_leq_u8(1).unwrap(), 1);
assert_eq!(bits.read_leq_u8(1).unwrap(), 1);
assert_eq!(bits.read_leq_u8(2).unwrap(), 0b00);
}
#[test]
fn verify_read_gt_u8_get_u16() {
let data = io::Cursor::new(vec![0b1010_0101, 0b1110_0001, 0b1101_0010, 0b0101_0101, 0b1111_0000]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_gt_u8_leq_u16(10).unwrap(), 0b1010_0101_11);
assert_eq!(bits.read_gt_u8_leq_u16(10).unwrap(), 0b10_0001_1101);
assert_eq!(bits.read_leq_u8(3).unwrap(), 0b001);
assert_eq!(bits.read_gt_u8_leq_u16(10).unwrap(), 0b0_0101_0101_1);
assert_eq!(bits.read_leq_u8(7).unwrap(), 0b111_0000);
assert!(bits.read_gt_u8_leq_u16(10).is_err());
}
#[test]
fn verify_read_leq_u16() {
let data = io::Cursor::new(vec![0b1010_0101, 0b1110_0001, 0b1101_0010, 0b0101_0101]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_leq_u16(0).unwrap(), 0);
assert_eq!(bits.read_leq_u16(1).unwrap(), 1);
assert_eq!(bits.read_leq_u16(13).unwrap(), 0b010_0101_1110_00);
assert_eq!(bits.read_leq_u16(9).unwrap(), 0b01_1101_001);
}
#[test]
fn verify_read_leq_u32() {
let data = io::Cursor::new(vec![0b1010_0101, 0b1110_0001, 0b1101_0010, 0b0101_0101]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_leq_u32(1).unwrap(), 1);
assert_eq!(bits.read_leq_u32(17).unwrap(), 0b010_0101_1110_0001_11);
assert_eq!(bits.read_leq_u32(14).unwrap(), 0b01_0010_0101_0101);
}
#[test]
fn verify_read_mixed() {
let data = io::Cursor::new(vec![0x03, 0xc7, 0xbf, 0xe5, 0x9b, 0x74, 0x1e, 0x3a, 0xdd, 0x7d,
0xc5, 0x5e, 0xf6, 0xbf, 0x78, 0x1b, 0xbd]);
let mut bits = Bitstream::new(BufferedReader::new(data));
assert_eq!(bits.read_leq_u8(6).unwrap(), 0);
assert_eq!(bits.read_leq_u8(1).unwrap(), 1);
let minus = 1u32 << 16;
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-14401_i16 as u16 as u32));
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-13514_i16 as u16 as u32));
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-12168_i16 as u16 as u32));
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-10517_i16 as u16 as u32));
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-09131_i16 as u16 as u32));
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-08489_i16 as u16 as u32));
assert_eq!(bits.read_leq_u32(17).unwrap(), minus | (-08698_i16 as u16 as u32));
}