Last active
May 20, 2017 20:41
-
-
Save MaskRay/fac2042058dd5d9e59953f18f3f3978a to your computer and use it in GitHub Desktop.
FFT & NTT benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cassert> | |
#include <chrono> | |
#include <cmath> | |
#include <complex> | |
#include <cstdint> | |
#include <cstdlib> | |
#include <iostream> | |
#include <numeric> | |
#include <string> | |
#include <type_traits> | |
#include <utility> | |
#include <vector> | |
using namespace std; | |
typedef complex<double> cd; | |
#define ALL(x) (x).begin(), (x).end() | |
#define FOR(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (a); i < (b); i++) | |
#define REP(i, n) FOR(i, 0, n) | |
const long P_int = 998244353, M_int = (1L<<61)/P_int, G_int = 3; // 998244353 = 7*17*2**23+1 | |
const long P_long = 1000000000949747713, G_long = 3; | |
typedef uint64_t u64; | |
typedef int64_t i64; | |
long times = 1; | |
const long NN = 1<<23; | |
extern inline long inv(long a, long b) | |
{ | |
long u = 1, x = 0, q, bb = b; | |
while (a) { | |
q = b/a; | |
swap(x -= q*u, u); | |
swap(b -= q*a, a); | |
} | |
if (x < 0) x += bb; | |
return x; | |
} | |
extern inline int mul_mod(int a, int b, int m) | |
{ | |
return long(a)*b%m; | |
} | |
extern inline long mul_mod(long a, long b, long m) | |
{ | |
auto x = (unsigned long)a*b; | |
auto y = m*(unsigned long)((long double)a*(long double)b/m+0.5); | |
long r = x-y; | |
if (r < 0) | |
r += m; | |
return r; | |
} | |
extern inline long pow_mod(long a, long b, long mod) | |
{ | |
long r = 1; | |
for (; b; b >>= 1) { | |
if (b & 1) | |
r = mul_mod(r, a, mod); | |
a = mul_mod(a, a, mod); | |
} | |
return r; | |
} | |
template<typename T> | |
vector<T> setup(long n) | |
{ | |
vector<T> a(n); | |
iota(ALL(a), 0); | |
return a; | |
} | |
namespace Montgomery | |
{ | |
extern inline u64 barrett30(u64 a, u64 P, u64 M) | |
{ // 2^29 <= P < 2^30 | |
u64 r = a-((a>>28)*M>>33)*P; | |
if (r >= P) r -= P; | |
return r; | |
} | |
long pow_mod(long a, long b, long P, long M) | |
{ | |
long r = 1; | |
for (; b; b >>= 1) { | |
if (b & 1) | |
r = barrett30(r*a, P, M); | |
a = barrett30(a*a, P, M); | |
} | |
return r; | |
} | |
void ntt_dif2(int a[], long n, long P, long M, long G, int is) | |
{ | |
static int units[NN]; | |
long invP = inv(P, 1L<<32), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P, M), wt = (1L<<32)%P; | |
REP(i, n>>1) { | |
units[i] = wt; | |
if (barrett30(wt*w1, P, M) != wt*w1%P) { | |
int*t=0; | |
*t=1; | |
} | |
wt = barrett30(wt*w1, P, M); | |
} | |
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1) | |
for (long r = 0; r < n; r += m) { | |
int *x = a+r, *y = a+r+(m>>1), *w = units; | |
REP(j, m>>1) { | |
long u = long(*x)+*y; | |
auto v = ((unsigned long)(*x)-*y+2*P)**w; | |
if (u >= 2*P) u -= 2*P; | |
*x++ = u; | |
*y++ = (v>>32)-(((v<<32)*invP>>32)*P>>32)+P; | |
w += dwi; | |
} | |
} | |
REP(i, n) | |
if (a[i] >= P) | |
a[i] -= P; | |
long logn = 63-__builtin_clzl(n); | |
REP(i, n) { | |
unsigned int x = i, t; | |
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x; | |
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x; | |
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x; | |
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x; | |
x >>= 31-logn; | |
if (i < x) | |
swap(a[i], a[x]); | |
} | |
if (is < 0) { | |
long invn = inv(n, P); | |
REP(i, n) { | |
if (barrett30(a[i]*invn, P, M) != a[i]*invn%P) { | |
int *t=0; | |
*t=1; | |
} | |
a[i] = barrett30(a[i]*invn, P, M); | |
} | |
} | |
} | |
void check(int a[], long n) | |
{ | |
ntt_dif2(&a[0], n, P_int, M_int, G_int, 1); | |
ntt_dif2(&a[0], n, P_int, M_int, G_int, -1); | |
} | |
void run(int a[], long n) | |
{ | |
ntt_dif2(a, n, P_int, M_int, G_int, 1); | |
} | |
} | |
namespace NTT_dif2 | |
{ | |
template<typename T, T P, T G> | |
void ntt_dif2(T a[], long n, int is) | |
{ | |
static T units[NN/2]; | |
T logn = 63-__builtin_clzl(n), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P), wt = 1; | |
REP(i, n>>1) { | |
units[i] = wt; | |
wt = mul_mod(wt, w1, P); | |
} | |
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1) | |
for (long r = 0; r < n; r += m) { | |
T *x = a+r, *y = a+r+(m>>1), *w = units; | |
REP(j, m>>1) { | |
T u = *x+*y, v = mul_mod(*x-*y+P, *w, P); | |
if (u >= P) u -= P; | |
*x++ = u; | |
*y++ = v; | |
w += dwi; | |
} | |
} | |
if (is < 0) { | |
T invn = pow_mod(n, P-2, P); | |
REP(i, n) | |
a[i] = mul_mod(a[i], invn, P); | |
} | |
REP(i, n) { | |
unsigned int x = i, t; | |
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x; | |
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x; | |
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x; | |
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x; | |
x >>= 31-logn; | |
if (i < x) | |
swap(a[i], a[x]); | |
} | |
} | |
template<typename T, T P, T G> | |
void check(T a[], long n) | |
{ | |
ntt_dif2<T, P, G>(a, n, 1); | |
ntt_dif2<T, P, G>(a, n, -1); | |
} | |
template<typename T, T P, T G> | |
void run(T a[], long n) | |
{ | |
ntt_dif2<T, P, G>(a, n, 1); | |
} | |
} | |
namespace NTT_dif2_variable_P | |
{ | |
template<typename T> | |
void ntt_dif2_p(T a[], long n, T P, T G, int is) | |
{ | |
static T units[NN/2]; | |
T logn = 63-__builtin_clzl(n), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P), wt = 1; | |
REP(i, n>>1) { | |
units[i] = wt; | |
wt = mul_mod(wt, w1, P); | |
} | |
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1) | |
for (long r = 0; r < n; r += m) { | |
T *x = a+r, *y = a+r+(m>>1), *w = units; | |
REP(j, m>>1) { | |
T u = *x+*y, v = mul_mod(*x-*y+P, *w, P); | |
if (u >= P) u -= P; | |
*x++ = u; | |
*y++ = v; | |
w += dwi; | |
} | |
} | |
if (is < 0) { | |
T invn = pow_mod(n, P-2, P); | |
REP(i, n) | |
a[i] = mul_mod(a[i], invn, P); | |
} | |
REP(i, n) { | |
unsigned int x = i, t; | |
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x; | |
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x; | |
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x; | |
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x; | |
x >>= 31-logn; | |
if (i < x) | |
swap(a[i], a[x]); | |
} | |
} | |
template<typename T> | |
void check(T a[], long n) | |
{ | |
volatile long p_int = P_int; | |
ntt_dif2_p<T>(a, n, p_int, G_int, 1); | |
ntt_dif2_p<T>(a, n, p_int, G_int, -1); | |
} | |
template<> | |
void check(long a[], long n) | |
{ | |
volatile long p_long = P_long; | |
ntt_dif2_p<long>(a, n, p_long, G_long, 1); | |
ntt_dif2_p<long>(a, n, p_long, G_long, -1); | |
} | |
template<typename T> | |
void run(T a[], long n) | |
{ | |
volatile long p_int = P_int; | |
ntt_dif2_p<T>(a, n, p_int, G_int, 1); | |
} | |
template<> | |
void run(long a[], long n) | |
{ | |
volatile long p_long = P_long; | |
ntt_dif2_p<long>(a, n, p_long, G_long, 1); | |
} | |
} | |
namespace NTT_dit2 | |
{ | |
template<typename T, T P, T G> | |
void ntt_dit2(T a[], long n, int is) | |
{ | |
static T units[NN/2]; | |
T logn = 63-__builtin_clzl(n), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P), wt = 1; | |
REP(i, n) { | |
unsigned int x = i, t; | |
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x; | |
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x; | |
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x; | |
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x; | |
x >>= 31-logn; | |
if (i < x) | |
swap(a[i], a[x]); | |
} | |
REP(i, n>>1) { | |
units[i] = wt; | |
wt = mul_mod(wt, w1, P); | |
} | |
for (long m = 2, dwi = n>>1; m <= n; m <<= 1, dwi >>= 1) | |
for (long r = 0; r < n; r += m) { | |
T *x = a+r, *y = a+r+(m>>1), *w = units; | |
REP(j, m>>1) { | |
T u = *x, v = mul_mod(*y, *w, P), x1 = u+v, y1 = u-v; | |
if (x1 >= P) x1 -= P; | |
if (y1 < 0) y1 += P; | |
*x++ = x1; | |
*y++ = y1; | |
w += dwi; | |
} | |
} | |
if (is < 0) { | |
T invn = pow_mod(n, P-2, P); | |
REP(i, n) | |
a[i] = mul_mod(a[i], invn, P); | |
} | |
} | |
template<typename T, T P, T G> | |
void check(T a[], long n) | |
{ | |
ntt_dit2<T, P, G>(a, n, 1); | |
ntt_dit2<T, P, G>(a, n, -1); | |
} | |
template<typename T, T P, T G> | |
void run(T a[], long n) | |
{ | |
ntt_dit2<int, P, G>(a, n, 1); | |
} | |
} | |
namespace FFT_dif2 | |
{ | |
void fft_dif2(cd a[], long n) | |
{ // sign = -1 | |
static cd units[NN/2]; | |
double ph = 2*M_PI/n; | |
REP(i, n/2) | |
units[i] = {cos(ph*i), sin(ph*i)}; | |
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1) | |
for (long r = 0; r < n; r += m) { | |
cd *x = a+r, *y = a+r+(m>>1), *w = units; | |
REP(j, m>>1) { | |
cd v = *y, t = *x-v; | |
*y++ = {t.real()*w->real()-t.imag()*w->imag(), t.real()*w->imag()-t.imag()*w->real()}; | |
*x++ += v; | |
w += dwi; | |
} | |
} | |
} | |
void run(cd a[], long n) | |
{ | |
fft_dif2(a, n); | |
} | |
} | |
namespace FFT_dit2 | |
{ | |
void fft_dit2(cd a[], long n) | |
{ | |
static cd units[NN/2]; | |
double ph = 2*M_PI/n; | |
REP(i, n/2) | |
units[i] = {cos(ph*i), sin(ph*i)}; | |
for (long m = 2, dwi = n>>1; m <= n; m <<= 1, dwi >>= 1) | |
for (long r = 0; r < n; r += m) { | |
cd *x = a+r, *y = a+r+(m>>1), *w = units; | |
REP(j, m>>1) { | |
cd t{y->real()*w->real()-y->imag()*w->imag(), y->real()*w->imag()+y->imag()*w->real()}; | |
*y++ = *x-t; | |
*x++ += t; | |
w += dwi; | |
} | |
} | |
} | |
void run(cd a[], long n) | |
{ | |
fft_dit2(a, n); | |
} | |
} | |
template<typename T> | |
void check(long n, void(*fn)(T a[], long)) | |
{ | |
auto a = setup<T>(n); | |
fn(&a[0], n); | |
REP(i, n) | |
assert(a[i] == i); | |
} | |
template<typename T> | |
long test(long n, void(*fn)(T a[], long)) | |
{ | |
auto a = setup<T>(n); | |
auto start = chrono::steady_clock::now(); | |
REP(_, times) | |
fn(&a[0], n); | |
return chrono::duration_cast<chrono::microseconds>(chrono::steady_clock::now() - start).count() / times; | |
} | |
int main(int argc, char* argv[]) | |
{ | |
if (argc > 1) | |
times = atoi(argv[1]); | |
for (long n = 1<<4; n <= 1<<4; n <<= 1) { | |
check<int>(n, Montgomery::check); | |
check<int>(n, NTT_dif2::check<int, P_int, G_int>); | |
check<long>(n, NTT_dif2::check<long, P_long, G_long>); | |
check<int>(n, NTT_dit2::check<int, P_int, G_int>); | |
check<long>(n, NTT_dit2::check<long, P_long, G_long>); | |
check<int>(n, NTT_dif2_variable_P::check<int>); | |
check<long>(n, NTT_dif2_variable_P::check<long>); | |
} | |
for (long n = 1<<8; n <= NN; n <<= 1) { | |
vector<pair<long, string>> res; | |
res.emplace_back(test<int>(n, Montgomery::run), "Montgomery+Barrett NTT dif2 int"); | |
res.emplace_back(test<int>(n, NTT_dif2::run<int, P_int, G_int>), "NTT dif2 int"); | |
res.emplace_back(test<long>(n, NTT_dif2::run<long, P_long, G_long>), "NTT dif2 long"); | |
res.emplace_back(test<int>(n, NTT_dif2::run<int, P_int, G_int>), "NTT dit2 int"); | |
res.emplace_back(test<long>(n, NTT_dif2::run<long, P_long, G_long>), "NTT dit2 long"); | |
res.emplace_back(test<int>(n, NTT_dif2_variable_P::run<int>), "NTT dif2 int non-constant P"); | |
res.emplace_back(test<long>(n, NTT_dif2_variable_P::run<long>), "NTT dif2 long non-constant P"); | |
res.emplace_back(test<cd>(n, FFT_dif2::run), "FFT dif2"); | |
res.emplace_back(test<cd>(n, FFT_dit2::run), "FFT dit2"); | |
sort(ALL(res)); | |
for (auto& x: res) cout << n << '\t' << x.first << '\t' << x.second << '\n'; | |
cout << '\n'; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment