Created
August 23, 2024 14:10
-
-
Save pohzipohzi/81cf2ae92a00287fccdac1c2d98b7ae1 to your computer and use it in GitHub Desktop.
minimal fft implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
f64::consts::PI, | |
fs::File, | |
io::Read, | |
iter::zip, | |
time::{Duration, Instant}, | |
}; | |
fn main() { | |
let mut t = [Duration::ZERO; 4]; | |
let n = 10; | |
for _ in 0..n { | |
for (i, dur) in g().into_iter().enumerate() { | |
t[i] += dur; | |
} | |
} | |
println!( | |
"dft={}s fft={}s idft={}s ifft={}s", | |
t[0].as_secs_f32() / n as f32, | |
t[1].as_secs_f32() / n as f32, | |
t[2].as_secs_f32() / n as f32, | |
t[3].as_secs_f32() / n as f32 | |
); | |
} | |
fn g() -> [Duration; 4] { | |
let a = gen_a(); | |
let t0 = Instant::now(); | |
let dft_res = dft(&a, false); | |
let t1 = Instant::now(); | |
let fft_res = fft(&a, false); | |
let t2 = Instant::now(); | |
check(&dft_res, &fft_res); | |
let t3 = Instant::now(); | |
let idft_res = norm(dft(&dft_res, true)); | |
let t4 = Instant::now(); | |
let ifft_res = norm(fft(&fft_res, true)); | |
let t5 = Instant::now(); | |
check(&a, &idft_res); | |
check(&a, &ifft_res); | |
check(&idft_res, &ifft_res); | |
[t1 - t0, t2 - t1, t4 - t3, t5 - t4] | |
} | |
fn gen_a() -> Vec<Complex> { | |
const FFT_SIZE: usize = 2usize.pow(12); | |
let mut rng = File::open("/dev/urandom").unwrap(); | |
let mut tmp = [0; FFT_SIZE * 4]; | |
rng.read_exact(&mut tmp).unwrap(); | |
(0..FFT_SIZE) | |
.map(|i| { | |
let re = i16::from_le_bytes([tmp[i * 4], tmp[i * 4 + 1]]) as f64 / i16::MAX as f64; | |
let im = i16::from_le_bytes([tmp[i * 4 + 2], tmp[i * 4 + 3]]) as f64 / i16::MAX as f64; | |
Complex::new(re, im) | |
}) | |
.collect() | |
} | |
fn check(a: &[Complex], b: &[Complex]) { | |
assert_eq!(a.len(), b.len()); | |
let precision = 2.0f64.powi(-16); | |
for (c0, c1) in zip(a, b) { | |
assert!((c0.re - c1.re).abs() < precision); | |
assert!((c0.im - c1.im).abs() < precision); | |
} | |
} | |
fn norm(mut a: Vec<Complex>) -> Vec<Complex> { | |
let len = a.len() as f64; | |
a.iter_mut().for_each(|c| *c *= 1.0 / len); | |
a | |
} | |
fn dft(a: &[Complex], inv: bool) -> Vec<Complex> { | |
(0..a.len()) | |
.map(|i| { | |
let w_i = if inv { | |
2.0 * PI * -(i as f64) / a.len() as f64 | |
} else { | |
2.0 * PI * i as f64 / a.len() as f64 | |
}; | |
let mut y_i = Complex::default(); | |
for (pow, a_i) in a.iter().enumerate() { | |
let w_i_pow = w_i * pow as f64; | |
y_i += *a_i * Complex::new(w_i_pow.cos(), w_i_pow.sin()); | |
} | |
y_i | |
}) | |
.collect() | |
} | |
fn fft(a: &[Complex], inv: bool) -> Vec<Complex> { | |
let a_len = a.len(); | |
if a_len == 1 { | |
return vec![a[0]]; | |
} | |
let mut a0 = Vec::with_capacity(a_len / 2); | |
let mut a1 = Vec::with_capacity(a_len / 2); | |
(0..a_len).for_each(|i| { | |
if i & 1 == 0 { | |
a0.push(a[i]) | |
} else { | |
a1.push(a[i]) | |
} | |
}); | |
let y0 = fft(&a0, inv); | |
let y1 = fft(&a1, inv); | |
(0..a_len) | |
.map(|i| { | |
let angle = if inv { | |
2.0 * -(i as f64) * PI / a_len as f64 | |
} else { | |
2.0 * i as f64 * PI / a_len as f64 | |
}; | |
let w_i = Complex::new(angle.cos(), angle.sin()); | |
y0[i % (a_len / 2)] + w_i * y1[i % (a_len / 2)] | |
}) | |
.collect() | |
} | |
#[derive(Clone, Copy, Default)] | |
struct Complex { | |
re: f64, | |
im: f64, | |
} | |
impl Complex { | |
fn new(re: f64, im: f64) -> Self { | |
Self { re, im } | |
} | |
} | |
impl std::ops::Mul for Complex { | |
type Output = Self; | |
fn mul(self, rhs: Self) -> Self::Output { | |
Self { | |
re: self.re * rhs.re - self.im * rhs.im, | |
im: self.re * rhs.im + self.im * rhs.re, | |
} | |
} | |
} | |
impl std::ops::Mul<f64> for Complex { | |
type Output = Self; | |
fn mul(self, rhs: f64) -> Self::Output { | |
Self { | |
re: self.re * rhs, | |
im: self.im * rhs, | |
} | |
} | |
} | |
impl std::ops::MulAssign<f64> for Complex { | |
fn mul_assign(&mut self, rhs: f64) { | |
*self = *self * rhs | |
} | |
} | |
impl std::ops::Add for Complex { | |
type Output = Self; | |
fn add(self, rhs: Self) -> Self::Output { | |
Self { | |
re: self.re + rhs.re, | |
im: self.im + rhs.im, | |
} | |
} | |
} | |
impl std::ops::AddAssign for Complex { | |
fn add_assign(&mut self, rhs: Self) { | |
*self = *self + rhs | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment