1use std::cmp;
9use std::io;
10use std::io::{IoSliceMut, Read, Seek};
11use std::ops::Sub;
12
13use super::SeekBuffered;
14use super::{MediaSource, ReadBytes};
15
16#[inline(always)]
17fn end_of_stream_error<T>() -> io::Result<T> {
18 Err(io::Error::new(io::ErrorKind::UnexpectedEof, "end of stream"))
19}
20
21pub struct MediaSourceStreamOptions {
23 pub buffer_len: usize,
25}
26
27impl Default for MediaSourceStreamOptions {
28 fn default() -> Self {
29 MediaSourceStreamOptions { buffer_len: 64 * 1024 }
30 }
31}
32
33pub struct MediaSourceStream {
53 inner: Box<dyn MediaSource>,
55 ring: Box<[u8]>,
57 ring_mask: usize,
59 read_pos: usize,
61 write_pos: usize,
63 read_block_len: usize,
65 abs_pos: u64,
67 rel_pos: u64,
70}
71
72impl MediaSourceStream {
73 const MIN_BLOCK_LEN: usize = 1 * 1024;
74 const MAX_BLOCK_LEN: usize = 32 * 1024;
75
76 pub fn new(source: Box<dyn MediaSource>, options: MediaSourceStreamOptions) -> Self {
77 assert!(options.buffer_len.count_ones() == 1);
79 assert!(options.buffer_len > Self::MAX_BLOCK_LEN);
80
81 MediaSourceStream {
82 inner: source,
83 ring: vec![0; options.buffer_len].into_boxed_slice(),
84 ring_mask: options.buffer_len - 1,
85 read_pos: 0,
86 write_pos: 0,
87 read_block_len: Self::MIN_BLOCK_LEN,
88 abs_pos: 0,
89 rel_pos: 0,
90 }
91 }
92
93 #[inline(always)]
96 fn is_buffer_exhausted(&self) -> bool {
97 self.read_pos == self.write_pos
98 }
99
100 fn fetch(&mut self) -> io::Result<()> {
102 if self.is_buffer_exhausted() {
104 let (vec1, vec0) = self.ring.split_at_mut(self.write_pos);
107
108 let actual_read_len = if vec0.len() >= self.read_block_len {
112 self.inner.read(&mut vec0[..self.read_block_len])?
113 }
114 else {
115 let rem = self.read_block_len - vec0.len();
117
118 let ring_vectors = &mut [IoSliceMut::new(vec0), IoSliceMut::new(&mut vec1[..rem])];
119
120 self.inner.read_vectored(ring_vectors)?
121 };
122
123 self.write_pos = (self.write_pos + actual_read_len) & self.ring_mask;
125
126 self.abs_pos += actual_read_len as u64;
128 self.rel_pos += actual_read_len as u64;
129
130 self.read_block_len = cmp::min(self.read_block_len << 1, Self::MAX_BLOCK_LEN);
133 }
134
135 Ok(())
136 }
137
138 fn fetch_or_eof(&mut self) -> io::Result<()> {
141 self.fetch()?;
142
143 if self.is_buffer_exhausted() {
144 return end_of_stream_error();
145 }
146
147 Ok(())
148 }
149
150 #[inline(always)]
152 fn consume(&mut self, len: usize) {
153 self.read_pos = (self.read_pos + len) & self.ring_mask;
154 }
155
156 #[inline(always)]
158 fn continguous_buf(&self) -> &[u8] {
159 if self.write_pos >= self.read_pos {
160 &self.ring[self.read_pos..self.write_pos]
161 }
162 else {
163 &self.ring[self.read_pos..]
164 }
165 }
166
167 fn reset(&mut self, pos: u64) {
169 self.read_pos = 0;
170 self.write_pos = 0;
171 self.read_block_len = Self::MIN_BLOCK_LEN;
172 self.abs_pos = pos;
173 self.rel_pos = 0;
174 }
175}
176
177impl MediaSource for MediaSourceStream {
178 #[inline]
179 fn is_seekable(&self) -> bool {
180 self.inner.is_seekable()
181 }
182
183 #[inline]
184 fn byte_len(&self) -> Option<u64> {
185 self.inner.byte_len()
186 }
187}
188
189impl io::Read for MediaSourceStream {
190 fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
191 let read_len = buf.len();
192
193 while !buf.is_empty() {
194 self.fetch()?;
196
197 match self.continguous_buf().read(buf) {
200 Ok(0) => break,
201 Ok(count) => {
202 buf = &mut buf[count..];
203 self.consume(count);
204 }
205 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
206 Err(e) => return Err(e),
207 }
208 }
209
210 Ok(read_len - buf.len())
213 }
214}
215
216impl io::Seek for MediaSourceStream {
217 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
218 let pos = match pos {
223 io::SeekFrom::Current(0) => return Ok(self.pos()),
224 io::SeekFrom::Current(delta_pos) => {
225 let delta = delta_pos - self.unread_buffer_len() as i64;
226 self.inner.seek(io::SeekFrom::Current(delta))
227 }
228 _ => self.inner.seek(pos),
229 }?;
230
231 self.reset(pos);
232
233 Ok(pos)
234 }
235}
236
237impl ReadBytes for MediaSourceStream {
238 #[inline(always)]
239 fn read_byte(&mut self) -> io::Result<u8> {
240 if self.is_buffer_exhausted() {
244 self.fetch_or_eof()?;
245 }
246
247 let value = self.ring[self.read_pos];
248 self.consume(1);
249
250 Ok(value)
251 }
252
253 fn read_double_bytes(&mut self) -> io::Result<[u8; 2]> {
254 let mut bytes = [0; 2];
255
256 let buf = self.continguous_buf();
257
258 if buf.len() >= 2 {
259 bytes.copy_from_slice(&buf[..2]);
260 self.consume(2);
261 }
262 else {
263 for byte in bytes.iter_mut() {
264 *byte = self.read_byte()?;
265 }
266 };
267
268 Ok(bytes)
269 }
270
271 fn read_triple_bytes(&mut self) -> io::Result<[u8; 3]> {
272 let mut bytes = [0; 3];
273
274 let buf = self.continguous_buf();
275
276 if buf.len() >= 3 {
277 bytes.copy_from_slice(&buf[..3]);
278 self.consume(3);
279 }
280 else {
281 for byte in bytes.iter_mut() {
282 *byte = self.read_byte()?;
283 }
284 };
285 Ok(bytes)
286 }
287
288 fn read_quad_bytes(&mut self) -> io::Result<[u8; 4]> {
289 let mut bytes = [0; 4];
290
291 let buf = self.continguous_buf();
292
293 if buf.len() >= 4 {
294 bytes.copy_from_slice(&buf[..4]);
295 self.consume(4);
296 }
297 else {
298 for byte in bytes.iter_mut() {
299 *byte = self.read_byte()?;
300 }
301 };
302 Ok(bytes)
303 }
304
305 fn read_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> {
306 let read = self.read(buf)?;
308
309 if !buf.is_empty() && read == 0 {
313 end_of_stream_error()
314 }
315 else {
316 Ok(read)
317 }
318 }
319
320 fn read_buf_exact(&mut self, mut buf: &mut [u8]) -> io::Result<()> {
321 while !buf.is_empty() {
322 match self.read(buf) {
323 Ok(0) => break,
324 Ok(count) => {
325 buf = &mut buf[count..];
326 }
327 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
328 Err(e) => return Err(e),
329 }
330 }
331
332 if !buf.is_empty() {
333 end_of_stream_error()
334 }
335 else {
336 Ok(())
337 }
338 }
339
340 fn scan_bytes_aligned<'a>(
341 &mut self,
342 _: &[u8],
343 _: usize,
344 _: &'a mut [u8],
345 ) -> io::Result<&'a mut [u8]> {
346 unimplemented!();
348 }
349
350 fn ignore_bytes(&mut self, mut count: u64) -> io::Result<()> {
351 let ring_len = self.ring.len() as u64;
355
356 while count >= 2 * ring_len && self.is_seekable() {
358 let delta = count.clamp(0, i64::MAX as u64).sub(ring_len);
359 self.seek(io::SeekFrom::Current(delta as i64))?;
360 count -= delta;
361 }
362
363 while count > 0 {
365 self.fetch_or_eof()?;
366 let discard_count = cmp::min(self.unread_buffer_len() as u64, count);
367 self.consume(discard_count as usize);
368 count -= discard_count;
369 }
370 Ok(())
371 }
372
373 fn pos(&self) -> u64 {
374 self.abs_pos - self.unread_buffer_len() as u64
375 }
376}
377
378impl SeekBuffered for MediaSourceStream {
379 fn ensure_seekback_buffer(&mut self, len: usize) {
380 let ring_len = self.ring.len();
381
382 let new_ring_len = (Self::MAX_BLOCK_LEN + len).next_power_of_two();
386
387 if ring_len < new_ring_len {
389 let mut new_ring = vec![0; new_ring_len].into_boxed_slice();
391
392 let (vec0, vec1) = if self.write_pos >= self.read_pos {
394 (&self.ring[self.read_pos..self.write_pos], None)
395 }
396 else {
397 (&self.ring[self.read_pos..], Some(&self.ring[..self.write_pos]))
398 };
399
400 let vec0_len = vec0.len();
402 new_ring[..vec0_len].copy_from_slice(vec0);
403
404 self.write_pos = if let Some(vec1) = vec1 {
405 let total_len = vec0_len + vec1.len();
406 new_ring[vec0_len..total_len].copy_from_slice(vec1);
407 total_len
408 }
409 else {
410 vec0_len
411 };
412
413 self.ring = new_ring;
414 self.ring_mask = new_ring_len - 1;
415 self.read_pos = 0;
416 }
417 }
418
419 fn unread_buffer_len(&self) -> usize {
420 if self.write_pos >= self.read_pos {
421 self.write_pos - self.read_pos
422 }
423 else {
424 self.write_pos + (self.ring.len() - self.read_pos)
425 }
426 }
427
428 fn read_buffer_len(&self) -> usize {
429 let unread_len = self.unread_buffer_len();
430
431 cmp::min(self.ring.len(), self.rel_pos as usize) - unread_len
432 }
433
434 fn seek_buffered(&mut self, pos: u64) -> u64 {
435 let old_pos = self.pos();
436
437 let delta = if pos > old_pos {
439 assert!(pos - old_pos < std::isize::MAX as u64);
440 (pos - old_pos) as isize
441 }
442 else if pos < old_pos {
443 assert!(old_pos - pos < std::isize::MAX as u64);
445 -((old_pos - pos) as isize)
446 }
447 else {
448 0
449 };
450
451 self.seek_buffered_rel(delta)
452 }
453
454 fn seek_buffered_rel(&mut self, delta: isize) -> u64 {
455 if delta < 0 {
456 let abs_delta = cmp::min((-delta) as usize, self.read_buffer_len());
457 self.read_pos = (self.read_pos + self.ring.len() - abs_delta) & self.ring_mask;
458 }
459 else if delta > 0 {
460 let abs_delta = cmp::min(delta as usize, self.unread_buffer_len());
461 self.read_pos = (self.read_pos + abs_delta) & self.ring_mask;
462 }
463
464 self.pos()
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::{MediaSourceStream, ReadBytes, SeekBuffered};
471 use std::io::{Cursor, Read};
472
473 fn generate_random_bytes(len: usize) -> Box<[u8]> {
475 let mut lcg: u32 = 0xec57c4bf;
476
477 let mut bytes = vec![0; len];
478
479 for quad in bytes.chunks_mut(4) {
480 lcg = lcg.wrapping_mul(1664525).wrapping_add(1013904223);
481 for (src, dest) in quad.iter_mut().zip(&lcg.to_le_bytes()) {
482 *src = *dest;
483 }
484 }
485
486 bytes.into_boxed_slice()
487 }
488
489 #[test]
490 fn verify_mss_read() {
491 let data = generate_random_bytes(5 * 96 * 1024);
492
493 let ms = Cursor::new(data.clone());
494 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
495
496 let mut buf = &data[..];
500
501 for byte in &buf[..96 * 1024] {
503 assert_eq!(*byte, mss.read_byte().unwrap());
504 }
505
506 mss.ignore_bytes(11).unwrap();
507
508 buf = &buf[11 + (96 * 1024)..];
509
510 for bytes in buf[..2 * 48 * 1024].chunks_exact(2) {
512 assert_eq!(bytes, &mss.read_double_bytes().unwrap());
513 }
514
515 mss.ignore_bytes(33).unwrap();
516
517 buf = &buf[33 + (2 * 48 * 1024)..];
518
519 for bytes in buf[..3 * 32 * 1024].chunks_exact(3) {
521 assert_eq!(bytes, &mss.read_triple_bytes().unwrap());
522 }
523
524 mss.ignore_bytes(55).unwrap();
525
526 buf = &buf[55 + (3 * 32 * 1024)..];
527
528 for bytes in buf[..4 * 24 * 1024].chunks_exact(4) {
530 assert_eq!(bytes, &mss.read_quad_bytes().unwrap());
531 }
532 }
533
534 #[test]
535 fn verify_mss_read_to_end() {
536 let data = generate_random_bytes(5 * 96 * 1024);
537
538 let ms = Cursor::new(data.clone());
539 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
540 let mut output: Vec<u8> = Vec::new();
541 assert_eq!(mss.read_to_end(&mut output).unwrap(), data.len());
542 assert_eq!(output.into_boxed_slice(), data);
543 }
544
545 #[test]
546 fn verify_mss_seek_buffered() {
547 let data = generate_random_bytes(1024 * 1024);
548
549 let ms = Cursor::new(data);
550 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
551
552 assert_eq!(mss.read_buffer_len(), 0);
553 assert_eq!(mss.unread_buffer_len(), 0);
554
555 mss.ignore_bytes(5122).unwrap();
556
557 assert_eq!(5122, mss.pos());
558 assert_eq!(mss.read_buffer_len(), 5122);
559
560 let upper = mss.read_byte().unwrap();
561
562 assert_eq!(mss.seek_buffered_rel(-1000), 4123);
564 assert_eq!(mss.pos(), 4123);
565 assert_eq!(mss.read_buffer_len(), 4123);
566
567 assert_eq!(mss.seek_buffered_rel(999), 5122);
569 assert_eq!(mss.pos(), 5122);
570 assert_eq!(mss.read_buffer_len(), 5122);
571
572 assert_eq!(upper, mss.read_byte().unwrap());
573 }
574
575 #[test]
576 fn verify_reading_be() {
577 let data = generate_random_bytes(1024 * 1024);
578
579 let ms = Cursor::new(data);
580 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
581
582 mss.ignore_bytes(2).unwrap();
584
585 assert_eq!(mss.read_be_f32().unwrap(), -72818055000000000000000000000.0);
586 assert_eq!(mss.read_be_f64().unwrap(), -0.000000000000011582640453292664);
587
588 assert_eq!(mss.read_be_u16().unwrap(), 32624);
589 assert_eq!(mss.read_be_u24().unwrap(), 6739677);
590 assert_eq!(mss.read_be_u32().unwrap(), 1569552917);
591 assert_eq!(mss.read_be_u64().unwrap(), 6091217585348000864);
592 }
593
594 #[test]
595 fn verify_reading_le() {
596 let data = generate_random_bytes(1024 * 1024);
597
598 let ms = Cursor::new(data);
599 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
600
601 mss.ignore_bytes(1024).unwrap();
602
603 assert_eq!(mss.read_f32().unwrap(), -0.00000000000000000000000000048426285);
604 assert_eq!(mss.read_f64().unwrap(), -6444325820119113.0);
605
606 assert_eq!(mss.read_u16().unwrap(), 36195);
607 assert_eq!(mss.read_u24().unwrap(), 6710386);
608 assert_eq!(mss.read_u32().unwrap(), 2378776723);
609 assert_eq!(mss.read_u64().unwrap(), 5170196279331153683);
610 }
611}