Skip to content

Instantly share code, notes, and snippets.

@rygorous
Created December 10, 2025 06:11
Show Gist options
  • Select an option

  • Save rygorous/d100b38b5928ceaa871230fda9254f3e to your computer and use it in GitHub Desktop.

Select an option

Save rygorous/d100b38b5928ceaa871230fda9254f3e to your computer and use it in GitHub Desktop.
Core for a 2x unrolled radix-2 (not really radix-4) FFT kernel
// The FFT alg used here was designed to be very FMA-friendly, but because we can't assume FMAs are present on
// all target HW and want consistent results everywhere, we're using FMA-less algorithms for this application.
// Notation throughout this file:
//
// Let z = a + bi. Then conj(z) = a - bi.
//
// We can swap the real and imaginary parts of z to yield s(z) = b + ai ("swap").
// Now because
//
// iz = i(a + bi) = -b + ai
//
// we get s(z) = i z*, which is convenient to manipulate this algebraically. Now
// obviously from the definition, we have
//
// s(s(z)) = z
//
// Regular complex arithmetic rules further give the identities
//
// s(z + w) = i conj(z + w) = i conj(z) + i conj(w) = s(z) + s(w)
// s(zw) = i conj(zw) = i conj(z) conj(w) = s(z) conj(w) = conj(z) s(w)
//
// Note the "swap identity" (by the multiplication rule applied twice)
//
// s(s(z) s(w)) = s(s(z) i conj(w)) = s(i s(zw)) = conj(i) s(s(zw)) = -i zw
//
// We mostly work with split real/imaginary parts throughout this file, so these
// swaps are "free" (just a matter of renaming variables). This lets us reduce
// complex multiplications -izw to regular complex multiplications zw with some
// swapping of the real/imaginary parts. (izw can also be handled by computing
// -izw as noted, and then folding a negate into the uses.)
enum FftSign
{
FftSign_Negative = 0, // Negative exponential (the customary "forward" FFT)
FftSign_Positive, // Positive exponential (customary "inverse" FFT)
};
// Burst-layout radix 2 FFTs
// First kBurstSize real values, then kBurstSize corresponding imaginary values, then kBurstSize reals again, and so forth.
static size_t constexpr kBurstSize = 16;
static size_t constexpr kBurstMask = kBurstSize - 1;
static constexpr inline size_t burst_swizzle(size_t i)
{
// logically, this is:
// return (i & kBurstMask) + ((i & ~kBurstMask) << 1);
// but note that i = (i & kBurstMask) + (i & ~kBurstMask)
// and ((i & ~kBurstMask) << 1) == (i & ~kBurstMask) + (i & ~kBurstMask)
// therefore:
return i + (i & ~kBurstMask);
}
template<typename T>
static RADFORCEINLINE void swap(T& a, T& b)
{
T t { a };
a = b;
b = t;
}
// Twiddle bt = b * w
// a' = a + bt
// b' = a - bt
template<typename T>
static RADFORCEINLINE void radix2_twiddle_unfused(T& ar, T& ai, T& br, T& bi, T wr, T wi)
{
T btr = br*wr - bi*wi;
T bti = bi*wr + br*wi;
T in_ar = ar;
T in_ai = ai;
ar = in_ar + btr;
ai = in_ai + bti;
br = in_ar - btr;
bi = in_ai - bti;
}
// The FFT kernel is parameterized by an "Elem" type that gives shared functionality and determines
// the vector width. The bitrev + initial passes need to work slightly differently as the vector width
// increases, so the real impl has that in here as well
struct ElemF32x4
{
static constexpr size_t kCount = 4;
__m128 v;
ElemF32x4() {}
explicit ElemF32x4(__m128 x) : v(x) {}
static ElemF32x4 load(float const* ptr) { return ElemF32x4(_mm_loadu_ps(ptr)); }
void store(float* ptr) { _mm_storeu_ps(ptr, v); }
ElemF32x4 operator+(ElemF32x4 b) const { return ElemF32x4(_mm_add_ps(v, b.v)); }
ElemF32x4 operator-(ElemF32x4 b) const { return ElemF32x4(_mm_sub_ps(v, b.v)); }
ElemF32x4 operator*(ElemF32x4 b) const { return ElemF32x4(_mm_mul_ps(v, b.v)); }
ElemF32x4 reverse() const { return ElemF32x4(_mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 1, 2, 3))); }
static RADFORCEINLINE void radix2_twiddle(
ElemF32x4& ar, ElemF32x4& ai, ElemF32x4& br, ElemF32x4& bi, ElemF32x4 wr, ElemF32x4 wi
)
{
radix2_twiddle_unfused(ar, ai, br, bi, wr, wi);
}
static RADFORCEINLINE void load_deinterleave(ElemF32x4& re, ElemF32x4& im, float const* ptr)
{
// Load pair of vectors
const __m128 a0 = _mm_load_ps(ptr);
const __m128 a1 = _mm_load_ps(ptr + 4);
// Deinterleave to real/imaginary parts
re.v = _mm_shuffle_ps(a0, a1, _MM_SHUFFLE(2, 0, 2, 0));
im.v = _mm_shuffle_ps(a0, a1, _MM_SHUFFLE(3, 1, 3, 1));
}
static RADFORCEINLINE void transpose4x4(ElemF32x4& A, ElemF32x4& B, ElemF32x4& C, ElemF32x4 & D)
{
// Pass 1
const __m128 t0 = _mm_unpacklo_ps(A.v, C.v);
const __m128 t1 = _mm_unpacklo_ps(B.v, D.v);
const __m128 t2 = _mm_unpackhi_ps(A.v, C.v);
const __m128 t3 = _mm_unpackhi_ps(B.v, D.v);
// Pass 2
A.v = _mm_unpacklo_ps(t0, t1);
B.v = _mm_unpackhi_ps(t0, t1);
C.v = _mm_unpacklo_ps(t2, t3);
D.v = _mm_unpackhi_ps(t2, t3);
}
static void store_interleaved(float * dest, ElemF32x4 re, ElemF32x4 im)
{
__m128 i0 = _mm_unpacklo_ps(re.v, im.v);
__m128 i1 = _mm_unpackhi_ps(re.v, im.v);
_mm_store_ps(dest + 0, i0);
_mm_store_ps(dest + 4, i1);
}
};
template<typename T>
static void burst_r4_fft_single_pass(float * out, size_t step, size_t swiz_N, FftSign sign)
{
float const *twiddle1_i = s_fft_twiddles + step*2;
float const *twiddle1_r = twiddle1_i + step/2;
float const *twiddle2_i = s_fft_twiddles + step*4;
float const *twiddle2_r = twiddle2_i + step;
// NOTE: this doesn't work unless step >= T::kCount
// i.e. our initial "regular" level is determined by vector width
const size_t twiddle_mask = step - 1;
size_t swiz_dec = burst_swizzle(~((3 * step) | (T::kCount - 1)));
float * outA = out;
float * outB = out + burst_swizzle(1 * step);
float * outC = out + burst_swizzle(2 * step);
float * outD = out + burst_swizzle(3 * step);
// Defaults to B/D swapped; see below
float * outWrB = outD;
float * outWrD = outB;
// Advance in sine table by half the phase if negative sign requested
// Also swap B/D outputs which effectively swaps our twiddle from +i to -i, see below.
if (sign == FftSign_Negative)
{
twiddle1_i += step;
twiddle2_i += step*2;
swap(outWrB, outWrD);
}
// This is actually radix 2^2, not radix 4.
for (size_t j = 0, k = 0; j < swiz_N; )
{
T ar = T::load(&outA[j + 0*kBurstSize]);
T ai = T::load(&outA[j + 1*kBurstSize]);
T br = T::load(&outB[j + 0*kBurstSize]);
T bi = T::load(&outB[j + 1*kBurstSize]);
T cr = T::load(&outC[j + 0*kBurstSize]);
T ci = T::load(&outC[j + 1*kBurstSize]);
T dr = T::load(&outD[j + 0*kBurstSize]);
T di = T::load(&outD[j + 1*kBurstSize]);
// First stage twiddle b * w0, d * w0
T w0r = T::load(&twiddle1_r[k]);
T w0i = T::load(&twiddle1_i[k]);
// First stage butterflies
T::radix2_twiddle(ar, ai, br, bi, w0r, w0i);
T::radix2_twiddle(cr, ci, dr, di, w0r, w0i);
// Second stage twiddle
T w1r = T::load(&twiddle2_r[k]);
T w1i = T::load(&twiddle2_i[k]);
// Second stage butterfly and output
T::radix2_twiddle(ar, ai, cr, ci, w1r, w1i);
ar.store(&outA[j + 0*kBurstSize]);
ai.store(&outA[j + 1*kBurstSize]);
cr.store(&outC[j + 0*kBurstSize]);
ci.store(&outC[j + 1*kBurstSize]);
// w2 is exactly a quarter-circle away
// i.e. multiply by i. w2 = i*(w1r + i*w1i) = i*w1r - w1i = (-w1i) + i*w1r
//
// Note "swap identity" above. We want
// b', d' = b +- w2 * d
// = b +- i (w1 * d)
// = b -+ -i (w1 * d)
// = b -+ s(s(w1) s(d))
// <=>
// d', b' = b +- s(s(w1) s(d))
// = s(s(b) +- s(w1) s(d))
//
// (note d', b' swapped places).
// Therefore, we compute a regular twiddled r2 butterfly with real/imag parts
// of b, w1 and d swapped (on both the inputs and outputs). That takes care of
// the s() applications. Finally for the positive FFT sign we also swap the
// output B and D pointer (this happens outside). For the negative FFT sign
// we need to multiply by -i not i to begin with, so we get another swap of
// output B and D which cancels out the first one.
T::radix2_twiddle(bi, br, di, dr, w1i, w1r);
br.store(&outWrB[j + 0*kBurstSize]);
bi.store(&outWrB[j + 1*kBurstSize]);
dr.store(&outWrD[j + 0*kBurstSize]);
di.store(&outWrD[j + 1*kBurstSize]);
j = (j - swiz_dec) & swiz_dec;
k = (k + T::kCount) & twiddle_mask;
}
}
// The complex FFT driver func.
static void radaudio_fft(float *out, float const *in, size_t N, FftSign sign, radaudio_fft_impl::FftKernelSet const * kernels)
{
using namespace radaudio_fft_impl;
const size_t swiz_N = burst_swizzle(N);
RR_ASSERT(16 <= N && N <= kMaxFFTN);
RR_ASSERT((N & (N - 1)) == 0); // checks for pow2
// This parts needs to do the bit reverse and the initial passes
// (until sub-FFTs are not within a single vector reg anymore)
size_t const initial_step = kernels->initial(out, in, N, sign);
// For the size we support here, an iterative FFT is always fine since we're
// comfortably in the L1D cache (the largest FFT we do is 512 complex elements,
// which is 4K of data).
// Iteratively do all the CT passes for increasing N (DIT order), indexed by step size (which is N/4)
// this is a pointer to an instantiation of burst_r4_fft_single_pass above
for (size_t step = initial_step; step <= N / 4; step *= 4)
kernels->cfft_pass(out, step, swiz_N, sign);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment