1use std::convert::TryInto;
14use std::f32;
15
16use lazy_static::lazy_static;
17
18use super::complex::Complex;
19
20macro_rules! fft_twiddle_table {
21 ($bi:expr, $name:ident) => {
22 lazy_static! {
23 static ref $name: [Complex; (1 << $bi) >> 1] = {
24 const N: usize = 1 << $bi;
25
26 let mut table = [Default::default(); N >> 1];
27
28 let theta = std::f64::consts::PI / (N >> 1) as f64;
29
30 for (k, t) in table.iter_mut().enumerate() {
31 let angle = theta * k as f64;
32 *t = Complex::new(angle.cos() as f32, -angle.sin() as f32);
33 }
34
35 table
36 };
37 }
38 };
39}
40
41fft_twiddle_table!(6, FFT_TWIDDLE_TABLE_64);
42fft_twiddle_table!(7, FFT_TWIDDLE_TABLE_128);
43fft_twiddle_table!(8, FFT_TWIDDLE_TABLE_256);
44fft_twiddle_table!(9, FFT_TWIDDLE_TABLE_512);
45fft_twiddle_table!(10, FFT_TWIDDLE_TABLE_1024);
46fft_twiddle_table!(11, FFT_TWIDDLE_TABLE_2048);
47fft_twiddle_table!(12, FFT_TWIDDLE_TABLE_4096);
48fft_twiddle_table!(13, FFT_TWIDDLE_TABLE_8192);
49fft_twiddle_table!(14, FFT_TWIDDLE_TABLE_16384);
50fft_twiddle_table!(15, FFT_TWIDDLE_TABLE_32768);
51fft_twiddle_table!(16, FFT_TWIDDLE_TABLE_65536);
52
53fn fft_twiddle_factors(n: usize) -> &'static [Complex] {
55 match n {
57 64 => FFT_TWIDDLE_TABLE_64.as_ref(),
58 128 => FFT_TWIDDLE_TABLE_128.as_ref(),
59 256 => FFT_TWIDDLE_TABLE_256.as_ref(),
60 512 => FFT_TWIDDLE_TABLE_512.as_ref(),
61 1024 => FFT_TWIDDLE_TABLE_1024.as_ref(),
62 2048 => FFT_TWIDDLE_TABLE_2048.as_ref(),
63 4096 => FFT_TWIDDLE_TABLE_4096.as_ref(),
64 8192 => FFT_TWIDDLE_TABLE_8192.as_ref(),
65 16384 => FFT_TWIDDLE_TABLE_16384.as_ref(),
66 32768 => FFT_TWIDDLE_TABLE_32768.as_ref(),
67 65536 => FFT_TWIDDLE_TABLE_65536.as_ref(),
68 _ => panic!("fft size too large"),
69 }
70}
71
72pub struct Fft {
74 perm: Box<[u16]>,
75}
76
77impl Fft {
78 pub const MAX_SIZE: usize = 1 << 16;
80
81 pub fn new(n: usize) -> Self {
82 assert!(n.is_power_of_two());
84 assert!(n <= Fft::MAX_SIZE);
87
88 let n = n as u16;
90 let shift = n.leading_zeros() + 1;
91 let perm = (0..n).map(|i| i.reverse_bits() >> shift).collect();
92
93 Self { perm }
94 }
95
96 pub fn size(&self) -> usize {
98 self.perm.len()
99 }
100
101 pub fn ifft(&self, x: &[Complex], y: &mut [Complex]) {
103 let n = x.len();
104 assert_eq!(n, y.len());
105 assert_eq!(n, self.perm.len());
106
107 for (x, y) in self.perm.iter().map(|&i| x[usize::from(i)]).zip(y.iter_mut()) {
109 *y = Complex { re: x.im, im: x.re };
110 }
111
112 Self::transform(y, n);
114
115 let c = 1.0 / n as f32;
117
118 for y in y.iter_mut() {
119 *y = Complex { re: c * y.im, im: c * y.re };
120 }
121 }
122
123 pub fn ifft_inplace(&self, x: &mut [Complex]) {
125 let n = x.len();
126 assert_eq!(n, self.perm.len());
127
128 for (i, &j) in self.perm.iter().enumerate() {
130 let j = usize::from(j);
131
132 if i <= j {
133 let xi = x[i];
135 let xj = x[j];
136 x[i] = Complex::new(xj.im, xj.re);
137 x[j] = Complex::new(xi.im, xi.re);
138 }
139 }
140
141 Self::transform(x, n);
143
144 let c = 1.0 / n as f32;
146
147 for x in x.iter_mut() {
148 *x = Complex { re: c * x.im, im: c * x.re };
149 }
150 }
151
152 pub fn fft_inplace(&self, x: &mut [Complex]) {
154 let n = x.len();
155 assert_eq!(n, x.len());
156 assert_eq!(n, self.perm.len());
157
158 for (i, &j) in self.perm.iter().enumerate() {
159 let j = usize::from(j);
160
161 if i < j {
162 x.swap(i, j);
163 }
164 }
165
166 match n {
168 1 => (),
169 2 => fft2(x.try_into().unwrap()),
170 4 => fft4(x.try_into().unwrap()),
171 8 => fft8(x.try_into().unwrap()),
172 16 => fft16(x.try_into().unwrap()),
173 _ => Self::transform(x, n),
174 }
175 }
176
177 pub fn fft(&self, x: &[Complex], y: &mut [Complex]) {
179 let n = x.len();
180 assert_eq!(n, y.len());
181 assert_eq!(n, self.perm.len());
182
183 for (x, y) in self.perm.iter().map(|&i| x[usize::from(i)]).zip(y.iter_mut()) {
185 *y = x;
186 }
187
188 match n {
190 1 => (),
191 2 => fft2(y.try_into().unwrap()),
192 4 => fft4(y.try_into().unwrap()),
193 8 => fft8(y.try_into().unwrap()),
194 16 => fft16(y.try_into().unwrap()),
195 _ => Self::transform(y, n),
196 }
197 }
198
199 fn transform(x: &mut [Complex], n: usize) {
200 fn to_arr(x: &mut [Complex]) -> Option<&mut [Complex; 32]> {
201 x.try_into().ok()
202 }
203
204 if let Some(x) = to_arr(x) {
205 fft32(x);
206 }
207 else {
208 let n_half = n >> 1;
209
210 let (even, odd) = x.split_at_mut(n_half);
211
212 Self::transform(even, n_half);
213 Self::transform(odd, n_half);
214
215 let twiddle = fft_twiddle_factors(n);
216
217 for ((e, o), w) in
218 even.chunks_exact_mut(2).zip(odd.chunks_exact_mut(2)).zip(twiddle.chunks_exact(2))
219 {
220 let p0 = e[0];
221 let q0 = o[0] * w[0];
222
223 e[0] = p0 + q0;
224 o[0] = p0 - q0;
225
226 let p1 = e[1];
227 let q1 = o[1] * w[1];
228
229 e[1] = p1 + q1;
230 o[1] = p1 - q1;
231 }
232 }
233 }
234}
235
236macro_rules! complex {
237 ($re:expr, $im:expr) => {
238 Complex { re: $re, im: $im }
239 };
240}
241
242fn fft32(x: &mut [Complex; 32]) {
243 let mut x0 = [
244 x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13],
245 x[14], x[15],
246 ];
247 let mut x1 = [
248 x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23], x[24], x[25], x[26], x[27], x[28],
249 x[29], x[30], x[31],
250 ];
251
252 fft16(&mut x0);
253 fft16(&mut x1);
254
255 let a4 = f32::consts::FRAC_1_SQRT_2 * x1[4].re;
256 let b4 = f32::consts::FRAC_1_SQRT_2 * x1[4].im;
257 let a12 = -f32::consts::FRAC_1_SQRT_2 * x1[12].re;
258 let b12 = -f32::consts::FRAC_1_SQRT_2 * x1[12].im;
259
260 let x1p = [
261 x1[0],
262 complex!(0.98078528040323044913, -0.19509032201612826785) * x1[1],
263 complex!(0.92387953251128675613, -0.38268343236508977173) * x1[2],
264 complex!(0.83146961230254523708, -0.55557023301960222474) * x1[3],
265 complex!(a4 + b4, b4 - a4),
266 complex!(0.55557023301960222474, -0.83146961230254523708) * x1[5],
267 complex!(0.38268343236508977173, -0.92387953251128675613) * x1[6],
268 complex!(0.19509032201612826785, -0.98078528040323044913) * x1[7],
269 complex!(x1[8].im, -x1[8].re),
270 complex!(-0.19509032201612826785, -0.98078528040323044913) * x1[9],
271 complex!(-0.38268343236508977173, -0.92387953251128675613) * x1[10],
272 complex!(-0.55557023301960222474, -0.83146961230254523708) * x1[11],
273 complex!(a12 - b12, a12 + b12),
274 complex!(-0.83146961230254523708, -0.55557023301960222474) * x1[13],
275 complex!(-0.92387953251128675613, -0.38268343236508977173) * x1[14],
276 complex!(-0.98078528040323044913, -0.19509032201612826785) * x1[15],
277 ];
278
279 x[0] = x0[0] + x1p[0];
280 x[1] = x0[1] + x1p[1];
281 x[2] = x0[2] + x1p[2];
282 x[3] = x0[3] + x1p[3];
283 x[4] = x0[4] + x1p[4];
284 x[5] = x0[5] + x1p[5];
285 x[6] = x0[6] + x1p[6];
286 x[7] = x0[7] + x1p[7];
287 x[8] = x0[8] + x1p[8];
288 x[9] = x0[9] + x1p[9];
289 x[10] = x0[10] + x1p[10];
290 x[11] = x0[11] + x1p[11];
291 x[12] = x0[12] + x1p[12];
292 x[13] = x0[13] + x1p[13];
293 x[14] = x0[14] + x1p[14];
294 x[15] = x0[15] + x1p[15];
295
296 x[16] = x0[0] - x1p[0];
297 x[17] = x0[1] - x1p[1];
298 x[18] = x0[2] - x1p[2];
299 x[19] = x0[3] - x1p[3];
300 x[20] = x0[4] - x1p[4];
301 x[21] = x0[5] - x1p[5];
302 x[22] = x0[6] - x1p[6];
303 x[23] = x0[7] - x1p[7];
304 x[24] = x0[8] - x1p[8];
305 x[25] = x0[9] - x1p[9];
306 x[26] = x0[10] - x1p[10];
307 x[27] = x0[11] - x1p[11];
308 x[28] = x0[12] - x1p[12];
309 x[29] = x0[13] - x1p[13];
310 x[30] = x0[14] - x1p[14];
311 x[31] = x0[15] - x1p[15];
312}
313
314#[inline(always)]
315fn fft16(x: &mut [Complex; 16]) {
316 let mut x0 = [x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]];
317 let mut x1 = [x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]];
318
319 fft8(&mut x0);
320 fft8(&mut x1);
321
322 let a2 = f32::consts::FRAC_1_SQRT_2 * x1[2].re;
323 let b2 = f32::consts::FRAC_1_SQRT_2 * x1[2].im;
324 let a6 = -f32::consts::FRAC_1_SQRT_2 * x1[6].re;
325 let b6 = -f32::consts::FRAC_1_SQRT_2 * x1[6].im;
326
327 let x1p = [
328 x1[0],
329 complex!(0.92387953251128675613, -0.38268343236508977173) * x1[1],
330 complex!(a2 + b2, b2 - a2),
331 complex!(0.38268343236508977173, -0.92387953251128675613) * x1[3],
332 complex!(x1[4].im, -x1[4].re),
333 complex!(-0.38268343236508977173, -0.92387953251128675613) * x1[5],
334 complex!(a6 - b6, a6 + b6),
335 complex!(-0.92387953251128675613, -0.38268343236508977173) * x1[7],
336 ];
337
338 x[0] = x0[0] + x1p[0];
339 x[1] = x0[1] + x1p[1];
340 x[2] = x0[2] + x1p[2];
341 x[3] = x0[3] + x1p[3];
342 x[4] = x0[4] + x1p[4];
343 x[5] = x0[5] + x1p[5];
344 x[6] = x0[6] + x1p[6];
345 x[7] = x0[7] + x1p[7];
346
347 x[8] = x0[0] - x1p[0];
348 x[9] = x0[1] - x1p[1];
349 x[10] = x0[2] - x1p[2];
350 x[11] = x0[3] - x1p[3];
351 x[12] = x0[4] - x1p[4];
352 x[13] = x0[5] - x1p[5];
353 x[14] = x0[6] - x1p[6];
354 x[15] = x0[7] - x1p[7];
355}
356
357#[inline(always)]
358fn fft8(x: &mut [Complex; 8]) {
359 let mut x0 = [x[0], x[1], x[2], x[3]];
360 let mut x1 = [x[4], x[5], x[6], x[7]];
361
362 fft4(&mut x0);
363 fft4(&mut x1);
364
365 let a1 = f32::consts::FRAC_1_SQRT_2 * x1[1].re;
366 let b1 = f32::consts::FRAC_1_SQRT_2 * x1[1].im;
367 let a3 = -f32::consts::FRAC_1_SQRT_2 * x1[3].re;
368 let b3 = -f32::consts::FRAC_1_SQRT_2 * x1[3].im;
369
370 let x1p = [
371 x1[0],
372 complex!(a1 + b1, b1 - a1),
373 complex!(x1[2].im, -x1[2].re),
374 complex!(a3 - b3, a3 + b3),
375 ];
376
377 x[0] = x0[0] + x1p[0];
378 x[1] = x0[1] + x1p[1];
379 x[2] = x0[2] + x1p[2];
380 x[3] = x0[3] + x1p[3];
381
382 x[4] = x0[0] - x1p[0];
383 x[5] = x0[1] - x1p[1];
384 x[6] = x0[2] - x1p[2];
385 x[7] = x0[3] - x1p[3];
386}
387
388#[inline(always)]
389fn fft4(x: &mut [Complex; 4]) {
390 let x0 = [x[0] + x[1], x[0] - x[1]];
391 let x1 = [x[2] + x[3], x[2] - x[3]];
392
393 let x1p1 = complex!(x1[1].im, -x1[1].re);
394
395 x[0] = x0[0] + x1[0];
396 x[1] = x0[1] + x1p1;
397
398 x[2] = x0[0] - x1[0];
399 x[3] = x0[1] - x1p1;
400}
401
402#[inline(always)]
403fn fft2(x: &mut [Complex; 2]) {
404 let x0 = x[0];
405 x[0] = x0 + x[1];
406 x[1] = x0 - x[1];
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use std::f64;
413
414 fn dft_naive(x: &[Complex], y: &mut [Complex]) {
416 assert_eq!(x.len(), y.len());
417
418 let n = x.len() as u64;
419
420 let theta = 2.0 * f64::consts::PI / (x.len() as f64);
421
422 for (i, y) in y.iter_mut().enumerate() {
423 let mut re = 0f64;
424 let mut im = 0f64;
425
426 for (j, &x) in x.iter().enumerate() {
427 let xre = f64::from(x.re);
428 let xim = f64::from(x.im);
429
430 let ij = ((i as u64) * (j as u64)) & (n - 1);
431
432 let wre = (theta * ij as f64).cos();
433 let wim = -(theta * ij as f64).sin();
434
435 re += (xre * wre) - (xim * wim);
436 im += (xre * wim) + (xim * wre);
437 }
438
439 *y = Complex { re: re as f32, im: im as f32 };
440 }
441 }
442
443 fn idft_naive(x: &[Complex], y: &mut [Complex]) {
445 let n = x.len() as u64;
446
447 let z = x.iter().map(|x| Complex { re: x.im, im: x.re }).collect::<Vec<Complex>>();
448
449 dft_naive(&z, y);
450
451 let c = 1.0 / n as f32;
452
453 for y in y.iter_mut() {
454 *y = Complex { re: c * y.im, im: c * y.re };
455 }
456 }
457
458 fn check_complex(lhs: Complex, rhs: Complex, epsilon: f32) -> bool {
461 (lhs.re - rhs.re).abs() < epsilon && (lhs.im - rhs.im).abs() < epsilon
462 }
463
464 #[rustfmt::skip]
465 const TEST_VECTOR: [Complex; 64] = [
466 Complex { re: -1.82036, im: -0.72591 },
467 Complex { re: 1.21002, im: 0.75897 },
468 Complex { re: 1.31084, im: -0.51285 },
469 Complex { re: 1.26443, im: 1.57430 },
470 Complex { re: -1.93680, im: 0.69987 },
471 Complex { re: 0.85269, im: -0.20148 },
472 Complex { re: 1.10078, im: 0.88904 },
473 Complex { re: -1.20634, im: -0.07612 },
474 Complex { re: 1.43358, im: -1.91248 },
475 Complex { re: 0.10594, im: -0.30743 },
476 Complex { re: 1.51258, im: 0.99538 },
477 Complex { re: -1.33673, im: 0.23797 },
478 Complex { re: 0.43738, im: -1.69900 },
479 Complex { re: -0.95355, im: -0.33534 },
480 Complex { re: -0.05479, im: -0.32739 },
481 Complex { re: -1.85529, im: -1.93157 },
482 Complex { re: -1.04220, im: 1.04277 },
483 Complex { re: -0.17585, im: 0.40640 },
484 Complex { re: 0.09893, im: 1.89538 },
485 Complex { re: 1.25018, im: -0.85052 },
486 Complex { re: -1.60696, im: -1.41320 },
487 Complex { re: -0.25171, im: -0.13830 },
488 Complex { re: 1.17782, im: -1.41225 },
489 Complex { re: -0.35389, im: -0.30323 },
490 Complex { re: -0.16485, im: -0.83675 },
491 Complex { re: -1.66729, im: -0.52132 },
492 Complex { re: 1.41246, im: 1.58295 },
493 Complex { re: -1.84041, im: 0.15331 },
494 Complex { re: -1.38897, im: 1.16180 },
495 Complex { re: 0.27927, im: -1.84254 },
496 Complex { re: -0.46229, im: 0.09699 },
497 Complex { re: 1.21027, im: -0.31551 },
498 Complex { re: 0.26195, im: -1.19340 },
499 Complex { re: 1.60673, im: 1.07094 },
500 Complex { re: -0.07456, im: -0.63058 },
501 Complex { re: -1.77397, im: 1.39608 },
502 Complex { re: -0.80300, im: 0.08858 },
503 Complex { re: -0.06289, im: 1.48840 },
504 Complex { re: 0.66399, im: 0.49451 },
505 Complex { re: -1.49827, im: 1.61856 },
506 Complex { re: -1.39006, im: 0.67652 },
507 Complex { re: -0.90232, im: 0.18255 },
508 Complex { re: 0.00525, im: -1.05797 },
509 Complex { re: 0.53688, im: 0.88532 },
510 Complex { re: 0.52712, im: -0.58205 },
511 Complex { re: -1.77624, im: -0.66799 },
512 Complex { re: 1.65335, im: -1.72668 },
513 Complex { re: -0.24568, im: 1.50477 },
514 Complex { re: -0.15051, im: 0.67824 },
515 Complex { re: -1.96744, im: 0.81734 },
516 Complex { re: -1.62630, im: -0.73233 },
517 Complex { re: -1.98698, im: 0.63824 },
518 Complex { re: 0.78115, im: 0.97909 },
519 Complex { re: 0.97392, im: 1.82166 },
520 Complex { re: 1.30982, im: -1.23975 },
521 Complex { re: 0.85813, im: 0.80842 },
522 Complex { re: -1.13934, im: 0.81352 },
523 Complex { re: -1.22092, im: 0.98348 },
524 Complex { re: -1.67949, im: 0.78278 },
525 Complex { re: -1.77411, im: 0.00424 },
526 Complex { re: 1.93204, im: -0.03273 },
527 Complex { re: 1.38529, im: 1.31798 },
528 Complex { re: 0.61666, im: -0.09798 },
529 Complex { re: 1.02132, im: 1.70293 },
530 ];
531
532 #[test]
533 fn verify_fft() {
534 let mut actual = [Default::default(); TEST_VECTOR.len()];
535 let mut expected = [Default::default(); TEST_VECTOR.len()];
536
537 dft_naive(&TEST_VECTOR, &mut expected);
539
540 Fft::new(TEST_VECTOR.len()).fft(&TEST_VECTOR, &mut actual);
542
543 for (&a, &e) in actual.iter().zip(expected.iter()) {
544 assert!(check_complex(a, e, 0.00001));
545 }
546 }
547
548 #[test]
549 fn verify_fft_inplace() {
550 let mut actual = [Default::default(); TEST_VECTOR.len()];
551 let mut expected = [Default::default(); TEST_VECTOR.len()];
552
553 dft_naive(&TEST_VECTOR, &mut expected);
555
556 actual.copy_from_slice(&TEST_VECTOR);
558 Fft::new(TEST_VECTOR.len()).fft_inplace(&mut actual);
559
560 for (&a, &e) in actual.iter().zip(expected.iter()) {
561 assert!(check_complex(a, e, 0.00001));
562 }
563 }
564
565 #[test]
566 fn verify_ifft() {
567 let mut actual = [Default::default(); TEST_VECTOR.len()];
568 let mut expected = [Default::default(); TEST_VECTOR.len()];
569
570 idft_naive(&TEST_VECTOR, &mut expected);
572
573 Fft::new(TEST_VECTOR.len()).ifft(&TEST_VECTOR, &mut actual);
575
576 for (&a, &e) in actual.iter().zip(expected.iter()) {
577 assert!(check_complex(a, e, 0.00001));
578 }
579 }
580
581 #[test]
582 fn verify_ifft_inplace() {
583 let mut actual = [Default::default(); TEST_VECTOR.len()];
584 let mut expected = [Default::default(); TEST_VECTOR.len()];
585
586 idft_naive(&TEST_VECTOR, &mut expected);
588
589 actual.copy_from_slice(&TEST_VECTOR);
591 Fft::new(TEST_VECTOR.len()).ifft_inplace(&mut actual);
592
593 for (&a, &e) in actual.iter().zip(expected.iter()) {
594 assert!(check_complex(a, e, 0.00001));
595 }
596 }
597
598 #[test]
599 fn verify_fft_reversible() {
600 let mut fft_out = [Default::default(); TEST_VECTOR.len()];
601 let mut ifft_out = [Default::default(); TEST_VECTOR.len()];
602
603 let fft = Fft::new(TEST_VECTOR.len());
604 fft.fft(&TEST_VECTOR, &mut fft_out);
605 fft.ifft(&fft_out, &mut ifft_out);
606
607 for (&a, &e) in ifft_out.iter().zip(TEST_VECTOR.iter()) {
608 assert!(check_complex(a, e, 0.000001));
609 }
610 }
611
612 #[test]
613 fn verify_fft_inplace_reversible() {
614 let mut out = [Default::default(); TEST_VECTOR.len()];
615 out.copy_from_slice(&TEST_VECTOR);
616
617 let fft = Fft::new(TEST_VECTOR.len());
618 fft.fft_inplace(&mut out);
619 fft.ifft_inplace(&mut out);
620
621 for (&a, &e) in out.iter().zip(TEST_VECTOR.iter()) {
622 assert!(check_complex(a, e, 0.000001));
623 }
624 }
625}