Skip to content

Commit

Permalink
Merge pull request #174 from GiacomoPope/improve_powmod_and_compose
Browse files Browse the repository at this point in the history
add compose_mod and powmod with large exp
  • Loading branch information
oscarbenjamin authored Aug 6, 2024
2 parents 1178e20 + e3c9d03 commit 069d24d
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/flint/flintlib/nmod_poly.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cdef extern from "flint/nmod_poly.h":
int nmod_poly_equal_trunc(const nmod_poly_t poly1, const nmod_poly_t poly2, slong n)
int nmod_poly_is_zero(const nmod_poly_t poly)
int nmod_poly_is_one(const nmod_poly_t poly)
int nmod_poly_is_gen(const nmod_poly_t poly)
void _nmod_poly_shift_left(mp_ptr res, mp_srcptr poly, slong len, slong k)
void nmod_poly_shift_left(nmod_poly_t res, const nmod_poly_t poly, slong k)
void _nmod_poly_shift_right(mp_ptr res, mp_srcptr poly, slong len, slong k)
Expand Down
47 changes: 39 additions & 8 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,15 @@ def test_nmod_poly():
assert raises(lambda: [] * s, TypeError)
assert raises(lambda: [] // s, TypeError)
assert raises(lambda: [] % s, TypeError)
assert raises(lambda: pow(P([1,2],3), 3, 4), NotImplementedError)
assert raises(lambda: [] % s, TypeError)
assert raises(lambda: s.reverse(-1), ValueError)
assert raises(lambda: s.compose("A"), TypeError)
assert raises(lambda: s.compose_mod(s, "A"), TypeError)
assert raises(lambda: s.compose_mod("A", P([3,6,9],17)), TypeError)
assert raises(lambda: s.compose_mod(s, P([0], 17)), ZeroDivisionError)
assert raises(lambda: pow(s, -1, P([3,6,9],17)), ValueError)
assert raises(lambda: pow(s, 1, "A"), TypeError)
assert raises(lambda: pow(s, "A", P([3,6,9],17)), TypeError)
assert str(P([1,2,3],17)) == "3*x^2 + 2*x + 1"
assert P([1,2,3],17).repr() == "nmod_poly([1, 2, 3], 17)"
p = P([3,4,5],17)
Expand Down Expand Up @@ -2087,6 +2095,18 @@ def test_fmpz_mod_poly():
assert f*f == f**2
assert f*f == f**fmpz(2)

# pow_mod
# assert ui and fmpz exp agree for polynomials and generators
R_gen = R_test.gen()
assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g)
assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g)

# Check other typechecks for pow_mod
assert raises(lambda: pow(f, -2, g), ValueError)
assert raises(lambda: pow(f, 1, "A"), TypeError)
assert raises(lambda: pow(f, "A", g), TypeError)
assert raises(lambda: f.pow_mod(2**32, g, mod_rev_inv="A"), TypeError)

# Shifts
assert raises(lambda: R_test([1,2,3]).left_shift(-1), ValueError)
assert raises(lambda: R_test([1,2,3]).right_shift(-1), ValueError)
Expand Down Expand Up @@ -2118,6 +2138,13 @@ def test_fmpz_mod_poly():
# compose
assert raises(lambda: h.compose("AAA"), TypeError)

# compose mod
mod = R_test([1,2,3,4])
assert f.compose(h) % mod == f.compose_mod(h, mod)
assert raises(lambda: h.compose_mod("AAA", mod), TypeError)
assert raises(lambda: h.compose_mod(f, "AAA"), TypeError)
assert raises(lambda: h.compose_mod(f, R_test(0)), ZeroDivisionError)

# Reverse
assert raises(lambda: h.reverse(degree=-100), ValueError)
assert R_test([-1,-2,-3]).reverse() == R_test([-3,-2,-1])
Expand All @@ -2135,9 +2162,9 @@ def test_fmpz_mod_poly():
assert raises(lambda: f.mulmod(f, "AAA"), TypeError)
assert raises(lambda: f.mulmod("AAA", g), TypeError)

# powmod
assert f.powmod(2, g) == (f*f) % g
assert raises(lambda: f.powmod(2, "AAA"), TypeError)
# pow_mod
assert f.pow_mod(2, g) == (f*f) % g
assert raises(lambda: f.pow_mod(2, "AAA"), TypeError)

# divmod
S, T = f.divmod(g)
Expand Down Expand Up @@ -2635,9 +2662,14 @@ def setbad(obj, i, val):
assert P([1, 1]) ** 2 == P([1, 2, 1])
assert raises(lambda: P([1, 1]) ** -1, ValueError)
assert raises(lambda: P([1, 1]) ** None, TypeError)

# # XXX: Not sure what this should do in general:
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)

# XXX: Not sure what this should do in general:
p = P([1, 1])
mod = P([1, 1])
if type(p) not in [flint.fmpz_mod_poly, flint.nmod_poly]:
assert raises(lambda: pow(p, 2, mod), NotImplementedError)
else:
assert p * p % mod == pow(p, 2, mod)

assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
Expand Down Expand Up @@ -2667,7 +2699,6 @@ def setbad(obj, i, val):
if is_field:
assert P([1, 2, 1]).integral() == P([0, 1, 1, S(1)/3])


def _all_mpolys():
return [
(flint.fmpz_mpoly, flint.fmpz_mpoly_ctx, flint.fmpz, False),
Expand Down
97 changes: 86 additions & 11 deletions src/flint/types/fmpz_mod_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ cdef class fmpz_mod_poly(flint_poly):

def __pow__(self, e, mod=None):
if mod is not None:
raise NotImplementedError
return self.pow_mod(e, mod)

cdef fmpz_mod_poly res
if e < 0:
Expand Down Expand Up @@ -778,11 +778,11 @@ cdef class fmpz_mod_poly(flint_poly):

return evaluations

def compose(self, input):
def compose(self, other):
"""
Returns the composition of two polynomials
To be precise about the order of composition, given ``self``, and ``input``
To be precise about the order of composition, given ``self``, and ``other``
by `f(x)`, `g(x)`, returns `f(g(x))`.
>>> R = fmpz_mod_poly_ctx(163)
Expand All @@ -794,12 +794,45 @@ cdef class fmpz_mod_poly(flint_poly):
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
"""
cdef fmpz_mod_poly res
val = self.ctx.any_as_fmpz_mod_poly(input)
val = self.ctx.any_as_fmpz_mod_poly(other)
if val is NotImplemented:
raise TypeError(f"Cannot compose the polynomial with input: {input}")
raise TypeError(f"Cannot compose the polynomial with input: {other}")

res = self.ctx.new_ctype_poly()
fmpz_mod_poly_compose(res.val, self.val, (<fmpz_mod_poly>val).val, self.ctx.mod.val)
return res

def compose_mod(self, other, modulus):
"""
Returns the composition of two polynomials modulo a third.
To be precise about the order of composition, given ``self``, and ``other``
and ``modulus`` by `f(x)`, `g(x)` and `h(x)`, returns `f(g(x)) \mod h(x)`.
We require that `h(x)` is non-zero.
>>> R = fmpz_mod_poly_ctx(163)
>>> f = R([1,2,3,4,5])
>>> g = R([3,2,1])
>>> h = R([1,0,1,0,1])
>>> f.compose_mod(g, h)
63*x^3 + 100*x^2 + 17*x + 63
>>> g.compose_mod(f, h)
147*x^3 + 159*x^2 + 4*x + 7
"""
cdef fmpz_mod_poly res
val = self.ctx.any_as_fmpz_mod_poly(other)
if val is NotImplemented:
raise TypeError(f"cannot compose the polynomial with input: {other}")

h = self.ctx.any_as_fmpz_mod_poly(modulus)
if h is NotImplemented:
raise TypeError(f"cannot reduce the polynomial with input: {modulus}")

if h.is_zero():
raise ZeroDivisionError("cannot reduce modulo zero")

res = self.ctx.new_ctype_poly()
fmpz_mod_poly_compose_mod(res.val, self.val, (<fmpz_mod_poly>val).val, (<fmpz_mod_poly>h).val, self.ctx.mod.val)
return res

cpdef long length(self):
Expand Down Expand Up @@ -1104,30 +1137,72 @@ cdef class fmpz_mod_poly(flint_poly):
)
return res

def powmod(self, e, modulus):
def pow_mod(self, e, modulus, mod_rev_inv=None):
"""
Returns ``self`` raised to the power ``e`` modulo ``modulus``:
:math:`f^e \mod g`
:math:`f^e \mod g`/
``mod_rev_inv`` is the inverse of the reverse of the modulus,
precomputing it and passing it to ``pow_mod()`` can optimise
powering of polynomials with large exponents.
>>> R = fmpz_mod_poly_ctx(163)
>>> x = R.gen()
>>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65
>>> g = 43*x**6 + 91*x**5 + 77*x**4 + 113*x**3 + 71*x**2 + 132*x + 60
>>> mod = x**4 + 93*x**3 + 78*x**2 + 72*x + 149
>>>
>>> f.powmod(123, mod)
>>> f.pow_mod(123, mod)
3*x^3 + 25*x^2 + 115*x + 161
>>> f.pow_mod(2**64, mod)
52*x^3 + 96*x^2 + 136*x + 9
>>> mod_rev_inv = mod.reverse().inverse_series_trunc(4)
>>> f.pow_mod(2**64, mod, mod_rev_inv)
52*x^3 + 96*x^2 + 136*x + 9
"""
cdef fmpz_mod_poly res

if e < 0:
raise ValueError("Exponent must be non-negative")

modulus = self.ctx.any_as_fmpz_mod_poly(modulus)
if modulus is NotImplemented:
raise TypeError(f"Cannot interpret {modulus} as a polynomial")

# Output polynomial
res = self.ctx.new_ctype_poly()
fmpz_mod_poly_powmod_ui_binexp(
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
)

# For small exponents, use a simple binary exponentiation method
if e.bit_length() < 32:
fmpz_mod_poly_powmod_ui_binexp(
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
)
return res

# For larger exponents we need to cast e to an fmpz first
e_fmpz = any_as_fmpz(e)
if e_fmpz is NotImplemented:
raise TypeError(f"exponent cannot be cast to an fmpz type: {e = }")

# To optimise powering, we precompute the inverse of the reverse of the modulus
if mod_rev_inv is not None:
mod_rev_inv = self.ctx.any_as_fmpz_mod_poly(mod_rev_inv)
if mod_rev_inv is NotImplemented:
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
else:
mod_rev_inv = modulus.reverse().inverse_series_trunc(modulus.length())

# Use windowed exponentiation optimisation when self = x
if self.is_gen():
fmpz_mod_poly_powmod_x_fmpz_preinv(
res.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
)
return res

# Otherwise using binary exponentiation for all other inputs
fmpz_mod_poly_powmod_fmpz_binexp_preinv(
res.val, self.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
)
return res

def divmod(self, other):
Expand Down
Loading

0 comments on commit 069d24d

Please sign in to comment.