-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cecb623
commit dc477c8
Showing
2 changed files
with
83 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,42 @@ | ||
import cmath | ||
|
||
MOD = 10**9 + 7 | ||
|
||
|
||
def fft(a, inv=False): | ||
n = len(a) | ||
w = [cmath.rect(1, (-2 if inv else 2) * cmath.pi * i / n) for i in range(n >> 1)] | ||
# FFT implementation based on https://codeforces.com/blog/entry/117947 | ||
|
||
rt = [1] | ||
def fft(P): | ||
n = len(P) | ||
P = list(P) | ||
assert n and (n - 1) & n == 0 | ||
|
||
while 2 * len(rt) < n: | ||
import cmath | ||
root = cmath.exp(2j * cmath.pi / (4 * len(rt))) | ||
rt.extend([r * root for r in rt]) | ||
|
||
k = n | ||
while k > 1: | ||
for i in range(0, n, k): | ||
r = rt[i//k] | ||
for j1 in range(i, i + k//2): | ||
j2 = j1 + k//2 | ||
z = r * P[j2] | ||
P[j2] = P[j1] - z | ||
P[j1] = P[j1] + z | ||
k //= 2 | ||
|
||
rev = [0] * n | ||
for i in range(n): | ||
rev[i] = rev[i >> 1] >> 1 | ||
if i & 1: | ||
rev[i] |= n >> 1 | ||
if i < rev[i]: | ||
a[i], a[rev[i]] = a[rev[i]], a[i] | ||
|
||
step = 2 | ||
while step <= n: | ||
half, diff = step >> 1, n // step | ||
for i in range(0, n, step): | ||
pw = 0 | ||
for j in range(i, i + half): | ||
v = a[j + half] * w[pw] | ||
a[j + half] = a[j] - v | ||
a[j] += v | ||
pw += diff | ||
step <<= 1 | ||
|
||
if inv: | ||
for i in range(n): | ||
a[i] /= n | ||
for i in range(1, n): | ||
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2 | ||
return [P[r] for r in rev] | ||
|
||
def ifft(P): | ||
n = len(P) | ||
return fft([P[-i]/n for i in range(n)]) | ||
|
||
def fft_conv(a, b): | ||
s = len(a) + len(b) - 1 | ||
n = 1 << s.bit_length() | ||
a.extend([0.0] * (n - len(a))) | ||
b.extend([0.0] * (n - len(b))) | ||
def fft_conv(P, Q): | ||
m = len(P) + len(Q) - 1 | ||
n = 1 << m.bit_length() | ||
|
||
fft(a), fft(b) | ||
for i in range(n): | ||
a[i] *= b[i] | ||
fft(a, True) | ||
P = P + [0] * (n - len(P)) | ||
Q = Q + [0] * (n - len(Q)) | ||
P, Q = fft(P), fft(Q) | ||
|
||
a = [a[i].real for i in range(s)] | ||
return a | ||
return ifft([p*q for p,q in zip(P, Q)])[:m] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,82 +1,51 @@ | ||
MOD = 998244353 | ||
MODF = float(MOD) | ||
ROOT = 3.0 | ||
|
||
MAGIC = 6755399441055744.0 | ||
SHRT = 65536.0 | ||
|
||
MODF_INV = 1.0 / MODF | ||
SHRT_INV = 1.0 / SHRT | ||
|
||
fround = lambda x: (x + MAGIC) - MAGIC | ||
fmod = lambda a: a - MODF * fround(MODF_INV * a) | ||
fmul = lambda a, b, c=0.0: fmod(fmod(a * SHRT) * fround(SHRT_INV * b) + a * (b - SHRT * fround(b * SHRT_INV)) + c) | ||
|
||
|
||
def fpow(x, y): | ||
if y == 0: | ||
return 1.0 | ||
|
||
res = 1.0 | ||
while y > 1: | ||
if y & 1 == 1: | ||
res = fmul(res, x) | ||
x = fmul(x, x) | ||
y >>= 1 | ||
|
||
return fmul(res, x) | ||
|
||
|
||
def ntt(a, inv=False): | ||
n = len(a) | ||
w = [1.0] * (n >> 1) | ||
|
||
w[1] = fpow(ROOT, (MOD - 1) // n) | ||
if inv: | ||
w[1] = fpow(w[1], MOD - 2) | ||
|
||
for i in range(2, (n >> 1)): | ||
w[i] = fmul(w[i - 1], w[1]) | ||
|
||
# NTT implementation based on https://codeforces.com/blog/entry/117947 | ||
|
||
# NTT prime | ||
MOD = (119 << 23) + 1 | ||
assert MOD & 1 | ||
|
||
non_quad_res = 2 | ||
while pow(non_quad_res, MOD//2, MOD) != MOD - 1: | ||
non_quad_res += 1 | ||
rt = [1] | ||
|
||
def ntt(P): | ||
n = len(P) | ||
P = list(P) | ||
assert n and (n - 1) & n == 0 | ||
|
||
while 2 * len(rt) < n: | ||
# 4*len(rt)-th root of unity | ||
root = pow(non_quad_res, MOD // (4*len(rt)), MOD) | ||
rt.extend([r * root % MOD for r in rt]) | ||
|
||
k = n | ||
while k > 1: | ||
for i in range(0, n, k): | ||
r = rt[i//k] | ||
for j1 in range(i, i + k//2): | ||
j2 = j1 + k//2 | ||
z = r * P[j2] | ||
P[j2] = (P[j1] - z) % MOD | ||
P[j1] = (P[j1] + z) % MOD | ||
k //= 2 | ||
|
||
rev = [0] * n | ||
for i in range(n): | ||
rev[i] = rev[i >> 1] >> 1 | ||
if i & 1 == 1: | ||
rev[i] |= n >> 1 | ||
if i < rev[i]: | ||
a[i], a[rev[i]] = a[rev[i]], a[i] | ||
|
||
step = 2 | ||
while step <= n: | ||
half, diff = step >> 1, n // step | ||
for i in range(0, n, step): | ||
pw = 0 | ||
for j in range(i, i + half): | ||
v = fmul(w[pw], a[j + half]) | ||
a[j + half] = a[j] - v | ||
a[j] += v | ||
pw += diff | ||
|
||
step <<= 1 | ||
|
||
if inv: | ||
inv_n = fpow(n, MOD - 2) | ||
for i in range(n): | ||
a[i] = fmul(a[i], inv_n) | ||
|
||
|
||
def ntt_conv(a, b): | ||
s = len(a) + len(b) - 1 | ||
n = 1 << s.bit_length() | ||
for i in range(1, n): | ||
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2 | ||
return [P[r] for r in rev] | ||
|
||
a.extend([0.0] * (n - len(a))) | ||
b.extend([0.0] * (n - len(b))) | ||
def intt(P): | ||
n = len(P) | ||
ninv = pow(n, MOD - 2, MOD) | ||
return ntt([P[-i] * ninv % MOD for i in range(n)]) | ||
|
||
ntt(a) | ||
ntt(b) | ||
def ntt_conv(P, Q): | ||
m = len(P) + len(Q) - 1 | ||
n = 1 << m.bit_length() | ||
|
||
for i in range(n): | ||
a[i] = fmul(a[i], b[i]) | ||
P = P + [0] * (n - len(P)) | ||
Q = Q + [0] * (n - len(Q)) | ||
P, Q = ntt(P), ntt(Q) | ||
|
||
ntt(a, True) | ||
del a[s:] | ||
return intt([p * q % MOD for p,q in zip(P, Q)])[:m] |