Created
December 10, 2025 06:11
-
-
Save rygorous/d100b38b5928ceaa871230fda9254f3e to your computer and use it in GitHub Desktop.
Core for a 2x unrolled radix-2 (not really radix-4) FFT kernel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // The FFT alg used here was designed to be very FMA-friendly, but because we can't assume FMAs are present on | |
| // all target HW and want consistent results everywhere, we're using FMA-less algorithms for this application. | |
| // Notation throughout this file: | |
| // | |
| // Let z = a + bi. Then conj(z) = a - bi. | |
| // | |
| // We can swap the real and imaginary parts of z to yield s(z) = b + ai ("swap"). | |
| // Now because | |
| // | |
| // iz = i(a + bi) = -b + ai | |
| // | |
| // we get s(z) = i z*, which is convenient to manipulate this algebraically. Now | |
| // obviously from the definition, we have | |
| // | |
| // s(s(z)) = z | |
| // | |
| // Regular complex arithmetic rules further give the identities | |
| // | |
| // s(z + w) = i conj(z + w) = i conj(z) + i conj(w) = s(z) + s(w) | |
| // s(zw) = i conj(zw) = i conj(z) conj(w) = s(z) conj(w) = conj(z) s(w) | |
| // | |
| // Note the "swap identity" (by the multiplication rule applied twice) | |
| // | |
| // s(s(z) s(w)) = s(s(z) i conj(w)) = s(i s(zw)) = conj(i) s(s(zw)) = -i zw | |
| // | |
| // We mostly work with split real/imaginary parts throughout this file, so these | |
| // swaps are "free" (just a matter of renaming variables). This lets us reduce | |
| // complex multiplications -izw to regular complex multiplications zw with some | |
| // swapping of the real/imaginary parts. (izw can also be handled by computing | |
| // -izw as noted, and then folding a negate into the uses.) | |
| enum FftSign | |
| { | |
| FftSign_Negative = 0, // Negative exponential (the customary "forward" FFT) | |
| FftSign_Positive, // Positive exponential (customary "inverse" FFT) | |
| }; | |
| // Burst-layout radix 2 FFTs | |
| // First kBurstSize real values, then kBurstSize corresponding imaginary values, then kBurstSize reals again, and so forth. | |
| static size_t constexpr kBurstSize = 16; | |
| static size_t constexpr kBurstMask = kBurstSize - 1; | |
| static constexpr inline size_t burst_swizzle(size_t i) | |
| { | |
| // logically, this is: | |
| // return (i & kBurstMask) + ((i & ~kBurstMask) << 1); | |
| // but note that i = (i & kBurstMask) + (i & ~kBurstMask) | |
| // and ((i & ~kBurstMask) << 1) == (i & ~kBurstMask) + (i & ~kBurstMask) | |
| // therefore: | |
| return i + (i & ~kBurstMask); | |
| } | |
| template<typename T> | |
| static RADFORCEINLINE void swap(T& a, T& b) | |
| { | |
| T t { a }; | |
| a = b; | |
| b = t; | |
| } | |
| // Twiddle bt = b * w | |
| // a' = a + bt | |
| // b' = a - bt | |
| template<typename T> | |
| static RADFORCEINLINE void radix2_twiddle_unfused(T& ar, T& ai, T& br, T& bi, T wr, T wi) | |
| { | |
| T btr = br*wr - bi*wi; | |
| T bti = bi*wr + br*wi; | |
| T in_ar = ar; | |
| T in_ai = ai; | |
| ar = in_ar + btr; | |
| ai = in_ai + bti; | |
| br = in_ar - btr; | |
| bi = in_ai - bti; | |
| } | |
| // The FFT kernel is parameterized by an "Elem" type that gives shared functionality and determines | |
| // the vector width. The bitrev + initial passes need to work slightly differently as the vector width | |
| // increases, so the real impl has that in here as well | |
| struct ElemF32x4 | |
| { | |
| static constexpr size_t kCount = 4; | |
| __m128 v; | |
| ElemF32x4() {} | |
| explicit ElemF32x4(__m128 x) : v(x) {} | |
| static ElemF32x4 load(float const* ptr) { return ElemF32x4(_mm_loadu_ps(ptr)); } | |
| void store(float* ptr) { _mm_storeu_ps(ptr, v); } | |
| ElemF32x4 operator+(ElemF32x4 b) const { return ElemF32x4(_mm_add_ps(v, b.v)); } | |
| ElemF32x4 operator-(ElemF32x4 b) const { return ElemF32x4(_mm_sub_ps(v, b.v)); } | |
| ElemF32x4 operator*(ElemF32x4 b) const { return ElemF32x4(_mm_mul_ps(v, b.v)); } | |
| ElemF32x4 reverse() const { return ElemF32x4(_mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 1, 2, 3))); } | |
| static RADFORCEINLINE void radix2_twiddle( | |
| ElemF32x4& ar, ElemF32x4& ai, ElemF32x4& br, ElemF32x4& bi, ElemF32x4 wr, ElemF32x4 wi | |
| ) | |
| { | |
| radix2_twiddle_unfused(ar, ai, br, bi, wr, wi); | |
| } | |
| static RADFORCEINLINE void load_deinterleave(ElemF32x4& re, ElemF32x4& im, float const* ptr) | |
| { | |
| // Load pair of vectors | |
| const __m128 a0 = _mm_load_ps(ptr); | |
| const __m128 a1 = _mm_load_ps(ptr + 4); | |
| // Deinterleave to real/imaginary parts | |
| re.v = _mm_shuffle_ps(a0, a1, _MM_SHUFFLE(2, 0, 2, 0)); | |
| im.v = _mm_shuffle_ps(a0, a1, _MM_SHUFFLE(3, 1, 3, 1)); | |
| } | |
| static RADFORCEINLINE void transpose4x4(ElemF32x4& A, ElemF32x4& B, ElemF32x4& C, ElemF32x4 & D) | |
| { | |
| // Pass 1 | |
| const __m128 t0 = _mm_unpacklo_ps(A.v, C.v); | |
| const __m128 t1 = _mm_unpacklo_ps(B.v, D.v); | |
| const __m128 t2 = _mm_unpackhi_ps(A.v, C.v); | |
| const __m128 t3 = _mm_unpackhi_ps(B.v, D.v); | |
| // Pass 2 | |
| A.v = _mm_unpacklo_ps(t0, t1); | |
| B.v = _mm_unpackhi_ps(t0, t1); | |
| C.v = _mm_unpacklo_ps(t2, t3); | |
| D.v = _mm_unpackhi_ps(t2, t3); | |
| } | |
| static void store_interleaved(float * dest, ElemF32x4 re, ElemF32x4 im) | |
| { | |
| __m128 i0 = _mm_unpacklo_ps(re.v, im.v); | |
| __m128 i1 = _mm_unpackhi_ps(re.v, im.v); | |
| _mm_store_ps(dest + 0, i0); | |
| _mm_store_ps(dest + 4, i1); | |
| } | |
| }; | |
| template<typename T> | |
| static void burst_r4_fft_single_pass(float * out, size_t step, size_t swiz_N, FftSign sign) | |
| { | |
| float const *twiddle1_i = s_fft_twiddles + step*2; | |
| float const *twiddle1_r = twiddle1_i + step/2; | |
| float const *twiddle2_i = s_fft_twiddles + step*4; | |
| float const *twiddle2_r = twiddle2_i + step; | |
| // NOTE: this doesn't work unless step >= T::kCount | |
| // i.e. our initial "regular" level is determined by vector width | |
| const size_t twiddle_mask = step - 1; | |
| size_t swiz_dec = burst_swizzle(~((3 * step) | (T::kCount - 1))); | |
| float * outA = out; | |
| float * outB = out + burst_swizzle(1 * step); | |
| float * outC = out + burst_swizzle(2 * step); | |
| float * outD = out + burst_swizzle(3 * step); | |
| // Defaults to B/D swapped; see below | |
| float * outWrB = outD; | |
| float * outWrD = outB; | |
| // Advance in sine table by half the phase if negative sign requested | |
| // Also swap B/D outputs which effectively swaps our twiddle from +i to -i, see below. | |
| if (sign == FftSign_Negative) | |
| { | |
| twiddle1_i += step; | |
| twiddle2_i += step*2; | |
| swap(outWrB, outWrD); | |
| } | |
| // This is actually radix 2^2, not radix 4. | |
| for (size_t j = 0, k = 0; j < swiz_N; ) | |
| { | |
| T ar = T::load(&outA[j + 0*kBurstSize]); | |
| T ai = T::load(&outA[j + 1*kBurstSize]); | |
| T br = T::load(&outB[j + 0*kBurstSize]); | |
| T bi = T::load(&outB[j + 1*kBurstSize]); | |
| T cr = T::load(&outC[j + 0*kBurstSize]); | |
| T ci = T::load(&outC[j + 1*kBurstSize]); | |
| T dr = T::load(&outD[j + 0*kBurstSize]); | |
| T di = T::load(&outD[j + 1*kBurstSize]); | |
| // First stage twiddle b * w0, d * w0 | |
| T w0r = T::load(&twiddle1_r[k]); | |
| T w0i = T::load(&twiddle1_i[k]); | |
| // First stage butterflies | |
| T::radix2_twiddle(ar, ai, br, bi, w0r, w0i); | |
| T::radix2_twiddle(cr, ci, dr, di, w0r, w0i); | |
| // Second stage twiddle | |
| T w1r = T::load(&twiddle2_r[k]); | |
| T w1i = T::load(&twiddle2_i[k]); | |
| // Second stage butterfly and output | |
| T::radix2_twiddle(ar, ai, cr, ci, w1r, w1i); | |
| ar.store(&outA[j + 0*kBurstSize]); | |
| ai.store(&outA[j + 1*kBurstSize]); | |
| cr.store(&outC[j + 0*kBurstSize]); | |
| ci.store(&outC[j + 1*kBurstSize]); | |
| // w2 is exactly a quarter-circle away | |
| // i.e. multiply by i. w2 = i*(w1r + i*w1i) = i*w1r - w1i = (-w1i) + i*w1r | |
| // | |
| // Note "swap identity" above. We want | |
| // b', d' = b +- w2 * d | |
| // = b +- i (w1 * d) | |
| // = b -+ -i (w1 * d) | |
| // = b -+ s(s(w1) s(d)) | |
| // <=> | |
| // d', b' = b +- s(s(w1) s(d)) | |
| // = s(s(b) +- s(w1) s(d)) | |
| // | |
| // (note d', b' swapped places). | |
| // Therefore, we compute a regular twiddled r2 butterfly with real/imag parts | |
| // of b, w1 and d swapped (on both the inputs and outputs). That takes care of | |
| // the s() applications. Finally for the positive FFT sign we also swap the | |
| // output B and D pointer (this happens outside). For the negative FFT sign | |
| // we need to multiply by -i not i to begin with, so we get another swap of | |
| // output B and D which cancels out the first one. | |
| T::radix2_twiddle(bi, br, di, dr, w1i, w1r); | |
| br.store(&outWrB[j + 0*kBurstSize]); | |
| bi.store(&outWrB[j + 1*kBurstSize]); | |
| dr.store(&outWrD[j + 0*kBurstSize]); | |
| di.store(&outWrD[j + 1*kBurstSize]); | |
| j = (j - swiz_dec) & swiz_dec; | |
| k = (k + T::kCount) & twiddle_mask; | |
| } | |
| } | |
| // The complex FFT driver func. | |
| static void radaudio_fft(float *out, float const *in, size_t N, FftSign sign, radaudio_fft_impl::FftKernelSet const * kernels) | |
| { | |
| using namespace radaudio_fft_impl; | |
| const size_t swiz_N = burst_swizzle(N); | |
| RR_ASSERT(16 <= N && N <= kMaxFFTN); | |
| RR_ASSERT((N & (N - 1)) == 0); // checks for pow2 | |
| // This parts needs to do the bit reverse and the initial passes | |
| // (until sub-FFTs are not within a single vector reg anymore) | |
| size_t const initial_step = kernels->initial(out, in, N, sign); | |
| // For the size we support here, an iterative FFT is always fine since we're | |
| // comfortably in the L1D cache (the largest FFT we do is 512 complex elements, | |
| // which is 4K of data). | |
| // Iteratively do all the CT passes for increasing N (DIT order), indexed by step size (which is N/4) | |
| // this is a pointer to an instantiation of burst_r4_fft_single_pass above | |
| for (size_t step = initial_step; step <= N / 4; step *= 4) | |
| kernels->cfft_pass(out, step, swiz_N, sign); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment