1use core::fmt;
2use core::time::Duration;
3use symphonia::{
4 core::{
5 audio::{AudioBufferRef, SampleBuffer, SignalSpec},
6 codecs::{Decoder, DecoderOptions, CODEC_TYPE_NULL},
7 errors::Error,
8 formats::{FormatOptions, FormatReader, SeekedTo},
9 io::MediaSourceStream,
10 meta::MetadataOptions,
11 probe::Hint,
12 units::{self, Time},
13 },
14 default::get_probe,
15};
16
17use crate::{source, Source};
18
19use super::DecoderError;
20
21const MAX_DECODE_RETRIES: usize = 3;
25
26pub(crate) struct SymphoniaDecoder {
27 decoder: Box<dyn Decoder>,
28 current_frame_offset: usize,
29 format: Box<dyn FormatReader>,
30 total_duration: Option<Time>,
31 buffer: SampleBuffer<i16>,
32 spec: SignalSpec,
33}
34
35impl SymphoniaDecoder {
36 pub(crate) fn new(
37 mss: MediaSourceStream,
38 extension: Option<&str>,
39 ) -> Result<Self, DecoderError> {
40 match SymphoniaDecoder::init(mss, extension) {
41 Err(e) => match e {
42 Error::IoError(e) => Err(DecoderError::IoError(e.to_string())),
43 Error::DecodeError(e) => Err(DecoderError::DecodeError(e)),
44 Error::SeekError(_) => {
45 unreachable!("Seek errors should not occur during initialization")
46 }
47 Error::Unsupported(_) => Err(DecoderError::UnrecognizedFormat),
48 Error::LimitError(e) => Err(DecoderError::LimitError(e)),
49 Error::ResetRequired => Err(DecoderError::ResetRequired),
50 },
51 Ok(Some(decoder)) => Ok(decoder),
52 Ok(None) => Err(DecoderError::NoStreams),
53 }
54 }
55
56 pub(crate) fn into_inner(self) -> MediaSourceStream {
57 self.format.into_inner()
58 }
59
60 fn init(
61 mss: MediaSourceStream,
62 extension: Option<&str>,
63 ) -> symphonia::core::errors::Result<Option<SymphoniaDecoder>> {
64 let mut hint = Hint::new();
65 if let Some(ext) = extension {
66 hint.with_extension(ext);
67 }
68 let format_opts: FormatOptions = FormatOptions {
69 enable_gapless: true,
70 ..Default::default()
71 };
72 let metadata_opts: MetadataOptions = Default::default();
73 let mut probed = get_probe().format(&hint, mss, &format_opts, &metadata_opts)?;
74
75 let stream = match probed.format.default_track() {
76 Some(stream) => stream,
77 None => return Ok(None),
78 };
79
80 let track_id = probed
82 .format
83 .tracks()
84 .iter()
85 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
86 .ok_or(symphonia::core::errors::Error::Unsupported(
87 "No track with supported codec",
88 ))?
89 .id;
90
91 let track = probed
92 .format
93 .tracks()
94 .iter()
95 .find(|track| track.id == track_id)
96 .unwrap();
97
98 let mut decoder = symphonia::default::get_codecs()
99 .make(&track.codec_params, &DecoderOptions::default())?;
100 let total_duration = stream
101 .codec_params
102 .time_base
103 .zip(stream.codec_params.n_frames)
104 .map(|(base, frames)| base.calc_time(frames));
105
106 let mut decode_errors: usize = 0;
107 let decoded = loop {
108 let current_frame = match probed.format.next_packet() {
109 Ok(packet) => packet,
110 Err(Error::IoError(_)) => break decoder.last_decoded(),
111 Err(e) => return Err(e),
112 };
113
114 if current_frame.track_id() != track_id {
116 continue;
117 }
118
119 match decoder.decode(¤t_frame) {
120 Ok(decoded) => break decoded,
121 Err(e) => match e {
122 Error::DecodeError(_) => {
123 decode_errors += 1;
124 if decode_errors > MAX_DECODE_RETRIES {
125 return Err(e);
126 } else {
127 continue;
128 }
129 }
130 _ => return Err(e),
131 },
132 }
133 };
134 let spec = decoded.spec().to_owned();
135 let buffer = SymphoniaDecoder::get_buffer(decoded, &spec);
136 Ok(Some(SymphoniaDecoder {
137 decoder,
138 current_frame_offset: 0,
139 format: probed.format,
140 total_duration,
141 buffer,
142 spec,
143 }))
144 }
145
146 #[inline]
147 fn get_buffer(decoded: AudioBufferRef, spec: &SignalSpec) -> SampleBuffer<i16> {
148 let duration = units::Duration::from(decoded.capacity() as u64);
149 let mut buffer = SampleBuffer::<i16>::new(duration, *spec);
150 buffer.copy_interleaved_ref(decoded);
151 buffer
152 }
153}
154
155impl Source for SymphoniaDecoder {
156 #[inline]
157 fn current_frame_len(&self) -> Option<usize> {
158 Some(self.buffer.samples().len())
159 }
160
161 #[inline]
162 fn channels(&self) -> u16 {
163 self.spec.channels.count() as u16
164 }
165
166 #[inline]
167 fn sample_rate(&self) -> u32 {
168 self.spec.rate
169 }
170
171 #[inline]
172 fn total_duration(&self) -> Option<Duration> {
173 self.total_duration
174 .map(|Time { seconds, frac }| Duration::new(seconds, (1f64 / frac) as u32))
175 }
176
177 fn try_seek(&mut self, pos: Duration) -> Result<(), source::SeekError> {
178 use symphonia::core::formats::{SeekMode, SeekTo};
179
180 let seek_beyond_end = self
181 .total_duration()
182 .is_some_and(|dur| dur.saturating_sub(pos).as_millis() < 1);
183
184 let time = if seek_beyond_end {
185 let time = self.total_duration.expect("if guarantees this is Some");
186 skip_back_a_tiny_bit(time) } else {
188 pos.as_secs_f64().into()
189 };
190
191 let to_skip = self.current_frame_offset % self.channels() as usize;
193
194 let seek_res = self
195 .format
196 .seek(
197 SeekMode::Accurate,
198 SeekTo::Time {
199 time,
200 track_id: None,
201 },
202 )
203 .map_err(SeekError::BaseSeek)?;
204
205 self.refine_position(seek_res)?;
206 self.current_frame_offset += to_skip;
207
208 Ok(())
209 }
210}
211
212#[derive(Debug)]
214pub enum SeekError {
215 Refining(symphonia::core::errors::Error),
217 BaseSeek(symphonia::core::errors::Error),
219 Retrying(symphonia::core::errors::Error),
221 Decoding(symphonia::core::errors::Error),
223}
224impl fmt::Display for SeekError {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 match self {
227 SeekError::Refining(err) => {
228 write!(
229 f,
230 "Could not get next packet while refining seek position: {:?}",
231 err
232 )
233 }
234 SeekError::BaseSeek(err) => {
235 write!(f, "Format reader failed to seek: {:?}", err)
236 }
237 SeekError::Retrying(err) => {
238 write!(
239 f,
240 "Decoding failed retrying on the next packet failed: {:?}",
241 err
242 )
243 }
244 SeekError::Decoding(err) => {
245 write!(
246 f,
247 "Decoding failed on multiple consecutive packets: {:?}",
248 err
249 )
250 }
251 }
252 }
253}
254impl std::error::Error for SeekError {
255 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
256 match self {
257 SeekError::Refining(err) => Some(err),
258 SeekError::BaseSeek(err) => Some(err),
259 SeekError::Retrying(err) => Some(err),
260 SeekError::Decoding(err) => Some(err),
261 }
262 }
263}
264
265impl SymphoniaDecoder {
266 fn refine_position(&mut self, seek_res: SeekedTo) -> Result<(), source::SeekError> {
268 let mut samples_to_pass = seek_res.required_ts - seek_res.actual_ts;
269 let packet = loop {
270 let candidate = self.format.next_packet().map_err(SeekError::Refining)?;
271 if candidate.dur() > samples_to_pass {
272 break candidate;
273 } else {
274 samples_to_pass -= candidate.dur();
275 }
276 };
277
278 let mut decoded = self.decoder.decode(&packet);
279 for _ in 0..MAX_DECODE_RETRIES {
280 if decoded.is_err() {
281 let packet = self.format.next_packet().map_err(SeekError::Retrying)?;
282 decoded = self.decoder.decode(&packet);
283 }
284 }
285
286 let decoded = decoded.map_err(SeekError::Decoding)?;
287 decoded.spec().clone_into(&mut self.spec);
288 self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec);
289 self.current_frame_offset = samples_to_pass as usize * self.channels() as usize;
290 Ok(())
291 }
292}
293
294fn skip_back_a_tiny_bit(
295 Time {
296 mut seconds,
297 mut frac,
298 }: Time,
299) -> Time {
300 frac -= 0.0001;
301 if frac < 0.0 {
302 seconds = seconds.saturating_sub(1);
303 frac = 1.0 - frac;
304 }
305 Time { seconds, frac }
306}
307
308impl Iterator for SymphoniaDecoder {
309 type Item = i16;
310
311 #[inline]
312 fn next(&mut self) -> Option<i16> {
313 if self.current_frame_offset >= self.buffer.len() {
314 let packet = self.format.next_packet().ok()?;
315 let mut decoded = self.decoder.decode(&packet);
316 for _ in 0..MAX_DECODE_RETRIES {
317 if decoded.is_err() {
318 let packet = self.format.next_packet().ok()?;
319 decoded = self.decoder.decode(&packet);
320 }
321 }
322 let decoded = decoded.ok()?;
323 decoded.spec().clone_into(&mut self.spec);
324 self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec);
325 self.current_frame_offset = 0;
326 }
327
328 let sample = *self.buffer.samples().get(self.current_frame_offset)?;
329 self.current_frame_offset += 1;
330
331 Some(sample)
332 }
333}