Last active
January 25, 2025 09:56
-
-
Save jacky860226/3e3baffcc48d8bb4c84d5218c916716c to your computer and use it in GitHub Desktop.
Cooley-Tukey Algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <typename AlgorithmTy> | |
auto convolution(typename AlgorithmTy::vector_type A, | |
typename AlgorithmTy::vector_type B) { | |
using Policy = typename AlgorithmTy::policy; | |
using vector_type = typename AlgorithmTy::vector_type; | |
if (A.empty() || B.empty()) return vector_type{}; | |
size_t C_size = A.size() + B.size() - 1, N = C_size; | |
while (N & (N - 1)) ++N; | |
A.resize(N), B.resize(N); | |
A = AlgorithmTy().run(A, false), B = AlgorithmTy().run(B, false); | |
vector_type C(N); | |
for (size_t i = 0; i < N; ++i) C[i] = Policy::mul(A[i], B[i]); | |
C = AlgorithmTy().run(C, true); | |
C.resize(C_size); | |
return C; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cassert> | |
#include <cstddef> | |
template <typename T, typename Policy> | |
class CooleyTukeyAlgorithm { | |
size_t reverse_bits_len(size_t N, size_t len) { | |
return ::reverse_bits(N) >> (sizeof(N) * 8 - len); | |
} | |
public: | |
using policy = Policy; | |
using vector_type = typename Policy::vector_type; | |
auto run(const vector_type& in, bool is_inv) { | |
size_t N = in.size(); | |
assert((N & (N - 1)) == 0 && Policy::check(N)); | |
vector_type out(N); | |
for (size_t i = 0; i < N; ++i) | |
out[reverse_bits_len(i, std::__lg(N))] = in[i]; | |
for (size_t step = 2; step <= N; step *= 2) { | |
auto wn = Policy::omega(step), wk = Policy::one(); | |
const size_t helf_step = step / 2; | |
for (size_t i = 0; i < helf_step; ++i) { | |
for (size_t k = i; k < N; k += step) { | |
size_t j = k + helf_step; | |
auto u = out[k], t = Policy::mul(wk, out[j]); | |
out[k] = Policy::add(u, t); | |
out[j] = Policy::sub(u, t); | |
} | |
wk = Policy::mul(wk, wn); | |
} | |
} | |
if (is_inv) { | |
std::reverse(out.begin() + 1, out.end()); | |
auto inv_N = Policy::inverse(N); | |
for (size_t i = 0; i < N; ++i) out[i] = Policy::mul(out[i], inv_N); | |
} | |
return out; | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cassert> | |
#include <cstddef> | |
template <typename T, typename Policy> | |
class CooleyTukeyAlgorithmRecursive { | |
public: | |
using policy = Policy; | |
using vector_type = typename Policy::vector_type; | |
private: | |
// Input: 係數表示法 F(x) := [f[0], f[1], ..., f[n-1]] | |
// Output: 點值表示法 F(x) := [fY[0], fY[1], ..., fY[n-1]], fY[i] = F(w(n)^i) | |
auto divide_and_conquer(vector_type f) { | |
size_t n = f.size(); | |
if (n <= 1) return f; | |
vector_type g(n / 2), h(n / 2); | |
for (size_t i = 0; i < n; ++i) { // 根據奇偶分類 | |
if (i % 2 == 0) | |
g[i / 2] = f[i]; | |
else | |
h[i / 2] = f[i]; | |
} | |
auto gY = divide_and_conquer(g); // 得到 gY[i] = G(w(n/2)^i) | |
auto hY = divide_and_conquer(h); // 得到 hY[i] = H(w(n/2)^i) | |
vector_type fY(n); | |
auto wn = Policy::omega(n), wk = Policy::one(); | |
for (size_t k = 0; k < n / 2; ++k) { | |
auto u = gY[k], t = Policy::mul(wk, hY[k]); | |
fY[k] = Policy::add(u, t); | |
fY[k + n / 2] = Policy::sub(u, t); | |
wk = Policy::mul(wk, wn); | |
} | |
return fY; // 得到 fY[i] = F(w(n)^i) | |
} | |
public: | |
auto run(const vector_type& in, bool is_inv) { | |
size_t N = in.size(); | |
assert((N & (N - 1)) == 0 && Policy::check(N)); // N 必須是 2 的冪次 | |
auto out = divide_and_conquer(in); | |
if (is_inv) { // 逆變換 | |
std::reverse(out.begin() + 1, out.end()); | |
auto inv_N = Policy::inverse(N); | |
for (size_t i = 0; i < N; ++i) out[i] = Policy::mul(out[i], inv_N); | |
} | |
return out; | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cmath> | |
#include <complex> | |
#include <vector> | |
template <typename T, typename ComplexTy = std::complex<T>> | |
struct FFT_Policy { | |
using vector_type = std::vector<ComplexTy>; | |
static constexpr T pi = std::acos((T)-1); | |
static bool check(size_t N) { return true; } | |
static auto one() { return ComplexTy(1, 0); } | |
static auto omega(size_t N) { | |
return std::exp(ComplexTy(0, 2 * pi / N)); | |
} | |
static auto inverse(T value) { return T(1) / value; } | |
static auto add(ComplexTy a, ComplexTy b) { return a + b; } | |
static auto sub(ComplexTy a, ComplexTy b) { return a - b; } | |
static auto mul(ComplexTy a, ComplexTy b) { return a * b; } | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
auto get_bit_reverse_table(size_t N){ | |
std::vector<size_t> table(N, 0); | |
for(size_t i = 1; i < N; ++i){ | |
table[i] = table[i >> 1] >> 1; | |
if(i & 1) table[i] += N >> 1; | |
} | |
return table; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename VectorTy> | |
void displacement(VectorTy &V){ | |
size_t N = V.size(); | |
for(int i = 0; i < N; ++i){ | |
size_t rev_i = reverse_bits(i) >> (sizeof(i) * 8 - N); | |
if(i < rev_i) std::swap(V[i], V[rev_i]); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// It is recommended that T at least be long long | |
#include <vector> | |
template <typename T, T P, T G> | |
struct NTT_Policy { | |
using vector_type = std::vector<T>; | |
static T pow_mod(T n, T k, T m) { | |
T ans = 1; | |
for (n %= m; k; k >>= 1) { | |
if (k & 1) ans = ans * n % m; | |
n = n * n % m; | |
} | |
return ans; | |
} | |
static bool check(size_t N) { return N <= 1 || P % N == 1; } | |
static auto one() { return T(1); } | |
static auto omega(size_t N) { | |
return pow_mod(G, (P - 1) / N, P); | |
} | |
static auto inverse(T value) { return pow_mod(value, P - 2, P); } | |
static auto add(T a, T b) { return (a + b) % P; } | |
static auto sub(T a, T b) { return ((a - b) % P + P) % P; } | |
static auto mul(T a, T b) { return a * b % P; } | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <type_traits> | |
template <typename T> | |
inline T reverse_bits(T n) { | |
using unsigned_T = typename std::make_unsigned<T>::type; | |
unsigned_T v = (unsigned_T)n; | |
v = ((v & 0xAAAAAAAAAAAAAAAA) >> 1) | ((v & 0x5555555555555555) << 1); | |
v = ((v & 0xCCCCCCCCCCCCCCCC) >> 2) | ((v & 0x3333333333333333) << 2); | |
v = ((v & 0xF0F0F0F0F0F0F0F0) >> 4) | ((v & 0x0F0F0F0F0F0F0F0F) << 4); | |
if constexpr (sizeof(T) == 1) return v; | |
v = ((v & 0xFF00FF00FF00FF00) >> 8) | ((v & 0x00FF00FF00FF00FF) << 8); | |
if constexpr (sizeof(T) == 2) return v; | |
v = ((v & 0xFFFF0000FFFF0000) >> 16) | ((v & 0x0000FFFF0000FFFF) << 16); | |
if constexpr (sizeof(T) <= 4) | |
return v; | |
else | |
v = ((v & 0xFFFFFFFF00000000) >> 32) | ((v & 0x00000000FFFFFFFF) << 32); | |
return (T)v; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <initializer_list> | |
#include <iostream> | |
template <typename ValueTy> | |
auto naiveMethod(std::vector<ValueTy> A, std::vector<ValueTy> B) { | |
if (A.empty() || B.empty()) return std::vector<ValueTy>{}; | |
std::vector<ValueTy> C(A.size() + B.size() - 1); | |
for (size_t i = 0; i < A.size(); ++i) { | |
for (size_t j = 0; j < B.size(); ++j) { | |
C[i + j] += A[i] * B[j]; | |
} | |
} | |
return C; | |
} | |
template <typename AlgorithmTy> | |
void test(typename AlgorithmTy::vector_type A, | |
typename AlgorithmTy::vector_type B) { | |
auto Res = convolution<AlgorithmTy>(A, B); | |
for (auto x : Res) std::cout << x << ' '; | |
std::cout << std::endl; | |
} | |
int main() { | |
std::cout << std::fixed; | |
std::cout.precision(1); | |
using NTT = | |
CooleyTukeyAlgorithm<long long, | |
NTT_Policy<long long, (1 << 23) * 7 * 17 + 1, 3>>; | |
test<NTT>({1, 2, 3, 4}, {5, 6, 7, 8, 9}); | |
using FFT = CooleyTukeyAlgorithm<double, FFT_Policy<double>>; | |
test<FFT>({1, 2, 3, 4}, {5, 6, 7, 8, 9}); | |
auto C = naiveMethod<long long>({1, 2, 3, 4}, {5, 6, 7, 8, 9}); | |
for (auto x : C) std::cout << x << ' '; | |
std::cout << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment