diff --git a/pyrival/algebra/fft.py b/pyrival/algebra/fft.py index f212fbc..7c66889 100644 --- a/pyrival/algebra/fft.py +++ b/pyrival/algebra/fft.py @@ -1,46 +1,42 @@ -import cmath - -MOD = 10**9 + 7 - - -def fft(a, inv=False): - n = len(a) - w = [cmath.rect(1, (-2 if inv else 2) * cmath.pi * i / n) for i in range(n >> 1)] +# FFT implementation based on https://codeforces.com/blog/entry/117947 + +rt = [1] +def fft(P): + n = len(P) + P = list(P) + assert n and (n - 1) & n == 0 + + while 2 * len(rt) < n: + import cmath + root = cmath.exp(2j * cmath.pi / (4 * len(rt))) + rt.extend([r * root for r in rt]) + + k = n + while k > 1: + for i in range(0, n, k): + r = rt[i//k] + for j1 in range(i, i + k//2): + j2 = j1 + k//2 + z = r * P[j2] + P[j2] = P[j1] - z + P[j1] = P[j1] + z + k //= 2 + rev = [0] * n - for i in range(n): - rev[i] = rev[i >> 1] >> 1 - if i & 1: - rev[i] |= n >> 1 - if i < rev[i]: - a[i], a[rev[i]] = a[rev[i]], a[i] - - step = 2 - while step <= n: - half, diff = step >> 1, n // step - for i in range(0, n, step): - pw = 0 - for j in range(i, i + half): - v = a[j + half] * w[pw] - a[j + half] = a[j] - v - a[j] += v - pw += diff - step <<= 1 - - if inv: - for i in range(n): - a[i] /= n + for i in range(1, n): + rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2 + return [P[r] for r in rev] +def ifft(P): + n = len(P) + return fft([P[-i]/n for i in range(n)]) -def fft_conv(a, b): - s = len(a) + len(b) - 1 - n = 1 << s.bit_length() - a.extend([0.0] * (n - len(a))) - b.extend([0.0] * (n - len(b))) +def fft_conv(P, Q): + m = len(P) + len(Q) - 1 + n = 1 << m.bit_length() - fft(a), fft(b) - for i in range(n): - a[i] *= b[i] - fft(a, True) + P = P + [0] * (n - len(P)) + Q = Q + [0] * (n - len(Q)) + P, Q = fft(P), fft(Q) - a = [a[i].real for i in range(s)] - return a + return ifft([p*q for p,q in zip(P, Q)])[:m] diff --git a/pyrival/algebra/ntt.py b/pyrival/algebra/ntt.py index f30cf64..5d93477 100644 --- a/pyrival/algebra/ntt.py +++ b/pyrival/algebra/ntt.py @@ -1,82 +1,51 @@ -MOD = 998244353 -MODF = float(MOD) -ROOT = 3.0 - -MAGIC = 6755399441055744.0 -SHRT = 65536.0 - -MODF_INV = 1.0 / MODF -SHRT_INV = 1.0 / SHRT - -fround = lambda x: (x + MAGIC) - MAGIC -fmod = lambda a: a - MODF * fround(MODF_INV * a) -fmul = lambda a, b, c=0.0: fmod(fmod(a * SHRT) * fround(SHRT_INV * b) + a * (b - SHRT * fround(b * SHRT_INV)) + c) - - -def fpow(x, y): - if y == 0: - return 1.0 - - res = 1.0 - while y > 1: - if y & 1 == 1: - res = fmul(res, x) - x = fmul(x, x) - y >>= 1 - - return fmul(res, x) - - -def ntt(a, inv=False): - n = len(a) - w = [1.0] * (n >> 1) - - w[1] = fpow(ROOT, (MOD - 1) // n) - if inv: - w[1] = fpow(w[1], MOD - 2) - - for i in range(2, (n >> 1)): - w[i] = fmul(w[i - 1], w[1]) - +# NTT implementation based on https://codeforces.com/blog/entry/117947 + +# NTT prime +MOD = (119 << 23) + 1 +assert MOD & 1 + +non_quad_res = 2 +while pow(non_quad_res, MOD//2, MOD) != MOD - 1: + non_quad_res += 1 +rt = [1] + +def ntt(P): + n = len(P) + P = list(P) + assert n and (n - 1) & n == 0 + + while 2 * len(rt) < n: + # 4*len(rt)-th root of unity + root = pow(non_quad_res, MOD // (4*len(rt)), MOD) + rt.extend([r * root % MOD for r in rt]) + + k = n + while k > 1: + for i in range(0, n, k): + r = rt[i//k] + for j1 in range(i, i + k//2): + j2 = j1 + k//2 + z = r * P[j2] + P[j2] = (P[j1] - z) % MOD + P[j1] = (P[j1] + z) % MOD + k //= 2 + rev = [0] * n - for i in range(n): - rev[i] = rev[i >> 1] >> 1 - if i & 1 == 1: - rev[i] |= n >> 1 - if i < rev[i]: - a[i], a[rev[i]] = a[rev[i]], a[i] - - step = 2 - while step <= n: - half, diff = step >> 1, n // step - for i in range(0, n, step): - pw = 0 - for j in range(i, i + half): - v = fmul(w[pw], a[j + half]) - a[j + half] = a[j] - v - a[j] += v - pw += diff - - step <<= 1 - - if inv: - inv_n = fpow(n, MOD - 2) - for i in range(n): - a[i] = fmul(a[i], inv_n) - - -def ntt_conv(a, b): - s = len(a) + len(b) - 1 - n = 1 << s.bit_length() + for i in range(1, n): + rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2 + return [P[r] for r in rev] - a.extend([0.0] * (n - len(a))) - b.extend([0.0] * (n - len(b))) +def intt(P): + n = len(P) + ninv = pow(n, MOD - 2, MOD) + return ntt([P[-i] * ninv % MOD for i in range(n)]) - ntt(a) - ntt(b) +def ntt_conv(P, Q): + m = len(P) + len(Q) - 1 + n = 1 << m.bit_length() - for i in range(n): - a[i] = fmul(a[i], b[i]) + P = P + [0] * (n - len(P)) + Q = Q + [0] * (n - len(Q)) + P, Q = ntt(P), ntt(Q) - ntt(a, True) - del a[s:] + return intt([p * q % MOD for p,q in zip(P, Q)])[:m]