Skip to content

Commit

Permalink
New fft and ntt
Browse files Browse the repository at this point in the history
  • Loading branch information
bjorn-martinsson committed Jul 9, 2023
1 parent cecb623 commit dc477c8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 118 deletions.
78 changes: 37 additions & 41 deletions pyrival/algebra/fft.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,42 @@
import cmath

MOD = 10**9 + 7


def fft(a, inv=False):
n = len(a)
w = [cmath.rect(1, (-2 if inv else 2) * cmath.pi * i / n) for i in range(n >> 1)]
# FFT implementation based on https://codeforces.com/blog/entry/117947

rt = [1]
def fft(P):
n = len(P)
P = list(P)
assert n and (n - 1) & n == 0

while 2 * len(rt) < n:
import cmath
root = cmath.exp(2j * cmath.pi / (4 * len(rt)))
rt.extend([r * root for r in rt])

k = n
while k > 1:
for i in range(0, n, k):
r = rt[i//k]
for j1 in range(i, i + k//2):
j2 = j1 + k//2
z = r * P[j2]
P[j2] = P[j1] - z
P[j1] = P[j1] + z
k //= 2

rev = [0] * n
for i in range(n):
rev[i] = rev[i >> 1] >> 1
if i & 1:
rev[i] |= n >> 1
if i < rev[i]:
a[i], a[rev[i]] = a[rev[i]], a[i]

step = 2
while step <= n:
half, diff = step >> 1, n // step
for i in range(0, n, step):
pw = 0
for j in range(i, i + half):
v = a[j + half] * w[pw]
a[j + half] = a[j] - v
a[j] += v
pw += diff
step <<= 1

if inv:
for i in range(n):
a[i] /= n
for i in range(1, n):
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2
return [P[r] for r in rev]

def ifft(P):
n = len(P)
return fft([P[-i]/n for i in range(n)])

def fft_conv(a, b):
s = len(a) + len(b) - 1
n = 1 << s.bit_length()
a.extend([0.0] * (n - len(a)))
b.extend([0.0] * (n - len(b)))
def fft_conv(P, Q):
m = len(P) + len(Q) - 1
n = 1 << m.bit_length()

fft(a), fft(b)
for i in range(n):
a[i] *= b[i]
fft(a, True)
P = P + [0] * (n - len(P))
Q = Q + [0] * (n - len(Q))
P, Q = fft(P), fft(Q)

a = [a[i].real for i in range(s)]
return a
return ifft([p*q for p,q in zip(P, Q)])[:m]
123 changes: 46 additions & 77 deletions pyrival/algebra/ntt.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,51 @@
MOD = 998244353
MODF = float(MOD)
ROOT = 3.0

MAGIC = 6755399441055744.0
SHRT = 65536.0

MODF_INV = 1.0 / MODF
SHRT_INV = 1.0 / SHRT

fround = lambda x: (x + MAGIC) - MAGIC
fmod = lambda a: a - MODF * fround(MODF_INV * a)
fmul = lambda a, b, c=0.0: fmod(fmod(a * SHRT) * fround(SHRT_INV * b) + a * (b - SHRT * fround(b * SHRT_INV)) + c)


def fpow(x, y):
if y == 0:
return 1.0

res = 1.0
while y > 1:
if y & 1 == 1:
res = fmul(res, x)
x = fmul(x, x)
y >>= 1

return fmul(res, x)


def ntt(a, inv=False):
n = len(a)
w = [1.0] * (n >> 1)

w[1] = fpow(ROOT, (MOD - 1) // n)
if inv:
w[1] = fpow(w[1], MOD - 2)

for i in range(2, (n >> 1)):
w[i] = fmul(w[i - 1], w[1])

# NTT implementation based on https://codeforces.com/blog/entry/117947

# NTT prime
MOD = (119 << 23) + 1
assert MOD & 1

non_quad_res = 2
while pow(non_quad_res, MOD//2, MOD) != MOD - 1:
non_quad_res += 1
rt = [1]

def ntt(P):
n = len(P)
P = list(P)
assert n and (n - 1) & n == 0

while 2 * len(rt) < n:
# 4*len(rt)-th root of unity
root = pow(non_quad_res, MOD // (4*len(rt)), MOD)
rt.extend([r * root % MOD for r in rt])

k = n
while k > 1:
for i in range(0, n, k):
r = rt[i//k]
for j1 in range(i, i + k//2):
j2 = j1 + k//2
z = r * P[j2]
P[j2] = (P[j1] - z) % MOD
P[j1] = (P[j1] + z) % MOD
k //= 2

rev = [0] * n
for i in range(n):
rev[i] = rev[i >> 1] >> 1
if i & 1 == 1:
rev[i] |= n >> 1
if i < rev[i]:
a[i], a[rev[i]] = a[rev[i]], a[i]

step = 2
while step <= n:
half, diff = step >> 1, n // step
for i in range(0, n, step):
pw = 0
for j in range(i, i + half):
v = fmul(w[pw], a[j + half])
a[j + half] = a[j] - v
a[j] += v
pw += diff

step <<= 1

if inv:
inv_n = fpow(n, MOD - 2)
for i in range(n):
a[i] = fmul(a[i], inv_n)


def ntt_conv(a, b):
s = len(a) + len(b) - 1
n = 1 << s.bit_length()
for i in range(1, n):
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2
return [P[r] for r in rev]

a.extend([0.0] * (n - len(a)))
b.extend([0.0] * (n - len(b)))
def intt(P):
n = len(P)
ninv = pow(n, MOD - 2, MOD)
return ntt([P[-i] * ninv % MOD for i in range(n)])

ntt(a)
ntt(b)
def ntt_conv(P, Q):
m = len(P) + len(Q) - 1
n = 1 << m.bit_length()

for i in range(n):
a[i] = fmul(a[i], b[i])
P = P + [0] * (n - len(P))
Q = Q + [0] * (n - len(Q))
P, Q = ntt(P), ntt(Q)

ntt(a, True)
del a[s:]
return intt([p * q % MOD for p,q in zip(P, Q)])[:m]

0 comments on commit dc477c8

Please sign in to comment.