1use cpal::Sample;
2
3#[derive(Clone, Debug)]
5pub struct ChannelCountConverter<I>
6where
7 I: Iterator,
8{
9 input: I,
10 from: cpal::ChannelCount,
11 to: cpal::ChannelCount,
12 sample_repeat: Option<I::Item>,
13 next_output_sample_pos: cpal::ChannelCount,
14}
15
16impl<I> ChannelCountConverter<I>
17where
18 I: Iterator,
19{
20 #[inline]
27 pub fn new(
28 input: I,
29 from: cpal::ChannelCount,
30 to: cpal::ChannelCount,
31 ) -> ChannelCountConverter<I> {
32 assert!(from >= 1);
33 assert!(to >= 1);
34
35 ChannelCountConverter {
36 input,
37 from,
38 to,
39 sample_repeat: None,
40 next_output_sample_pos: 0,
41 }
42 }
43
44 #[inline]
46 pub fn into_inner(self) -> I {
47 self.input
48 }
49
50 #[inline]
52 pub fn inner_mut(&mut self) -> &mut I {
53 &mut self.input
54 }
55}
56
57impl<I> Iterator for ChannelCountConverter<I>
58where
59 I: Iterator,
60 I::Item: Sample,
61{
62 type Item = I::Item;
63
64 fn next(&mut self) -> Option<I::Item> {
65 let result = match self.next_output_sample_pos {
66 0 => {
67 let value = self.input.next();
69 self.sample_repeat = value;
70 value
71 }
72 x if x < self.from => self.input.next(),
73 1 => self.sample_repeat,
74 _ => Some(I::Item::EQUILIBRIUM),
75 };
76
77 if result.is_some() {
78 self.next_output_sample_pos += 1;
79 }
80
81 if self.next_output_sample_pos == self.to {
82 self.next_output_sample_pos = 0;
83
84 if self.from > self.to {
85 for _ in self.to..self.from {
86 self.input.next(); }
88 }
89 }
90
91 result
92 }
93
94 #[inline]
95 fn size_hint(&self) -> (usize, Option<usize>) {
96 let (min, max) = self.input.size_hint();
97
98 let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
99 let calculate = |size| {
100 (size + consumed) / self.from as usize * self.to as usize
101 - self.next_output_sample_pos as usize
102 };
103
104 let min = calculate(min);
105 let max = max.map(calculate);
106
107 (min, max)
108 }
109}
110
111impl<I> ExactSizeIterator for ChannelCountConverter<I>
112where
113 I: ExactSizeIterator,
114 I::Item: Sample,
115{
116}
117
118#[cfg(test)]
119mod test {
120 use super::ChannelCountConverter;
121
122 #[test]
123 fn remove_channels() {
124 let input = vec![1u16, 2, 3, 4, 5, 6];
125 let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
126 assert_eq!(output, [1, 2, 4, 5]);
127
128 let input = vec![1u16, 2, 3, 4, 5, 6, 7, 8];
129 let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
130 assert_eq!(output, [1, 5]);
131 }
132
133 #[test]
134 fn add_channels() {
135 let input = vec![1i16, 2, 3, 4];
136 let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
137 assert_eq!(output, [1, 1, 2, 2, 3, 3, 4, 4]);
138
139 let input = vec![1i16, 2];
140 let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
141 assert_eq!(output, [1, 1, 0, 0, 2, 2, 0, 0]);
142
143 let input = vec![1i16, 2, 3, 4];
144 let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
145 assert_eq!(output, [1, 2, 0, 0, 3, 4, 0, 0]);
146 }
147
148 #[test]
149 fn size_hint() {
150 fn test(input: &[i16], from: cpal::ChannelCount, to: cpal::ChannelCount) {
151 let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
152 let count = converter.clone().count();
153 for left_in_iter in (0..=count).rev() {
154 println!("left_in_iter = {left_in_iter}");
155 assert_eq!(converter.size_hint(), (left_in_iter, Some(left_in_iter)));
156 converter.next();
157 }
158 assert_eq!(converter.size_hint(), (0, Some(0)));
159 }
160
161 test(&[1i16, 2, 3], 1, 2);
162 test(&[1i16, 2, 3, 4], 2, 4);
163 test(&[1i16, 2, 3, 4], 4, 2);
164 test(&[1i16, 2, 3, 4, 5, 6], 3, 8);
165 test(&[1i16, 2, 3, 4, 5, 6, 7, 8], 4, 1);
166 }
167
168 #[test]
169 fn len_more() {
170 let input = vec![1i16, 2, 3, 4];
171 let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
172 assert_eq!(output.len(), 6);
173 }
174
175 #[test]
176 fn len_less() {
177 let input = vec![1i16, 2, 3, 4];
178 let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
179 assert_eq!(output.len(), 2);
180 }
181}