Skip to content

Instantly share code, notes, and snippets.

@ingoogni
Last active June 10, 2025 13:43
Show Gist options
  • Save ingoogni/cb5bb90c4374ec271c8ff2ef94814865 to your computer and use it in GitHub Desktop.
Save ingoogni/cb5bb90c4374ec271c8ff2ef94814865 to your computer and use it in GitHub Desktop.
fft: stockham avx, Cooly Tukey, window functions
# FFT Stockham
# http://wwwa.pikara.ne.jp/okojisan/otfft-en/optimization1.html
# original source of scalar Nim version by "Amb" :
# https://gist.github.com/amb/4f70bcbea897024452d683b40d18be1f
# Timings
# === Scalar float64 ===
# FFT/IFFT test on 4096 points
# Total time for 100 iterations: (seconds: 0, nanosecond: 23475000)
# Maximum error after FFT/IFFT: 3.2219820586602874e-10
#
# === Scalar float32 ===
# FFT/IFFT test on 4096 points
# Total time for 100 iterations: (seconds: 0, nanosecond: 24718500)
# Maximum error after FFT/IFFT: 0.0818290039896965
#
# === Complex32 ===
# FFT/IFFT test on 4096 points
# Total time for 100 iterations: (seconds: 0, nanosecond: 7036800)
# Maximum error after FFT/IFFT: 0.081829004
#
# === Direct Interleaved ===
# Direct interleaved FFT/IFFT test on 4096 points
# Total time for 100 iterations: (seconds: 0, nanosecond: 4666200)
# Maximum error after interleaved FFT/IFFT: 0.08154297
import std/[complex, math, monotimes]
import nimsimd/sse42
import nimsimd/avx2
when defined(gcc) or defined(clang):
{.passC: "-mavx2 -msse4.2".}
when defined(vcc):
{.localPassC: "/arch:SSE4.2 /arch:AVX2".}
template simdComplexMul4(a, b: M256): M256 =
# a = [re0, im0, re1, im1, re2, im2, re3, im3]
# b = [re0, im0, re1, im1, re2, im2, re3, im3]
let
a_re = mm256_moveldup_ps(a) # [re0,re0,re1,re1,re2,re2,re3,re3]
a_im = mm256_movehdup_ps(a) # [im0,im0,im1,im1,im2,im2,im3,im3]
b_sw = mm256_shuffle_ps(b, b, 0xB1) # [im0,re0,im1,re1,im2,re2,im3,re3]
re_part = mm256_mul_ps(a_re, b) # [re0*re0, re0*im0, re1*re1, re1*im1, ...]
im_part = mm256_mul_ps(a_im, b_sw) # [im0*im0, im0*re0, im1*im1, im1*re1, ...]
mm256_addsub_ps(re_part, im_part) # Combined with correct signs
template simdComplexMul2(a, b: M128): M128 =
let
a_re = mm_moveldup_ps(a) # [re0, re0, re1, re1]
a_im = mm_movehdup_ps(a) # [im0, im0, im1, im1]
b_sw = mm_shuffle_ps(b, b, 0xB1) # [im0, re0, im1, re1]
re_part = mm_mul_ps(a_re, b)
im_part = mm_mul_ps(a_im, b_sw)
mm_addsub_ps(re_part, im_part)
template simdComplexAdd4(a, b: M256): M256 = mm256_add_ps(a, b)
template simdComplexSub4(a, b: M256): M256 = mm256_sub_ps(a, b)
template simdComplexAdd2(a, b: M128): M128 = mm_add_ps(a, b)
template simdComplexSub2(a, b: M128): M128 = mm_sub_ps(a, b)
template loadComplex4(data: ptr float32): M256 =
mm256_loadu_ps(data)
template loadComplex2(data: ptr float32): M128 =
mm_loadu_ps(data)
template storeComplex4(data: ptr float32, v: M256) =
mm256_storeu_ps(data, v)
template storeComplex2(data: ptr float32, v: M128) =
mm_storeu_ps(data, v)
# Compile time LUT for twiddle factors
const thetaLutSize = 32768
const thetaLut = static:
var arr: array[thetaLutSize, Complex32]
let step = 2.0'f32 * PI / float32(thetaLutSize)
for k, v in mpairs(arr):
let angle = step * float32(k)
v = complex(float32(cos(angle)), -float32(sin(angle)))
arr
# Runtime SIMD lookup tables
var thetaLutSSE: array[thetaLutSize, M128]
var thetaLutAVX: array[thetaLutSize, M256]
proc initThetaLutSSE*() =
for k in 0..<thetaLutSize:
let re = thetaLut[k].re
let im = thetaLut[k].im
# Store as [re, im, re, im] for 2 identical complex numbers
thetaLutSSE[k] = mm_setr_ps(re, im, re, im)
proc initThetaLutAVX*() =
for k in 0..<thetaLutSize:
let re = thetaLut[k].re
let im = thetaLut[k].im
# Store as [re, im, re, im, re, im, re, im] for 4 identical complex numbers
thetaLutAVX[k] = mm256_setr_ps(re, im, re, im, re, im, re, im)
# Convert Complex32 array to interleaved float32 array
proc toInterleaved*(x: seq[Complex32]): seq[float32] =
result = newSeq[float32](x.len * 2)
for i in 0..<x.len:
result[i * 2] = x[i].re
result[i * 2 + 1] = x[i].im
# Convert interleaved float32 array back to Complex32 array
proc fromInterleaved*(x: seq[float32]): seq[Complex32] =
assert x.len mod 2 == 0
result = newSeq[Complex32](x.len div 2)
for i in 0..<result.len:
result[i] = complex(x[i * 2], x[i * 2 + 1])
# Core Stockham FFT algorithm with optimized SIMD
proc fftSimdButterfly(
n: int, s: int, eo: bool,
x: var seq[float32], y: var seq[float32]
) =
if n == 1:
if eo:
for q in 0..<s:
y[q * 2] = x[q * 2] # real part
y[q * 2 + 1] = x[q * 2 + 1] # imaginary part
return
let m = n shr 1
let theta0 = float32(thetaLutSize) / float32(n)
for p in 0..<m:
let fp = int(float32(p) * theta0)
let wpSSE = thetaLutSSE[fp]
let wpAVX = thetaLutAVX[fp]
# Process in groups of 4 complex numbers with AVX
var q = 0
while q + 3 < s:
# Load 4 complex pairs from both halves
let aPtr = cast[ptr float32](addr x[(q + s * p) * 2])
let bPtr = cast[ptr float32](addr x[(q + s * (p + m)) * 2])
let a4 = loadComplex4(aPtr)
let b4 = loadComplex4(bPtr)
# Butterfly operations on 4 complex numbers
let sum4 = simdComplexAdd4(a4, b4)
let diff4 = simdComplexSub4(a4, b4)
let mult4 = simdComplexMul4(diff4, wpAVX)
# Store results
let sumPtr = cast[ptr float32](addr y[(q + s * (p shl 1)) * 2])
let multPtr = cast[ptr float32](addr y[(q + s * ((p shl 1) + 1)) * 2])
storeComplex4(sumPtr, sum4)
storeComplex4(multPtr, mult4)
q += 4
# Process remaining elements in groups of 2 with SSE
while q + 1 < s:
let aPtr = cast[ptr float32](addr x[(q + s * p) * 2])
let bPtr = cast[ptr float32](addr x[(q + s * (p + m)) * 2])
let a2 = loadComplex2(aPtr)
let b2 = loadComplex2(bPtr)
let sum2 = simdComplexAdd2(a2, b2)
let diff2 = simdComplexSub2(a2, b2)
let mult2 = simdComplexMul2(diff2, wpSSE)
let sumPtr = cast[ptr float32](addr y[(q + s * (p shl 1)) * 2])
let multPtr = cast[ptr float32](addr y[(q + s * ((p shl 1) + 1)) * 2])
storeComplex2(sumPtr, sum2)
storeComplex2(multPtr, mult2)
q += 2
# Handle any remaining single complex number
while q < s:
let aIdx = (q + s * p) * 2
let bIdx = (q + s * (p + m)) * 2
# Manual complex arithmetic for single element
let are = x[aIdx]
let aim = x[aIdx + 1]
let bre = x[bIdx]
let bim = x[bIdx + 1]
let wre = thetaLut[fp].re
let wim = thetaLut[fp].im
# sum = a + b
let sumRe = are + bre
let sumIm = aim + bim
# diff = a - b
let diffRe = are - bre
let diffIm = aim - bim
# mult = diff * w
let multRe = diffRe * wre - diffIm * wim
let multIm = diffRe * wim + diffIm * wre
# Store results
let sumIdx = (q + s * (p shl 1)) * 2
let multIdx = (q + s * ((p shl 1) + 1)) * 2
y[sumIdx] = sumRe
y[sumIdx + 1] = sumIm
y[multIdx] = multRe
y[multIdx + 1] = multIm
q += 1
fftSimdButterfly(n shr 1, s shl 1, not eo, y, x)
# SIMD-optimized FFT using Stockham algorithm
proc fftSimdInterleaved(x: var seq[float32]) =
let n = x.len div 2 # Number of complex elements
assert n > 0 and n.isPowerOfTwo()
var y = newSeq[float32](x.len)
fftSimdButterfly(n, 1, false, x, y)
proc fft*(x: var seq[Complex32]) =
## FFT with optimized SIMD implementation using float32
## x: seq of complex numbers to transform. Must have power of 2 length.
assert x.len > 0
assert x.len.isPowerOfTwo()
var interleaved = toInterleaved(x)
fftSimdInterleaved(interleaved)
x = fromInterleaved(interleaved)
proc ifft*(x: var seq[Complex32]) =
## IFFT with optimized SIMD implementation using float32
## x: seq of complex numbers to inverse transform. Must have power of 2 length.
let n = x.len
let fn = complex(1.0'f32 / float32(n))
# Conjugate and scale
for p in 0..<n:
x[p] = (x[p] * fn).conjugate
fft(x)
# Conjugate result
for p in 0..<n:
x[p] = x[p].conjugate
# Direct interleaved interface for maximum performance
proc fftInterleaved*(x: var seq[float32]) =
## FFT directly on interleaved float32 data [re, im, re, im, ...]
## x: interleaved complex data. Length must be even and represent
## power of 2 complex numbers.
let n = x.len div 2
assert n > 0 and n.isPowerOfTwo()
fftSimdInterleaved(x)
proc ifftInterleaved*(x: var seq[float32]) =
## IFFT directly on interleaved float32 data [re, im, re, im, ...]
## x: interleaved complex data. Length must be even and represent
## power of 2 complex numbers.
let n = x.len div 2
let fn = 1.0'f32 / float32(n)
# Conjugate and scale in-place
for i in countup(0, x.len - 1, 2):
x[i] = x[i] * fn # scale real part
x[i + 1] = -x[i + 1] * fn # conjugate and scale imaginary part
fftInterleaved(x)
# Conjugate result in-place
for i in countup(1, x.len - 1, 2):
x[i] = -x[i] # negate imaginary parts
when isMainModule:
# Initialize the SIMD lookup tables
initThetaLutSSE()
initThetaLutAVX()
echo "=== Complex32 Interface Test ==="
var testData = newSeq[Complex32](4096)
for i in 0..<testData.len:
testData[i] = complex(float32(i), float32(i * 2))
fft(testData)
ifft(testData)
let start1 = getMonoTime()
const iterations = 100
for _ in 0..<iterations:
fft(testData)
ifft(testData)
let duration1 = getMonoTime() - start1
echo "FFT/IFFT test on ", testData.len, " points"
echo "Total time for ", iterations, " iterations: ", duration1
# Correctness
var originalData = newSeq[Complex32](testData.len)
for i in 0..<testData.len:
originalData[i] = complex(float32(i), float32(i * 2))
fft(testData)
ifft(testData)
var maxError = 0.0'f32
for i in 0..<testData.len:
let error = abs(testData[i] - originalData[i])
if error > maxError:
maxError = error
echo "Maximum error after FFT/IFFT: ", maxError
# Test with direct interleaved interface
echo "\n=== Direct Interleaved Interface ==="
var interleavedTest = newSeq[float32](4096 * 2)
for i in 0..<4096:
interleavedTest[i * 2] = float32(i)
interleavedTest[i * 2 + 1] = float32(i * 2)
# Warmup
fftInterleaved(interleavedTest)
ifftInterleaved(interleavedTest)
# Benchmark direct interface
let start2 = getMonoTime()
for _ in 0..<iterations:
fftInterleaved(interleavedTest)
ifftInterleaved(interleavedTest)
let duration2 = getMonoTime() - start2
echo "Direct interleaved FFT/IFFT test on 4096 points"
echo "Total time for ", iterations, " iterations: ", duration2
# Correctness test for interleaved
var originalInterleaved = newSeq[float32](4096 * 2)
for i in 0..<4096:
originalInterleaved[i * 2] = float32(i)
originalInterleaved[i * 2 + 1] = float32(i * 2)
fftInterleaved(interleavedTest)
ifftInterleaved(interleavedTest)
var maxErrorInterleaved = 0.0'f32
for i in 0..<interleavedTest.len:
let error = abs(interleavedTest[i] - originalInterleaved[i])
if error > maxErrorInterleaved:
maxErrorInterleaved = error
echo "Maximum error after interleaved FFT/IFFT: ", maxErrorInterleaved
import math, complex, strutils
# Works with floats and complex numbers as input
proc fft[T: float | Complex[float]](x: openarray[T]): seq[Complex[float]] =
let n = x.len
if n == 0: return
result.newSeq(n)
if n == 1:
result[0] = (when T is float: complex(x[0]) else: x[0])
return
var evens, odds = newSeq[T]()
for i, v in x:
if i mod 2 == 0: evens.add v
else: odds.add v
var (even, odd) = (fft(evens), fft(odds))
let halfn = n div 2
for k in 0 ..< halfn:
let a = exp(complex(0.0, -TAU * float(k) / float(n))) * odd[k]
result[k] = even[k] + a
result[k + halfn] = even[k] - a
proc ifft(x: var openarray[Complex[float]]): seq[Complex[float]] =
var n: int = x.len
for p in 0..<n:
x[p] = (x[p]).conjugate
var xi = fft(x)
for p in 0..<n:
xi[p] = (xi[p].conjugate) / n.float
return xi
when isMainModule:
var ff = fft(@[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0])
for i in ff:
echo formatFloat(abs(i), ffDecimal, 3)
var iff = ifft(ff)
for i in iff:
echo formatFloat(abs(i), ffDecimal, 3)
import std/[math]
# Window functions
# https://en.wikipedia.org/wiki/Window_function
# Window correction factors
# Correct window loss before ifft either amplitude or energy, not both
# https://community.sw.siemens.com/s/article/window-correction-factors
const
corrHannAmplitude* = 2.0
corrHannEnergy* = 1.63
corrHammingAmplitude* = 1.85
corrHammingEnergy* = 1.59
corrBlackmanAmplitude* = 2.80
corrBlackmanEnergy* = 1.97
corrFlattopAmplitude* = 4.18
corrFlattopEnergy* = 2.26
corrKaiserBesselAmplitude* = 2.49 # for beta=8.6
corrKaiserBesselEnergy* = 1.86 # for beta=8.6
corrNuttalAmplitude* = 2.81
corrNuttalEnergy* = 1.98
corrBlackmanNuttalAmplitude* = 2.82
corrBlackmanNuttalEnergy* = 1.99
corrBlackmanHarrisAmplitude* = 2.79
corrBlackmanHarrisEnergy* = 1.97
corrLanczosAmplitude* = 1.64
corrLanczosEnergy* = 1.30
corrTukeyAmplitude* = 1.33 # for alpha=0.5
corrTukeyEnergy* = 1.22 # for alpha=0.5
corrGaussianAmplitude* = 2.5 # for sigma=0.4
corrGaussianEnergy* = 1.8 # for sigma=0.4
corrBartlettAmplitude* = 2.0
corrBartlettEnergy* = 1.73
corrWelchAmplitude* = 1.5
corrWelchEnergy* = 1.29
corrBohmanAmplitude* = 2.52
corrBohmanEnergy* = 1.79
corrExponentialAmplitude* = 2.71 # for tau=size/4
corrExponentialEnergy* = 2.0 # for tau=size/4
corrDolphChebyshevAmplitude* = 2.0 # varies with sidelobe level
corrDolphChebyshevEnergy* = 1.4 # varies with sidelobe level
# Some Kaiser beta values
const
kaiserBeta5* = 5.0 # General purpose
kaiserBeta6* = 6.0 # Higher sidelobe suppression
kaiserBeta8_6* = 8.6 # Very high sidelobe suppression
proc sinc(x: float): float {.inline.} =
if abs(x) < 1e-6: 1.0 else: sin(x) / x
template hann*(i: int, size: int): float =
0.5 * (1.0 - cos(2.0 * PI * float(i) / float(size - 1)))
proc hannWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = hann(i, size)
template hamming*(i: int, size: int): float =
(25/46) - 0.46 * cos(2.0 * PI * float(i) / float(size - 1))
proc hammingWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = hamming(i, size)
template blackman*(i: int, size: int): float =
0.42659 - 0.49656 * cos(2.0 * PI * float(i) / float(size - 1)) +
0.076849 * cos(4.0 * PI * float(i) / float(size - 1))
proc blackmanWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = blackman(i, size)
template nuttal*(i: int, size: int): float =
0.355768 - 0.487396 * cos(2.0 * PI * float(i) / float(size - 1)) +
0.144232 * cos(4.0 * PI * float(i) / float(size - 1)) -
0.012604 * cos(6.0 * PI * float(i) / float(size - 1))
proc nuttalWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = nuttal(i, size)
template blackmanNuttal*(i: int, size: int): float =
0.3635819 - 0.4891775 * cos(2.0 * PI * float(i) / float(size - 1)) +
0.1365995 * cos(4.0 * PI * float(i) / float(size - 1)) -
0.0106411 * cos(6.0 * PI * float(i) / float(size - 1))
proc blackmanNuttalWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = blackmanNuttal(i, size)
template blackmanHarris*(i: int, size: int): float =
0.35875 - 0.48829 * cos(2.0 * PI * float(i) / float(size - 1)) +
0.14128 * cos(4.0 * PI * float(i) / float(size - 1)) -
0.01168 * cos(6.0 * PI * float(i) / float(size - 1))
proc blackmanHarrisWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = blackmanHarris(i, size)
template flattop*(i: int, size: int): float =
0.21557895 - 0.41663158 * cos(2.0 * PI * float(i) / float(size - 1)) +
0.277263158 * cos(4.0 * PI * float(i) / float(size - 1)) -
0.083578947 * cos(6.0 * PI * float(i) / float(size - 1)) +
0.006947368 * cos(8.0 * PI * float(i) / float(size - 1))
proc flattopWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = flattop(i, size)
template lanczos*(i: int, size: int): float =
sinc(((2 * float(i)) / float(size - 1)) - 1)
proc lanczosWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = lanczos(i, size)
proc besselIO(x: float): float =
var
sum = 1.0
term = 1.0
k = 1
while term > 1e-10:
term *= (x * x) / (4.0 * k.float * k.float)
sum += term
inc k
return sum
proc kaiserBesselWindow*(size: int, beta: float): seq[float] =
let alpha = (size - 1) / 2
result = newSeq[float](size)
for i in 0..<size:
let t = (i.float - alpha) / alpha
result[i] = besselIO(beta * sqrt(1 - t * t)) / besselIO(beta)
template tukey*(i: int, size: int, alpha: float): float =
let n = i.float
let N = (size - 1).float
if n < alpha * N / 2.0:
0.5 * (1.0 + cos(PI * (2.0 * n / (alpha * N) - 1.0)))
elif n > N * (1.0 - alpha / 2.0):
0.5 * (1.0 + cos(PI * (2.0 * n / (alpha * N) - 2.0 / alpha + 1.0)))
else:
1.0
proc tukeyWindow*(size: int, alpha: float = 0.5): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = tukey(i, size, alpha)
template gaussian*(i: int, size: int, sigma: float): float =
let n = i.float - (size - 1).float / 2.0
exp(-0.5 * (n / sigma) * (n / sigma))
proc gaussianWindow*(size: int, sigma: float = 0.4): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = gaussian(i, size, sigma * (size - 1).float / 2.0)
template bartlett*(i: int, size: int): float =
let n = i.float
let N = (size - 1).float
if n <= N / 2.0:
2.0 * n / N
else:
2.0 - 2.0 * n / N
proc bartlettWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = bartlett(i, size)
template welch*(i: int, size: int): float =
let n = i.float - (size - 1).float / 2.0
let N = (size - 1).float / 2.0
1.0 - (n / N) * (n / N)
proc welchWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = welch(i, size)
template bohman*(i: int, size: int): float =
let n = i.float
let N = (size - 1).float
let x = abs(2.0 * n / N - 1.0)
if x < 1.0:
(1.0 - x) * cos(PI * x) + (1.0 / PI) * sin(PI * x)
else:
0.0
proc bohmanWindow*(size: int): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = bohman(i, size)
template exponential*(i: int, size: int, tau: float): float =
let n = i.float - (size - 1).float / 2.0
if tau > 0:
exp(-abs(n) / tau)
else:
if n <= 0: exp(n / abs(tau)) else: 0.0
proc exponentialWindow*(size: int, tau: float): seq[float] =
result = newSeq[float](size)
for i in 0..<size:
result[i] = exponential(i, size, tau)
# Chebyshev polynomial encounter of the first kind
proc chebyshev(n: int, x: float): float =
if n == 0: return 1.0
if n == 1: return x
if abs(x) <= 1.0:
return cos(n.float * arccos(x))
else:
if x > 1.0:
return cosh(n.float * arccosh(x))
else:
return pow(-1.0, n.float) * cosh(n.float * arccosh(-x))
proc dolphChebyshevWindow*(size: int, sidelobeLevel: float): seq[float] =
let M = size - 1
let alpha = cosh(arccosh(pow(10.0, sidelobeLevel / 20.0)) / M.float)
result = newSeq[float](size)
for n in 0..<size:
var sum = 0.0
for k in 0..<M:
let arg = PI * k.float / M.float
sum += chebyshev(M, alpha * cos(arg)) * cos(2.0 * PI * n.float * k.float / size.float)
result[n] = sum / M.float
# Normalize
let maxVal = result.max()
for i in 0..<size:
result[i] = result[i] / maxVal
# Calculate correction factors for any window
proc calculateCorrectionFactors*(window: seq[float]): tuple[amplitude: float, energy: float] =
var sumAmplitude = 0.0
var sumEnergy = 0.0
for w in window:
sumAmplitude += w
sumEnergy += w * w
let amplitude = window.len.float / sumAmplitude
let energy = sqrt(window.len.float / sumEnergy)
return (amplitude, energy)
@dwhall
Copy link

dwhall commented Jun 10, 2025

Thanks for the code examples.

@ingoogni
Copy link
Author

* [turkey](https://en.wikipedia.org/wiki/Wild_turkey) with the 'r'.

* [Tukey](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm), without the 'r'.

Cooley Tukey it is, now.
Cheers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment