rodio/source/
linear_ramp.rs

1use std::time::Duration;
2
3use super::SeekError;
4use crate::{Sample, Source};
5
6/// Internal function that builds a `LinearRamp` object.
7pub fn linear_gain_ramp<I>(
8    input: I,
9    duration: Duration,
10    start_gain: f32,
11    end_gain: f32,
12    clamp_end: bool,
13) -> LinearGainRamp<I>
14where
15    I: Source,
16    I::Item: Sample,
17{
18    let duration_nanos = duration.as_nanos() as f32;
19    assert!(duration_nanos > 0.0f32);
20
21    LinearGainRamp {
22        input,
23        elapsed_ns: 0.0f32,
24        total_ns: duration_nanos,
25        start_gain,
26        end_gain,
27        clamp_end,
28        sample_idx: 0u64,
29    }
30}
31
32/// Filter that adds a linear gain ramp to the source over a given time range.
33#[derive(Clone, Debug)]
34pub struct LinearGainRamp<I> {
35    input: I,
36    elapsed_ns: f32,
37    total_ns: f32,
38    start_gain: f32,
39    end_gain: f32,
40    clamp_end: bool,
41    sample_idx: u64,
42}
43
44impl<I> LinearGainRamp<I>
45where
46    I: Source,
47    I::Item: Sample,
48{
49    /// Returns a reference to the innner source.
50    #[inline]
51    pub fn inner(&self) -> &I {
52        &self.input
53    }
54
55    /// Returns a mutable reference to the inner source.
56    #[inline]
57    pub fn inner_mut(&mut self) -> &mut I {
58        &mut self.input
59    }
60
61    /// Returns the inner source.
62    #[inline]
63    pub fn into_inner(self) -> I {
64        self.input
65    }
66}
67
68impl<I> Iterator for LinearGainRamp<I>
69where
70    I: Source,
71    I::Item: Sample,
72{
73    type Item = I::Item;
74
75    #[inline]
76    fn next(&mut self) -> Option<I::Item> {
77        let factor: f32;
78        let remaining_ns = self.total_ns - self.elapsed_ns;
79
80        if remaining_ns < 0.0 {
81            if self.clamp_end {
82                factor = self.end_gain;
83            } else {
84                factor = 1.0f32;
85            }
86        } else {
87            self.sample_idx += 1;
88
89            let p = self.elapsed_ns / self.total_ns;
90            factor = self.start_gain * (1.0f32 - p) + self.end_gain * p;
91        }
92
93        if self.sample_idx % (self.channels() as u64) == 0 {
94            self.elapsed_ns += 1000000000.0 / (self.input.sample_rate() as f32);
95        }
96
97        self.input.next().map(|value| value.amplify(factor))
98    }
99
100    #[inline]
101    fn size_hint(&self) -> (usize, Option<usize>) {
102        self.input.size_hint()
103    }
104}
105
106impl<I> ExactSizeIterator for LinearGainRamp<I>
107where
108    I: Source + ExactSizeIterator,
109    I::Item: Sample,
110{
111}
112
113impl<I> Source for LinearGainRamp<I>
114where
115    I: Source,
116    I::Item: Sample,
117{
118    #[inline]
119    fn current_frame_len(&self) -> Option<usize> {
120        self.input.current_frame_len()
121    }
122
123    #[inline]
124    fn channels(&self) -> u16 {
125        self.input.channels()
126    }
127
128    #[inline]
129    fn sample_rate(&self) -> u32 {
130        self.input.sample_rate()
131    }
132
133    #[inline]
134    fn total_duration(&self) -> Option<Duration> {
135        self.input.total_duration()
136    }
137
138    #[inline]
139    fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
140        self.elapsed_ns = pos.as_nanos() as f32;
141        self.input.try_seek(pos)
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use approx::assert_abs_diff_eq;
148
149    use super::*;
150    use crate::buffer::SamplesBuffer;
151
152    /// Create a SamplesBuffer of identical samples with value `value`.
153    /// Returned buffer is one channel and has a sample rate of 1 hz.
154    fn const_source(length: u8, value: f32) -> SamplesBuffer<f32> {
155        let data: Vec<f32> = (1..=length).map(|_| value).collect();
156        SamplesBuffer::new(1, 1, data)
157    }
158
159    /// Create a SamplesBuffer of repeating sample values from `values`.
160    fn cycle_source(length: u8, values: Vec<f32>) -> SamplesBuffer<f32> {
161        let data: Vec<f32> = (1..=length)
162            .enumerate()
163            .map(|(i, _)| values[i % values.len()])
164            .collect();
165
166        SamplesBuffer::new(1, 1, data)
167    }
168
169    #[test]
170    fn test_linear_ramp() {
171        let source1 = const_source(10, 1.0f32);
172        let mut faded = linear_gain_ramp(source1, Duration::from_secs(4), 0.0, 1.0, true);
173
174        assert_eq!(faded.next(), Some(0.0));
175        assert_eq!(faded.next(), Some(0.25));
176        assert_eq!(faded.next(), Some(0.5));
177        assert_eq!(faded.next(), Some(0.75));
178        assert_eq!(faded.next(), Some(1.0));
179        assert_eq!(faded.next(), Some(1.0));
180        assert_eq!(faded.next(), Some(1.0));
181        assert_eq!(faded.next(), Some(1.0));
182        assert_eq!(faded.next(), Some(1.0));
183        assert_eq!(faded.next(), Some(1.0));
184        assert_eq!(faded.next(), None);
185    }
186
187    #[test]
188    fn test_linear_ramp_clamped() {
189        let source1 = const_source(10, 1.0f32);
190        let mut faded = linear_gain_ramp(source1, Duration::from_secs(4), 0.0, 0.5, true);
191
192        assert_eq!(faded.next(), Some(0.0)); // fading in...
193        assert_eq!(faded.next(), Some(0.125));
194        assert_eq!(faded.next(), Some(0.25));
195        assert_eq!(faded.next(), Some(0.375));
196        assert_eq!(faded.next(), Some(0.5)); // fade is done
197        assert_eq!(faded.next(), Some(0.5));
198        assert_eq!(faded.next(), Some(0.5));
199        assert_eq!(faded.next(), Some(0.5));
200        assert_eq!(faded.next(), Some(0.5));
201        assert_eq!(faded.next(), Some(0.5));
202        assert_eq!(faded.next(), None);
203    }
204
205    #[test]
206    fn test_linear_ramp_seek() {
207        let source1 = cycle_source(20, vec![0.0f32, 0.4f32, 0.8f32]);
208        let mut faded = linear_gain_ramp(source1, Duration::from_secs(10), 0.0, 1.0, true);
209
210        assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0
211        assert_abs_diff_eq!(faded.next().unwrap(), 0.04); // source value 0.4, ramp gain 0.1
212        assert_abs_diff_eq!(faded.next().unwrap(), 0.16); // source value 0.8, ramp gain 0.2
213
214        if let Ok(_result) = faded.try_seek(Duration::from_secs(5)) {
215            assert_abs_diff_eq!(faded.next().unwrap(), 0.40); // source value 0.8, ramp gain 0.5
216            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0, ramp gain 0.6
217            assert_abs_diff_eq!(faded.next().unwrap(), 0.28); // source value 0.4. ramp gain 0.7
218        } else {
219            panic!("try_seek() failed!");
220        }
221
222        if let Ok(_result) = faded.try_seek(Duration::from_secs(0)) {
223            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0, ramp gain 0.0
224            assert_abs_diff_eq!(faded.next().unwrap(), 0.04); // source value 0.4, ramp gain 0.1
225            assert_abs_diff_eq!(faded.next().unwrap(), 0.16); // source value 0.8. ramp gain 0.2
226        } else {
227            panic!("try_seek() failed!");
228        }
229
230        if let Ok(_result) = faded.try_seek(Duration::from_secs(10)) {
231            assert_abs_diff_eq!(faded.next().unwrap(), 0.4); // source value 0.4, ramp gain 1.0
232            assert_abs_diff_eq!(faded.next().unwrap(), 0.8); // source value 0.8, ramp gain 1.0
233            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0. ramp gain 1.0
234        } else {
235            panic!("try_seek() failed!");
236        }
237    }
238}