Created
April 5, 2021 12:32
-
-
Save fredrik-johansson/2b8d6db8a2ffaf50c9ed97854abacd62 to your computer and use it in GitHub Desktop.
Some new nmod_poly multiplication code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "flint/nmod_poly.h" | |
#include "flint/profiler.h" | |
/* | |
Multiplication/squaring using Kronecker substitution at 2^b and -2^b. | |
*/ | |
void | |
_nmod_poly_mul_KS2B(mp_ptr res, mp_srcptr op1, slong n1, | |
mp_srcptr op2, slong n2, nmod_t mod) | |
{ | |
int sqr, v3m_neg; | |
ulong bits, b, w; | |
slong n1o, n1e, n2o, n2e, n3o, n3e, n3, k1, k2, k3; | |
mp_ptr v1_buf0, v2_buf0, v1_buf1, v2_buf1, v1_buf2, v2_buf2; | |
mp_ptr v1o, v1e, v1p, v1m, v2o, v2e, v2p, v2m, v3o, v3e, v3p, v3m; | |
mp_ptr z, tmp; | |
TMP_INIT; | |
if (n2 == 1) | |
{ | |
/* code below needs n2 > 1, so fall back on scalar multiplication */ | |
_nmod_vec_scalar_mul_nmod(res, op1, n1, op2[0], mod); | |
return; | |
} | |
TMP_START; | |
sqr = (op1 == op2 && n1 == n2); | |
/* bits in each output coefficient */ | |
bits = 2 * (FLINT_BITS - mod.norm) + FLINT_CLOG2(n2); | |
/* we're evaluating at x = B and -B, where B = 2^b, and b = ceil(bits / 2) */ | |
b = (bits + 1) / 2; | |
/* number of ulongs required to store each output coefficient */ | |
w = (2*b - 1)/FLINT_BITS + 1; | |
/* | |
Write f1(x) = f1e(x^2) + x * f1o(x^2) | |
f2(x) = f2e(x^2) + x * f2o(x^2) | |
h(x) = he(x^2) + x * ho(x^2) | |
"e" = even, "o" = odd | |
*/ | |
n1o = n1 / 2; | |
n1e = n1 - n1o; | |
n2o = n2 / 2; | |
n2e = n2 - n2o; | |
n3 = n1 + n2 - 1; /* length of h */ | |
n3o = n3 / 2; | |
n3e = n3 - n3o; | |
/* | |
f1(B) and |f1(-B)| are at most ((n1 - 1) * b + mod->bits) bits long. | |
However, when evaluating f1e(B^2) and B * f1o(B^2) the bitpacking | |
routine needs room for the last chunk of 2b bits. Therefore we need to | |
allow room for (n1 + 1) * b bits. Ditto for f2. | |
*/ | |
k1 = ((n1 + 1)*b - 1)/FLINT_BITS + 1; | |
k2 = ((n2 + 1)*b - 1)/FLINT_BITS + 1; | |
k3 = k1 + k2; | |
/* allocate space */ | |
v1_buf0 = TMP_ALLOC(sizeof(mp_limb_t) * 3 * k3); /* k1 limbs */ | |
v2_buf0 = v1_buf0 + k1; /* k2 limbs */ | |
v1_buf1 = v2_buf0 + k2; /* k1 limbs */ | |
v2_buf1 = v1_buf1 + k1; /* k2 limbs */ | |
v1_buf2 = v2_buf1 + k2; /* k1 limbs */ | |
v2_buf2 = v1_buf2 + k1; /* k2 limbs */ | |
/* | |
arrange overlapping buffers to minimise memory use | |
"p" = plus, "m" = minus | |
*/ | |
v1e = v1_buf0; | |
v2e = v2_buf0; | |
v1o = v1_buf1; | |
v2o = v2_buf1; | |
v1p = v1_buf2; | |
v2p = v2_buf2; | |
v1m = v1_buf0; | |
v2m = v2_buf0; | |
v3m = v1_buf1; | |
v3p = v1_buf0; | |
v3e = v1_buf2; | |
v3o = v1_buf0; | |
z = TMP_ALLOC(sizeof(mp_limb_t) * w * n3e); | |
if (!sqr) | |
{ | |
/* multiplication version */ | |
/* evaluate f1e(B^2) and B * f1o(B^2) */ | |
_nmod_poly_KS2_pack(v1e, op1, n1e, 2, 2 * b, 0, k1); | |
_nmod_poly_KS2_pack(v1o, op1 + 1, n1o, 2, 2 * b, b, k1); | |
/* evaluate f2e(B^2) and B * f2o(B^2) */ | |
_nmod_poly_KS2_pack(v2e, op2, n2e, 2, 2 * b, 0, k2); | |
_nmod_poly_KS2_pack(v2o, op2 + 1, n2o, 2, 2 * b, b, k2); | |
/* | |
compute f1(B) = f1e(B^2) + B * f1o(B^2) | |
and f2(B) = f2e(B^2) + B * f2o(B^2) | |
*/ | |
mpn_add_n(v1p, v1e, v1o, k1); | |
mpn_add_n(v2p, v2e, v2o, k2); | |
/* | |
compute |f1(-B)| = |f1e(B^2) - B * f1o(B^2)| | |
and |f2(-B)| = |f2e(B^2) - B * f2o(B^2)| | |
*/ | |
v3m_neg = signed_mpn_sub_n(v1m, v1e, v1o, k1); | |
v3m_neg ^= signed_mpn_sub_n(v2m, v2e, v2o, k2); | |
/* | |
compute h(B) = f1(B) * f2(B) | |
compute |h(-B)| = |f1(-B)| * |f2(-B)| | |
v3m_neg is set if h(-B) is negative | |
*/ | |
mpn_mul(v3m, v1m, k1, v2m, k2); | |
mpn_mul(v3p, v1p, k1, v2p, k2); | |
} | |
else | |
{ | |
/* squaring version */ | |
/* evaluate f1e(B^2) and B * f1o(B^2) */ | |
_nmod_poly_KS2_pack(v1e, op1, n1e, 2, 2 * b, 0, k1); | |
_nmod_poly_KS2_pack(v1o, op1 + 1, n1o, 2, 2 * b, b, k1); | |
/* compute f1(B) = f1e(B^2) + B * f1o(B^2) */ | |
mpn_add_n(v1p, v1e, v1o, k1); | |
/* compute |f1(-B)| = |f1e(B^2) - B * f1o(B^2)| */ | |
signed_mpn_sub_n(v1m, v1e, v1o, k1); | |
/* | |
compute h(B) = f1(B)^2 | |
compute h(-B) = f1(-B)^2 | |
v3m_neg is cleared (since f1(-B)^2 is never negative) | |
*/ | |
mpn_sqr(v3m, v1m, k1); | |
mpn_sqr(v3p, v1p, k1); | |
v3m_neg = 0; | |
} | |
/* | |
he(B^2) and B * ho(B^2) are both at most b * (n3 + 1) bits long (since | |
the coefficients don't overlap). The buffers used below are at least | |
b * (n1 + n2 + 2) = b * (n3 + 3) bits long. So we definitely have | |
enough room for 2 * he(B^2) and 2 * B * ho(B^2). | |
*/ | |
/* compute 2 * he(B^2) = h(B) + h(-B) */ | |
if (v3m_neg) | |
mpn_sub_n(v3e, v3p, v3m, k3); | |
else | |
mpn_add_n(v3e, v3p, v3m, k3); | |
/* unpack coefficients of he, and reduce mod m */ | |
_nmod_poly_KS2_unpack(z, v3e, n3e, 2 * b, 1); | |
_nmod_poly_KS2_reduce(res, 2, z, n3e, w, mod); | |
/* compute 2 * b * ho(B^2) = h(B) - h(-B) */ | |
if (v3m_neg) | |
mpn_add_n(v3o, v3p, v3m, k3); | |
else | |
mpn_sub_n(v3o, v3p, v3m, k3); | |
/* unpack coefficients of ho, and reduce mod m */ | |
_nmod_poly_KS2_unpack(z, v3o, n3o, 2 * b, b + 1); | |
_nmod_poly_KS2_reduce(res + 1, 2, z, n3o, w, mod); | |
TMP_END; | |
} | |
void | |
nmod_poly_mul_KS2B(nmod_poly_t res, | |
const nmod_poly_t poly1, const nmod_poly_t poly2) | |
{ | |
slong len_out; | |
if ((poly1->length == 0) || (poly2->length == 0)) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
len_out = poly1->length + poly2->length - 1; | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out); | |
if (poly1->length >= poly2->length) | |
_nmod_poly_mul_KS2B(temp->coeffs, poly1->coeffs, poly1->length, | |
poly2->coeffs, poly2->length, | |
poly1->mod); | |
else | |
_nmod_poly_mul_KS2B(temp->coeffs, poly2->coeffs, poly2->length, | |
poly1->coeffs, poly1->length, | |
poly1->mod); | |
nmod_poly_swap(res, temp); | |
nmod_poly_clear(temp); | |
} | |
else | |
{ | |
nmod_poly_fit_length(res, len_out); | |
if (poly1->length >= poly2->length) | |
_nmod_poly_mul_KS2B(res->coeffs, poly1->coeffs, poly1->length, | |
poly2->coeffs, poly2->length, | |
poly1->mod); | |
else | |
_nmod_poly_mul_KS2B(res->coeffs, poly2->coeffs, poly2->length, | |
poly1->coeffs, poly1->length, | |
poly1->mod); | |
} | |
res->length = len_out; | |
_nmod_poly_normalise(res); | |
} | |
static mp_limb_t | |
nmod_fmma(mp_limb_t a, mp_limb_t b, mp_limb_t c, mp_limb_t d, nmod_t mod) | |
{ | |
a = nmod_mul(a, b, mod); | |
NMOD_ADDMUL(a, c, d, mod); | |
return a; | |
} | |
mp_limb_t | |
_nmod_vec_dot_rev(mp_srcptr vec1, mp_srcptr vec2, slong len, nmod_t mod, int nlimbs) | |
{ | |
mp_limb_t res; | |
slong i; | |
if (len <= 2 && nlimbs >= 2) | |
{ | |
if (len == 2) | |
return nmod_fmma(vec1[0], vec2[1], vec1[1], vec2[0], mod); | |
if (len == 1) | |
return nmod_mul(vec1[0], vec2[0], mod); | |
return 0; | |
} | |
NMOD_VEC_DOT(res, i, len, vec1[i], vec2[len - 1 - i], mod, nlimbs); | |
return res; | |
} | |
void | |
_nmod_poly_sqr_classical(mp_ptr res, mp_srcptr poly, | |
slong len, nmod_t mod) | |
{ | |
slong i, j, bits, log_len, nlimbs, start, stop; | |
mp_limb_t c; | |
if (len == 1) | |
{ | |
res[0] = nmod_mul(poly[0], poly[0], mod); | |
return; | |
} | |
if (len == 2) | |
{ | |
mp_limb_t a, b, c; | |
a = poly[0]; | |
b = poly[1]; | |
c = nmod_mul(a, b, mod); | |
res[0] = nmod_mul(a, a, mod); | |
res[1] = nmod_add(c, c, mod); | |
res[2] = nmod_mul(b, b, mod); | |
return; | |
} | |
log_len = FLINT_BIT_COUNT(len); | |
bits = FLINT_BITS - (slong) mod.norm; | |
bits = 2 * bits + log_len; | |
if (bits <= FLINT_BITS) | |
{ | |
flint_mpn_zero(res, 2 * len - 1); | |
for (i = 0; i < len; i++) | |
{ | |
c = poly[i]; | |
res[2 * i] += c * c; | |
c *= 2; | |
for (j = i + 1; j < len; j++) | |
res[i + j] += poly[j] * c; | |
} | |
_nmod_vec_reduce(res, res, 2 * len - 1, mod); | |
return; | |
} | |
if (bits <= 2 * FLINT_BITS) | |
nlimbs = 2; | |
else | |
nlimbs = 3; | |
for (i = 0; i < 2 * len - 1; i++) | |
{ | |
start = FLINT_MAX(0, i - len + 1); | |
stop = FLINT_MIN(len - 1, (i + 1) / 2 - 1); | |
c = _nmod_vec_dot_rev(poly + start, poly + i - stop, stop - start + 1, mod, nlimbs); | |
c = nmod_add(c, c, mod); | |
if (i % 2 == 0 && i / 2 < len) | |
NMOD_ADDMUL(c, poly[i / 2], poly[i / 2], mod); | |
res[i] = c; | |
} | |
} | |
void | |
_nmod_poly_mul_classical2(mp_ptr res, mp_srcptr poly1, | |
slong len1, mp_srcptr poly2, slong len2, nmod_t mod) | |
{ | |
slong i, j, bits, log_len, nlimbs, n1, n2; | |
int squaring; | |
mp_limb_t c; | |
if (len1 == 1) | |
{ | |
res[0] = nmod_mul(poly1[0], poly2[0], mod); | |
return; | |
} | |
if (len2 == 1) | |
{ | |
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod); | |
return; | |
} | |
squaring = (poly1 == poly2 && len1 == len2); | |
log_len = FLINT_BIT_COUNT(len2); | |
bits = FLINT_BITS - (slong) mod.norm; | |
bits = 2 * bits + log_len; | |
if (bits <= FLINT_BITS) | |
{ | |
flint_mpn_zero(res, len1 + len2 - 1); | |
if (squaring) | |
{ | |
for (i = 0; i < len1; i++) | |
{ | |
c = poly1[i]; | |
res[2 * i] += c * c; | |
c *= 2; | |
for (j = i + 1; j < len1; j++) | |
res[i + j] += poly1[j] * c; | |
} | |
} | |
else | |
{ | |
for (i = 0; i < len1; i++) | |
{ | |
mp_limb_t c = poly1[i]; | |
for (j = 0; j < len2; j++) | |
res[i + j] += c * poly2[j]; | |
} | |
} | |
_nmod_vec_reduce(res, res, len1 + len2 - 1, mod); | |
return; | |
} | |
if (len2 == 2) | |
{ | |
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod); | |
_nmod_vec_scalar_addmul_nmod(res + 1, poly1, len1 - 1, poly2[1], mod); | |
res[len1 + len2 - 2] = nmod_mul(poly1[len1 - 1], poly2[len2 - 1], mod); | |
return; | |
} | |
if (bits <= 2 * FLINT_BITS) | |
nlimbs = 2; | |
else | |
nlimbs = 3; | |
if (squaring) | |
{ | |
for (i = 0; i < 2 * len1 - 1; i++) | |
{ | |
n1 = FLINT_MAX(0, i - len1 + 1); | |
n2 = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1); | |
c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, nlimbs); | |
c = nmod_add(c, c, mod); | |
if (i % 2 == 0 && i / 2 < len1) | |
NMOD_ADDMUL(c, poly1[i / 2], poly1[i / 2], mod); | |
res[i] = c; | |
} | |
} | |
else | |
{ | |
for (i = 0; i < len1 + len2 - 1; i++) | |
{ | |
n1 = FLINT_MIN(len1 - 1, i); | |
n2 = FLINT_MIN(len2 - 1, i); | |
res[i] = _nmod_vec_dot_rev(poly1 + i - n2, | |
poly2 + i - n1, | |
n1 + n2 - i + 1, mod, nlimbs); | |
} | |
} | |
} | |
void | |
nmod_poly_mul_classical2(nmod_poly_t res, | |
const nmod_poly_t poly1, const nmod_poly_t poly2) | |
{ | |
slong len_out; | |
if ((poly1->length == 0) || (poly2->length == 0)) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
len_out = poly1->length + poly2->length - 1; | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out); | |
if (poly1->length >= poly2->length) | |
_nmod_poly_mul_classical2(temp->coeffs, poly1->coeffs, | |
poly1->length, poly2->coeffs, | |
poly2->length, poly1->mod); | |
else | |
_nmod_poly_mul_classical2(temp->coeffs, poly2->coeffs, | |
poly2->length, poly1->coeffs, | |
poly1->length, poly1->mod); | |
nmod_poly_swap(res, temp); | |
nmod_poly_clear(temp); | |
} | |
else | |
{ | |
nmod_poly_fit_length(res, len_out); | |
if (poly1->length >= poly2->length) | |
_nmod_poly_mul_classical2(res->coeffs, poly1->coeffs, poly1->length, | |
poly2->coeffs, poly2->length, poly1->mod); | |
else | |
_nmod_poly_mul_classical2(res->coeffs, poly2->coeffs, poly2->length, | |
poly1->coeffs, poly1->length, poly1->mod); | |
} | |
res->length = len_out; | |
_nmod_poly_normalise(res); | |
} | |
flint_bitcnt_t _nmod_vec_max_bits2(mp_srcptr vec, slong len) | |
{ | |
slong i; | |
mp_limb_t mask = 0; | |
for (i = 0; i < len; i++) | |
{ | |
mask |= vec[i]; | |
if (mask & (UWORD(1) << (FLINT_BITS - 1))) | |
return FLINT_BITS; | |
} | |
return FLINT_BIT_COUNT(mask); | |
} | |
void | |
_nmod_poly_mul_KSB(mp_ptr out, mp_srcptr in1, slong len1, | |
mp_srcptr in2, slong len2, flint_bitcnt_t bits, nmod_t mod) | |
{ | |
slong len_out = len1 + len2 - 1, limbs1, limbs2; | |
mp_ptr tmp, mpn1, mpn2, res; | |
int squaring; | |
TMP_INIT; | |
squaring = (in1 == in2 && len1 == len2); | |
if (bits == 0) | |
{ | |
flint_bitcnt_t bits1, bits2, loglen; | |
#if 0 | |
bits1 = _nmod_vec_max_bits2(in1, len1); | |
bits2 = squaring ? bits1 : _nmod_vec_max_bits2(in2, len2); | |
#else | |
bits1 = FLINT_BITS - (slong) mod.norm; | |
bits2 = bits1; | |
#endif | |
loglen = FLINT_BIT_COUNT(len2); | |
bits = bits1 + bits2 + loglen; | |
} | |
limbs1 = (len1 * bits - 1) / FLINT_BITS + 1; | |
limbs2 = (len2 * bits - 1) / FLINT_BITS + 1; | |
TMP_START; | |
tmp = TMP_ALLOC(sizeof(mp_limb_t) * (limbs1 + limbs2 + limbs1 + (squaring ? 0 : limbs2))); | |
res = tmp; | |
mpn1 = tmp + limbs1 + limbs2; | |
mpn2 = squaring ? mpn1 : (mpn1 + limbs1); | |
_nmod_poly_bit_pack(mpn1, in1, len1, bits); | |
if (!squaring) | |
_nmod_poly_bit_pack(mpn2, in2, len2, bits); | |
if (squaring) | |
mpn_sqr(res, mpn1, limbs1); | |
else | |
mpn_mul(res, mpn1, limbs1, mpn2, limbs2); | |
_nmod_poly_bit_unpack(out, len_out, res, bits, mod); | |
TMP_END; | |
} | |
void | |
nmod_poly_mul_KSB(nmod_poly_t res, | |
const nmod_poly_t poly1, const nmod_poly_t poly2, | |
flint_bitcnt_t bits) | |
{ | |
slong len_out; | |
if ((poly1->length == 0) || (poly2->length == 0)) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
len_out = poly1->length + poly2->length - 1; | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out); | |
if (poly1->length >= poly2->length) | |
_nmod_poly_mul_KSB(temp->coeffs, poly1->coeffs, poly1->length, | |
poly2->coeffs, poly2->length, bits, | |
poly1->mod); | |
else | |
_nmod_poly_mul_KSB(temp->coeffs, poly2->coeffs, poly2->length, | |
poly1->coeffs, poly1->length, bits, | |
poly1->mod); | |
nmod_poly_swap(res, temp); | |
nmod_poly_clear(temp); | |
} | |
else | |
{ | |
nmod_poly_fit_length(res, len_out); | |
if (poly1->length >= poly2->length) | |
_nmod_poly_mul_KSB(res->coeffs, poly1->coeffs, poly1->length, | |
poly2->coeffs, poly2->length, bits, | |
poly1->mod); | |
else | |
_nmod_poly_mul_KSB(res->coeffs, poly2->coeffs, poly2->length, | |
poly1->coeffs, poly1->length, bits, | |
poly1->mod); | |
} | |
res->length = len_out; | |
_nmod_poly_normalise(res); | |
} | |
#define TIMEIT_PRINT1(__var, __timer, __reps) \ | |
__var = __timer->cpu*0.001/__reps; | |
#define TIMEIT_REPEAT1(__timer, __reps) \ | |
do \ | |
{ \ | |
slong __timeit_k; \ | |
__reps = 1; \ | |
while (1) \ | |
{ \ | |
timeit_start(__timer); \ | |
for (__timeit_k = 0; __timeit_k < __reps; __timeit_k++) \ | |
{ | |
#define TIMEIT_END_REPEAT1(__timer, __reps) \ | |
} \ | |
timeit_stop(__timer); \ | |
if (__timer->cpu >= 10) \ | |
break; \ | |
__reps *= 10; \ | |
} \ | |
} while (0); | |
#define TIMEIT_START1 \ | |
do { \ | |
timeit_t __timer; slong __reps; \ | |
TIMEIT_REPEAT1(__timer, __reps) | |
#define TIMEIT_STOP1(__var) \ | |
TIMEIT_END_REPEAT1(__timer, __reps) \ | |
TIMEIT_PRINT1(__var, __timer, __reps) \ | |
} while (0); | |
static int choose_KS2(slong bits, slong len) | |
{ | |
if (len * bits < 800) | |
return 1; | |
if (len * bits * bits < 100000 * (1 + (FLINT_BITS >= 62))) | |
return 2; | |
return 4; | |
} | |
static int choose_KS(slong bits, slong len) | |
{ | |
if (len * bits < 800) | |
return 1; | |
if (len * bits * bits < 100000) | |
return 2; | |
return 4; | |
} | |
void _nmod_poly_mul2(mp_ptr res, mp_srcptr poly1, slong len1, | |
mp_srcptr poly2, slong len2, nmod_t mod) | |
{ | |
int KS; | |
slong bits, cutoff_len; | |
if (len2 <= 5) | |
{ | |
_nmod_poly_mul_classical2(res, poly1, len1, poly2, len2, mod); | |
return; | |
} | |
bits = FLINT_BITS - (slong) mod.norm; | |
cutoff_len = FLINT_MIN(len1, 2 * len2); | |
if (3 * cutoff_len < 2 * FLINT_MAX(bits, 10)) | |
_nmod_poly_mul_classical2(res, poly1, len1, poly2, len2, mod); | |
else if (cutoff_len * bits < 800) | |
_nmod_poly_mul_KSB(res, poly1, len1, poly2, len2, 0, mod); | |
else if (cutoff_len * (bits + 1) * (bits + 1) < 100000) | |
_nmod_poly_mul_KS2B(res, poly1, len1, poly2, len2, mod); | |
else | |
_nmod_poly_mul_KS4(res, poly1, len1, poly2, len2, mod); | |
return; | |
/* | |
slong bits2; | |
bits2 = FLINT_BIT_COUNT(len1); | |
if (2 * bits + bits2 <= FLINT_BITS && len1 + len2 < 16) | |
_nmod_poly_mul_classical2(res, poly1, len1, poly2, len2, mod); | |
else if (bits * len2 > 2000) | |
_nmod_poly_mul_KS4(res, poly1, len1, poly2, len2, mod); | |
else if (bits * len2 > 200) | |
_nmod_poly_mul_KS2B(res, poly1, len1, poly2, len2, mod); | |
else | |
_nmod_poly_mul_KSB(res, poly1, len1, poly2, len2, 0, mod); | |
return; | |
*/ | |
/* Note: with unbalanced operands, KS tuning seems to respond better to the | |
length of the longer operand? */ | |
KS = choose_KS(bits, len1); | |
if (KS == 1) | |
_nmod_poly_mul_KSB(res, poly1, len1, poly2, len2, 0, mod); | |
else if (KS == 2) | |
_nmod_poly_mul_KS2B(res, poly1, len1, poly2, len2, mod); | |
else | |
_nmod_poly_mul_KS4(res, poly1, len1, poly2, len2, mod); | |
} | |
void nmod_poly_mul2(nmod_poly_t res, const nmod_poly_t poly1, const nmod_poly_t poly2) | |
{ | |
slong len1, len2, len_out; | |
len1 = poly1->length; | |
len2 = poly2->length; | |
if (len1 == 0 || len2 == 0) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
len_out = poly1->length + poly2->length - 1; | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2(temp, poly1->mod.n, len_out); | |
if (len1 >= len2) | |
_nmod_poly_mul2(temp->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, poly1->mod); | |
else | |
_nmod_poly_mul2(temp->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, poly1->mod); | |
nmod_poly_swap(temp, res); | |
nmod_poly_clear(temp); | |
} else | |
{ | |
nmod_poly_fit_length(res, len_out); | |
if (len1 >= len2) | |
_nmod_poly_mul2(res->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, poly1->mod); | |
else | |
_nmod_poly_mul2(res->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, poly1->mod); | |
} | |
res->length = len_out; | |
_nmod_poly_normalise(res); | |
} | |
void | |
_nmod_poly_mullow_classical2(mp_ptr res, mp_srcptr poly1, | |
slong len1, mp_srcptr poly2, slong len2, slong n, nmod_t mod) | |
{ | |
slong i, j, bits, log_len, nlimbs, n1, n2; | |
int squaring; | |
mp_limb_t c; | |
len1 = FLINT_MIN(len1, n); | |
len2 = FLINT_MIN(len2, n); | |
if (n == 1) | |
{ | |
res[0] = nmod_mul(poly1[0], poly2[0], mod); | |
return; | |
} | |
if (len2 == 1) | |
{ | |
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod); | |
return; | |
} | |
squaring = (poly1 == poly2 && len1 == len2); | |
log_len = FLINT_BIT_COUNT(len2); | |
bits = FLINT_BITS - (slong) mod.norm; | |
bits = 2 * bits + log_len; | |
if (bits <= FLINT_BITS) | |
{ | |
flint_mpn_zero(res, n); | |
if (squaring) | |
{ | |
for (i = 0; i < len1; i++) | |
{ | |
c = poly1[i]; | |
if (2 * i < n) | |
res[2 * i] += c * c; | |
c *= 2; | |
for (j = i + 1; j < FLINT_MIN(len1, n - i); j++) | |
res[i + j] += poly1[j] * c; | |
} | |
} | |
else | |
{ | |
for (i = 0; i < len1; i++) | |
{ | |
mp_limb_t c = poly1[i]; | |
for (j = 0; j < FLINT_MIN(len2, n - i); j++) | |
res[i + j] += c * poly2[j]; | |
} | |
} | |
_nmod_vec_reduce(res, res, n, mod); | |
return; | |
} | |
if (len2 == 2) | |
{ | |
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod); | |
_nmod_vec_scalar_addmul_nmod(res + 1, poly1, len1 - 1, poly2[1], mod); | |
if (n == len1 + len2 - 1) | |
res[len1 + len2 - 2] = nmod_mul(poly1[len1 - 1], poly2[len2 - 1], mod); | |
return; | |
} | |
if (bits <= 2 * FLINT_BITS) | |
nlimbs = 2; | |
else | |
nlimbs = 3; | |
if (squaring) | |
{ | |
for (i = 0; i < n; i++) | |
{ | |
n1 = FLINT_MAX(0, i - len1 + 1); | |
n2 = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1); | |
c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, nlimbs); | |
c = nmod_add(c, c, mod); | |
if (i % 2 == 0 && i / 2 < len1) | |
NMOD_ADDMUL(c, poly1[i / 2], poly1[i / 2], mod); | |
res[i] = c; | |
} | |
} | |
else | |
{ | |
for (i = 0; i < n; i++) | |
{ | |
n1 = FLINT_MIN(len1 - 1, i); | |
n2 = FLINT_MIN(len2 - 1, i); | |
res[i] = _nmod_vec_dot_rev(poly1 + i - n2, | |
poly2 + i - n1, | |
n1 + n2 - i + 1, mod, nlimbs); | |
} | |
} | |
} | |
void | |
_nmod_poly_mullow_KSB(mp_ptr out, mp_srcptr in1, slong len1, | |
mp_srcptr in2, slong len2, flint_bitcnt_t bits, slong n, nmod_t mod) | |
{ | |
slong limbs1, limbs2; | |
mp_ptr tmp, mpn1, mpn2, res; | |
int squaring; | |
TMP_INIT; | |
len1 = FLINT_MIN(len1, n); | |
len2 = FLINT_MIN(len2, n); | |
squaring = (in1 == in2 && len1 == len2); | |
if (bits == 0) | |
{ | |
flint_bitcnt_t bits1, bits2, loglen; | |
#if 0 | |
bits1 = _nmod_vec_max_bits2(in1, len1); | |
bits2 = squaring ? bits1 : _nmod_vec_max_bits2(in2, len2); | |
#else | |
bits1 = FLINT_BITS - (slong) mod.norm; | |
bits2 = bits1; | |
#endif | |
loglen = FLINT_BIT_COUNT(len2); | |
bits = bits1 + bits2 + loglen; | |
} | |
limbs1 = (len1 * bits - 1) / FLINT_BITS + 1; | |
limbs2 = (len2 * bits - 1) / FLINT_BITS + 1; | |
TMP_START; | |
tmp = TMP_ALLOC(sizeof(mp_limb_t) * (limbs1 + limbs2 + limbs1 + (squaring ? 0 : limbs2))); | |
res = tmp; | |
mpn1 = tmp + limbs1 + limbs2; | |
mpn2 = squaring ? mpn1 : (mpn1 + limbs1); | |
_nmod_poly_bit_pack(mpn1, in1, len1, bits); | |
if (!squaring) | |
_nmod_poly_bit_pack(mpn2, in2, len2, bits); | |
if (squaring) | |
mpn_sqr(res, mpn1, limbs1); | |
else | |
mpn_mul(res, mpn1, limbs1, mpn2, limbs2); | |
_nmod_poly_bit_unpack(out, n, res, bits, mod); | |
TMP_END; | |
} | |
void _nmod_poly_mullow2(mp_ptr res, mp_srcptr poly1, slong len1, | |
mp_srcptr poly2, slong len2, slong n, nmod_t mod) | |
{ | |
slong bits; | |
len1 = FLINT_MIN(len1, n); | |
len2 = FLINT_MIN(len2, n); | |
if (len2 <= 5) | |
{ | |
_nmod_poly_mullow_classical2(res, poly1, len1, poly2, len2, n, mod); | |
return; | |
} | |
bits = FLINT_BITS - (slong) mod.norm; | |
if (n < 10 + bits * bits / 10) | |
_nmod_poly_mullow_classical2(res, poly1, len1, poly2, len2, n, mod); | |
else | |
_nmod_poly_mullow_KSB(res, poly1, len1, poly2, len2, 0, n, mod); | |
} | |
void nmod_poly_mullow2(nmod_poly_t res, | |
const nmod_poly_t poly1, const nmod_poly_t poly2, slong trunc) | |
{ | |
slong len1, len2, len_out; | |
len1 = poly1->length; | |
len2 = poly2->length; | |
len_out = poly1->length + poly2->length - 1; | |
if (trunc > len_out) | |
trunc = len_out; | |
if (len1 == 0 || len2 == 0 || trunc == 0) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2(temp, poly1->mod.n, trunc); | |
if (len1 >= len2) | |
_nmod_poly_mullow2(temp->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, trunc, poly1->mod); | |
else | |
_nmod_poly_mullow2(temp->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, trunc, poly1->mod); | |
nmod_poly_swap(temp, res); | |
nmod_poly_clear(temp); | |
} else | |
{ | |
nmod_poly_fit_length(res, trunc); | |
if (len1 >= len2) | |
_nmod_poly_mullow2(res->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, trunc, poly1->mod); | |
else | |
_nmod_poly_mullow2(res->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, trunc, poly1->mod); | |
} | |
res->length = trunc; | |
_nmod_poly_normalise(res); | |
} | |
void nmod_poly_mullow_classical2(nmod_poly_t res, | |
const nmod_poly_t poly1, const nmod_poly_t poly2, slong trunc) | |
{ | |
slong len1, len2, len_out; | |
len1 = poly1->length; | |
len2 = poly2->length; | |
len_out = poly1->length + poly2->length - 1; | |
if (trunc > len_out) | |
trunc = len_out; | |
if (len1 == 0 || len2 == 0 || trunc == 0) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2(temp, poly1->mod.n, trunc); | |
if (len1 >= len2) | |
_nmod_poly_mullow_classical2(temp->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, trunc, poly1->mod); | |
else | |
_nmod_poly_mullow_classical2(temp->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, trunc, poly1->mod); | |
nmod_poly_swap(temp, res); | |
nmod_poly_clear(temp); | |
} else | |
{ | |
nmod_poly_fit_length(res, trunc); | |
if (len1 >= len2) | |
_nmod_poly_mullow_classical2(res->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, trunc, poly1->mod); | |
else | |
_nmod_poly_mullow_classical2(res->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, trunc, poly1->mod); | |
} | |
res->length = trunc; | |
_nmod_poly_normalise(res); | |
} | |
void nmod_poly_mullow_KSB(nmod_poly_t res, | |
const nmod_poly_t poly1, const nmod_poly_t poly2, slong trunc) | |
{ | |
slong len1, len2, len_out; | |
len1 = poly1->length; | |
len2 = poly2->length; | |
len_out = poly1->length + poly2->length - 1; | |
if (trunc > len_out) | |
trunc = len_out; | |
if (len1 == 0 || len2 == 0 || trunc == 0) | |
{ | |
nmod_poly_zero(res); | |
return; | |
} | |
if (res == poly1 || res == poly2) | |
{ | |
nmod_poly_t temp; | |
nmod_poly_init2(temp, poly1->mod.n, trunc); | |
if (len1 >= len2) | |
_nmod_poly_mullow_KSB(temp->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, 0, trunc, poly1->mod); | |
else | |
_nmod_poly_mullow_KSB(temp->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, 0, trunc, poly1->mod); | |
nmod_poly_swap(temp, res); | |
nmod_poly_clear(temp); | |
} else | |
{ | |
nmod_poly_fit_length(res, trunc); | |
if (len1 >= len2) | |
_nmod_poly_mullow_KSB(res->coeffs, poly1->coeffs, len1, | |
poly2->coeffs, len2, 0, trunc, poly1->mod); | |
else | |
_nmod_poly_mullow_KSB(res->coeffs, poly2->coeffs, len2, | |
poly1->coeffs, len1, 0, trunc, poly1->mod); | |
} | |
res->length = trunc; | |
_nmod_poly_normalise(res); | |
} | |
int checkbits[] = { 2, 4, 8, 16, 28, 32, 60, 64, 0 }; | |
/* | |
int checkbits[] = { 64, 60, 32, 28, 16, 8, 4, 2, 0 }; | |
int checkbits[] = { 2, 4, 8, 16, 32, 64, 0 }; | |
int checkbits[] = { 64, 8, 2, 0 }; | |
*/ | |
void | |
randpoly(nmod_poly_t f, flint_rand_t state, slong n) | |
{ | |
slong i; | |
nmod_poly_zero(f); | |
for (i = 0; i < n; i++) | |
nmod_poly_set_coeff_ui(f, i, n_randlimb(state) % f->mod.n); | |
if (f->length < n) | |
nmod_poly_set_coeff_ui(f, n - 1, 1); | |
} | |
#define TIMET(res, expr) \ | |
TIMEIT_START1 expr; TIMEIT_STOP1(tx) \ | |
TIMEIT_START1 expr; TIMEIT_STOP1(ty) \ | |
TIMEIT_START1 expr; TIMEIT_STOP1(tz) \ | |
res = FLINT_MIN(tx, FLINT_MIN(ty, tz)); \ | |
int main() | |
{ | |
nmod_t mod; | |
nmod_poly_t f, g, h; | |
flint_rand_t state; | |
flint_randinit(state); | |
slong i, j, n, ii, bits; | |
double t1, t2, tt; | |
slong iter; | |
slong iters = 1000; | |
slong iters2 = 20; | |
for (ii = 0; (bits = checkbits[ii]) != 0; ii++) | |
{ | |
for (n = 1; n <= 30000; n = FLINT_MAX(n+1, n*1.1)) | |
{ | |
double tx, ty, tz, told, tnew, told2, tnew2, told10, tnew10, tolds, tnews; | |
if (bits == 64) | |
nmod_init(&mod, UWORD_MAX); | |
else | |
nmod_init(&mod, (UWORD(1) << bits) - UWORD(1)); | |
nmod_poly_init(f, mod.n); | |
nmod_poly_init(g, mod.n); | |
nmod_poly_init(h, mod.n); | |
printf("%ld %ld ", bits, n); fflush(stdout); | |
randpoly(f, state, n); | |
randpoly(g, state, n); | |
TIMET(told, nmod_poly_mullow(h, f, g, n)); | |
TIMET(tnew, nmod_poly_mullow2(h, f, g, n)); | |
printf("%.3f ", told / tnew); | |
randpoly(f, state, n); | |
randpoly(g, state, FLINT_MAX(n / 2, 1)); | |
TIMET(told, nmod_poly_mullow(h, f, g, n)); | |
TIMET(tnew, nmod_poly_mullow2(h, f, g, n)); | |
printf("%.3f ", told / tnew); | |
randpoly(f, state, n); | |
TIMET(told, nmod_poly_mullow(h, f, f, n)); | |
TIMET(tnew, nmod_poly_mullow2(h, f, f, n)); | |
printf("%.3f ", told / tnew); | |
randpoly(f, state, n); | |
randpoly(g, state, n); | |
TIMET(told, nmod_poly_mul(h, f, g)); | |
TIMET(tnew, nmod_poly_mul2(h, f, g)); | |
printf("%.3f ", told / tnew); | |
randpoly(f, state, n); | |
randpoly(g, state, 2 * n); | |
TIMET(told, nmod_poly_mul(h, f, g)); | |
TIMET(tnew, nmod_poly_mul2(h, f, g)); | |
printf("%.3f ", told / tnew); | |
randpoly(f, state, 10 * n); | |
randpoly(g, state, n); | |
TIMET(told, nmod_poly_mul(h, f, g)); | |
TIMET(tnew, nmod_poly_mul2(h, f, g)); | |
printf("%.3f ", told / tnew); | |
randpoly(f, state, n); | |
TIMET(told, nmod_poly_mul(h, f, f)); | |
TIMET(tnew, nmod_poly_mul2(h, f, f)); | |
printf("%.3f ", told / tnew); | |
printf("\n"); | |
nmod_poly_clear(f); | |
nmod_poly_clear(g); | |
nmod_poly_clear(h); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment