rodio/decoder/
symphonia.rs

1use core::fmt;
2use core::time::Duration;
3use symphonia::{
4    core::{
5        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
6        codecs::{Decoder, DecoderOptions, CODEC_TYPE_NULL},
7        errors::Error,
8        formats::{FormatOptions, FormatReader, SeekedTo},
9        io::MediaSourceStream,
10        meta::MetadataOptions,
11        probe::Hint,
12        units::{self, Time},
13    },
14    default::get_probe,
15};
16
17use crate::{source, Source};
18
19use super::DecoderError;
20
21// Decoder errors are not considered fatal.
22// The correct action is to just get a new packet and try again.
23// But a decode error in more than 3 consecutive packets is fatal.
24const MAX_DECODE_RETRIES: usize = 3;
25
26pub(crate) struct SymphoniaDecoder {
27    decoder: Box<dyn Decoder>,
28    current_frame_offset: usize,
29    format: Box<dyn FormatReader>,
30    total_duration: Option<Time>,
31    buffer: SampleBuffer<i16>,
32    spec: SignalSpec,
33}
34
35impl SymphoniaDecoder {
36    pub(crate) fn new(
37        mss: MediaSourceStream,
38        extension: Option<&str>,
39    ) -> Result<Self, DecoderError> {
40        match SymphoniaDecoder::init(mss, extension) {
41            Err(e) => match e {
42                Error::IoError(e) => Err(DecoderError::IoError(e.to_string())),
43                Error::DecodeError(e) => Err(DecoderError::DecodeError(e)),
44                Error::SeekError(_) => {
45                    unreachable!("Seek errors should not occur during initialization")
46                }
47                Error::Unsupported(_) => Err(DecoderError::UnrecognizedFormat),
48                Error::LimitError(e) => Err(DecoderError::LimitError(e)),
49                Error::ResetRequired => Err(DecoderError::ResetRequired),
50            },
51            Ok(Some(decoder)) => Ok(decoder),
52            Ok(None) => Err(DecoderError::NoStreams),
53        }
54    }
55
56    pub(crate) fn into_inner(self) -> MediaSourceStream {
57        self.format.into_inner()
58    }
59
60    fn init(
61        mss: MediaSourceStream,
62        extension: Option<&str>,
63    ) -> symphonia::core::errors::Result<Option<SymphoniaDecoder>> {
64        let mut hint = Hint::new();
65        if let Some(ext) = extension {
66            hint.with_extension(ext);
67        }
68        let format_opts: FormatOptions = FormatOptions {
69            enable_gapless: true,
70            ..Default::default()
71        };
72        let metadata_opts: MetadataOptions = Default::default();
73        let mut probed = get_probe().format(&hint, mss, &format_opts, &metadata_opts)?;
74
75        let stream = match probed.format.default_track() {
76            Some(stream) => stream,
77            None => return Ok(None),
78        };
79
80        // Select the first supported track
81        let track_id = probed
82            .format
83            .tracks()
84            .iter()
85            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
86            .ok_or(symphonia::core::errors::Error::Unsupported(
87                "No track with supported codec",
88            ))?
89            .id;
90
91        let track = probed
92            .format
93            .tracks()
94            .iter()
95            .find(|track| track.id == track_id)
96            .unwrap();
97
98        let mut decoder = symphonia::default::get_codecs()
99            .make(&track.codec_params, &DecoderOptions::default())?;
100        let total_duration = stream
101            .codec_params
102            .time_base
103            .zip(stream.codec_params.n_frames)
104            .map(|(base, frames)| base.calc_time(frames));
105
106        let mut decode_errors: usize = 0;
107        let decoded = loop {
108            let current_frame = match probed.format.next_packet() {
109                Ok(packet) => packet,
110                Err(Error::IoError(_)) => break decoder.last_decoded(),
111                Err(e) => return Err(e),
112            };
113
114            // If the packet does not belong to the selected track, skip over it
115            if current_frame.track_id() != track_id {
116                continue;
117            }
118
119            match decoder.decode(&current_frame) {
120                Ok(decoded) => break decoded,
121                Err(e) => match e {
122                    Error::DecodeError(_) => {
123                        decode_errors += 1;
124                        if decode_errors > MAX_DECODE_RETRIES {
125                            return Err(e);
126                        } else {
127                            continue;
128                        }
129                    }
130                    _ => return Err(e),
131                },
132            }
133        };
134        let spec = decoded.spec().to_owned();
135        let buffer = SymphoniaDecoder::get_buffer(decoded, &spec);
136        Ok(Some(SymphoniaDecoder {
137            decoder,
138            current_frame_offset: 0,
139            format: probed.format,
140            total_duration,
141            buffer,
142            spec,
143        }))
144    }
145
146    #[inline]
147    fn get_buffer(decoded: AudioBufferRef, spec: &SignalSpec) -> SampleBuffer<i16> {
148        let duration = units::Duration::from(decoded.capacity() as u64);
149        let mut buffer = SampleBuffer::<i16>::new(duration, *spec);
150        buffer.copy_interleaved_ref(decoded);
151        buffer
152    }
153}
154
155impl Source for SymphoniaDecoder {
156    #[inline]
157    fn current_frame_len(&self) -> Option<usize> {
158        Some(self.buffer.samples().len())
159    }
160
161    #[inline]
162    fn channels(&self) -> u16 {
163        self.spec.channels.count() as u16
164    }
165
166    #[inline]
167    fn sample_rate(&self) -> u32 {
168        self.spec.rate
169    }
170
171    #[inline]
172    fn total_duration(&self) -> Option<Duration> {
173        self.total_duration
174            .map(|Time { seconds, frac }| Duration::new(seconds, (1f64 / frac) as u32))
175    }
176
177    fn try_seek(&mut self, pos: Duration) -> Result<(), source::SeekError> {
178        use symphonia::core::formats::{SeekMode, SeekTo};
179
180        let seek_beyond_end = self
181            .total_duration()
182            .is_some_and(|dur| dur.saturating_sub(pos).as_millis() < 1);
183
184        let time = if seek_beyond_end {
185            let time = self.total_duration.expect("if guarantees this is Some");
186            skip_back_a_tiny_bit(time) // some decoders can only seek to just before the end
187        } else {
188            pos.as_secs_f64().into()
189        };
190
191        // make sure the next sample is for the right channel
192        let to_skip = self.current_frame_offset % self.channels() as usize;
193
194        let seek_res = self
195            .format
196            .seek(
197                SeekMode::Accurate,
198                SeekTo::Time {
199                    time,
200                    track_id: None,
201                },
202            )
203            .map_err(SeekError::BaseSeek)?;
204
205        self.refine_position(seek_res)?;
206        self.current_frame_offset += to_skip;
207
208        Ok(())
209    }
210}
211
212/// Error returned when the try_seek implementation of the symphonia decoder fails.
213#[derive(Debug)]
214pub enum SeekError {
215    /// Could not get next packet while refining seek position
216    Refining(symphonia::core::errors::Error),
217    /// Format reader failed to seek
218    BaseSeek(symphonia::core::errors::Error),
219    /// Decoding failed retrying on the next packet failed
220    Retrying(symphonia::core::errors::Error),
221    /// Decoding failed on multiple consecutive packets
222    Decoding(symphonia::core::errors::Error),
223}
224impl fmt::Display for SeekError {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        match self {
227            SeekError::Refining(err) => {
228                write!(
229                    f,
230                    "Could not get next packet while refining seek position: {:?}",
231                    err
232                )
233            }
234            SeekError::BaseSeek(err) => {
235                write!(f, "Format reader failed to seek: {:?}", err)
236            }
237            SeekError::Retrying(err) => {
238                write!(
239                    f,
240                    "Decoding failed retrying on the next packet failed: {:?}",
241                    err
242                )
243            }
244            SeekError::Decoding(err) => {
245                write!(
246                    f,
247                    "Decoding failed on multiple consecutive packets: {:?}",
248                    err
249                )
250            }
251        }
252    }
253}
254impl std::error::Error for SeekError {
255    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
256        match self {
257            SeekError::Refining(err) => Some(err),
258            SeekError::BaseSeek(err) => Some(err),
259            SeekError::Retrying(err) => Some(err),
260            SeekError::Decoding(err) => Some(err),
261        }
262    }
263}
264
265impl SymphoniaDecoder {
266    /// Note frame offset must be set after
267    fn refine_position(&mut self, seek_res: SeekedTo) -> Result<(), source::SeekError> {
268        let mut samples_to_pass = seek_res.required_ts - seek_res.actual_ts;
269        let packet = loop {
270            let candidate = self.format.next_packet().map_err(SeekError::Refining)?;
271            if candidate.dur() > samples_to_pass {
272                break candidate;
273            } else {
274                samples_to_pass -= candidate.dur();
275            }
276        };
277
278        let mut decoded = self.decoder.decode(&packet);
279        for _ in 0..MAX_DECODE_RETRIES {
280            if decoded.is_err() {
281                let packet = self.format.next_packet().map_err(SeekError::Retrying)?;
282                decoded = self.decoder.decode(&packet);
283            }
284        }
285
286        let decoded = decoded.map_err(SeekError::Decoding)?;
287        decoded.spec().clone_into(&mut self.spec);
288        self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec);
289        self.current_frame_offset = samples_to_pass as usize * self.channels() as usize;
290        Ok(())
291    }
292}
293
294fn skip_back_a_tiny_bit(
295    Time {
296        mut seconds,
297        mut frac,
298    }: Time,
299) -> Time {
300    frac -= 0.0001;
301    if frac < 0.0 {
302        seconds = seconds.saturating_sub(1);
303        frac = 1.0 - frac;
304    }
305    Time { seconds, frac }
306}
307
308impl Iterator for SymphoniaDecoder {
309    type Item = i16;
310
311    #[inline]
312    fn next(&mut self) -> Option<i16> {
313        if self.current_frame_offset >= self.buffer.len() {
314            let packet = self.format.next_packet().ok()?;
315            let mut decoded = self.decoder.decode(&packet);
316            for _ in 0..MAX_DECODE_RETRIES {
317                if decoded.is_err() {
318                    let packet = self.format.next_packet().ok()?;
319                    decoded = self.decoder.decode(&packet);
320                }
321            }
322            let decoded = decoded.ok()?;
323            decoded.spec().clone_into(&mut self.spec);
324            self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec);
325            self.current_frame_offset = 0;
326        }
327
328        let sample = *self.buffer.samples().get(self.current_frame_offset)?;
329        self.current_frame_offset += 1;
330
331        Some(sample)
332    }
333}