Skip to content

Commit

Permalink
Add nmod_poly_ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarbenjamin committed Aug 22, 2024
1 parent c2b1f67 commit 34a3b49
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 57 deletions.
20 changes: 18 additions & 2 deletions src/flint/types/nmod_poly.pxd
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from flint.flint_base.flint_base cimport flint_poly

from flint.flintlib.nmod cimport nmod_t
from flint.flintlib.nmod_poly cimport nmod_poly_t
from flint.flintlib.flint cimport mp_limb_t

from flint.flint_base.flint_base cimport flint_poly

from flint.types.nmod cimport nmod_ctx


cdef class nmod_poly_ctx:
cdef nmod_ctx ctx
cdef nmod_t mod
cdef bint _is_prime

cdef nmod_poly_set_list(self, nmod_poly_t poly, list val)
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
cdef any_as_nmod_poly(self, obj)


cdef class nmod_poly(flint_poly):
cdef nmod_poly_t val
cdef nmod_poly_ctx ctx

cpdef long length(self)
cpdef long degree(self)
cpdef mp_limb_t modulus(self)
152 changes: 97 additions & 55 deletions src/flint/types/nmod_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
from flint.types.fmpz_poly cimport any_as_fmpz_poly
from flint.types.fmpz_poly cimport fmpz_poly
from flint.types.nmod cimport any_as_nmod_ctx
from flint.types.nmod cimport nmod
from flint.types.nmod cimport nmod, nmod_ctx

from flint.flintlib.nmod_vec cimport *
from flint.flintlib.nmod_poly cimport *
from flint.flintlib.nmod_poly_factor cimport *
from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
from flint.flintlib.ulong_extras cimport n_gcdinv
from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime

from flint.utils.flint_exceptions import DomainError


cdef any_as_nmod_poly(obj, nmod_t mod):
cdef nmod_poly r
cdef mp_limb_t v
# XXX: should check that modulus is the same here, and not all over the place
if typecheck(obj, nmod_poly):
_nmod_poly_ctx_cache = {}


cdef nmod_ctx any_as_nmod_poly_ctx(obj):
"""Convert an int to an nmod_ctx."""
if typecheck(obj, nmod_poly_ctx):
return obj
if any_as_nmod(&v, obj, mod):
r = nmod_poly.__new__(nmod_poly)
nmod_poly_init(r.val, mod.n)
nmod_poly_set_coeff_ui(r.val, 0, v)
return r
x = any_as_fmpz_poly(obj)
if x is not NotImplemented:
r = nmod_poly.__new__(nmod_poly)
nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
fmpz_poly_get_nmod_poly(r.val, (<fmpz_poly>x).val)
return r
if typecheck(obj, int):
ctx = _nmod_poly_ctx_cache.get(obj)
if ctx is None:
ctx = nmod_poly_ctx(obj)
_nmod_poly_ctx_cache[obj] = ctx
return ctx
return NotImplemented

cdef nmod_poly_set_list(nmod_poly_t poly, list val):
cdef long i, n
cdef nmod_t mod
cdef mp_limb_t v
nmod_init(&mod, nmod_poly_modulus(poly)) # XXX
n = PyList_GET_SIZE(val)
nmod_poly_fit_length(poly, n)
for i from 0 <= i < n:
c = val[i]
if any_as_nmod(&v, val[i], mod):
nmod_poly_set_coeff_ui(poly, i, v)
else:
raise TypeError("unsupported coefficient in list")

cdef class nmod_poly_ctx:
"""
Context object for creating :class:`~.nmod_poly` initalised
with modulus :math:`N`.
>>> nmod_ctx(17)
nmod_ctx(17)
"""
def __init__(self, mod):
cdef mp_limb_t m
m = mod
nmod_init(&self.mod, m)
self.ctx = nmod_ctx(mod)
self._is_prime = n_is_prime(m)

cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
return self.ctx.any_as_nmod(val, obj)

cdef any_as_nmod_poly(self, obj):
cdef nmod_poly r
cdef mp_limb_t v
# XXX: should check that modulus is the same here, and not all over the place
if typecheck(obj, nmod_poly):
return obj
if self.ctx.any_as_nmod(&v, obj):
r = nmod_poly.__new__(nmod_poly)
nmod_poly_init(r.val, self.mod.n)
nmod_poly_set_coeff_ui(r.val, 0, v)
return r
x = any_as_fmpz_poly(obj)
if x is not NotImplemented:
r = nmod_poly.__new__(nmod_poly)
nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
fmpz_poly_get_nmod_poly(r.val, (<fmpz_poly>x).val)
return r
return NotImplemented

cdef nmod_poly_set_list(self, nmod_poly_t poly, list val):
cdef long i, n
cdef mp_limb_t v
n = PyList_GET_SIZE(val)
nmod_poly_fit_length(poly, n)
for i from 0 <= i < n:
c = val[i]
if self.any_as_nmod(&v, val[i]):
nmod_poly_set_coeff_ui(poly, i, v)
else:
raise TypeError("unsupported coefficient in list")


cdef class nmod_poly(flint_poly):
"""
Expand Down Expand Up @@ -79,24 +113,32 @@ cdef class nmod_poly(flint_poly):
def __dealloc__(self):
nmod_poly_clear(self.val)

def __init__(self, val=None, ulong mod=0):
def __init__(self, val=None, mod=0):
cdef ulong m2
cdef mp_limb_t v
cdef nmod_poly_ctx ctx

if typecheck(val, nmod_poly):
m2 = nmod_poly_modulus((<nmod_poly>val).val)
if m2 != mod:
raise ValueError("different moduli!")
nmod_poly_init(self.val, m2)
nmod_poly_set(self.val, (<nmod_poly>val).val)
self.ctx = (<nmod_poly>val).ctx
else:
if mod == 0:
raise ValueError("a nonzero modulus is required")
nmod_poly_init(self.val, mod)
ctx = any_as_nmod_poly_ctx(mod)
if ctx is NotImplemented:
raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod))

self.ctx = ctx
nmod_poly_init(self.val, ctx.mod.n)
if typecheck(val, fmpz_poly):
fmpz_poly_get_nmod_poly(self.val, (<fmpz_poly>val).val)
elif typecheck(val, list):
nmod_poly_set_list(self.val, val)
elif any_as_nmod(&v, val, self.val.mod):
ctx.nmod_poly_set_list(self.val, val)
elif ctx.any_as_nmod(&v, val):
nmod_poly_fit_length(self.val, 1)
nmod_poly_set_coeff_ui(self.val, 0, v)
else:
Expand Down Expand Up @@ -178,7 +220,7 @@ cdef class nmod_poly(flint_poly):
cdef mp_limb_t v
if i < 0:
raise ValueError("cannot assign to index < 0 of polynomial")
if any_as_nmod(&v, x, self.val.mod):
if self.ctx.any_as_nmod(&v, x):
nmod_poly_set_coeff_ui(self.val, i, v)
else:
raise TypeError("cannot set element of type %s" % type(x))
Expand Down Expand Up @@ -291,7 +333,7 @@ cdef class nmod_poly(flint_poly):
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
"""
cdef nmod_poly res
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
other = self.ctx.any_as_nmod_poly(other)
if other is NotImplemented:
raise TypeError("cannot convert input to nmod_poly")
res = nmod_poly.__new__(nmod_poly)
Expand All @@ -316,11 +358,11 @@ cdef class nmod_poly(flint_poly):
147*x^3 + 159*x^2 + 4*x + 7
"""
cdef nmod_poly res
g = any_as_nmod_poly(other, self.val.mod)
g = self.ctx.any_as_nmod_poly(other)
if g is NotImplemented:
raise TypeError(f"cannot convert {other = } to nmod_poly")

h = any_as_nmod_poly(modulus, self.val.mod)
h = self.any_as_nmod_poly(modulus)
if h is NotImplemented:
raise TypeError(f"cannot convert {modulus = } to nmod_poly")

Expand All @@ -334,11 +376,11 @@ cdef class nmod_poly(flint_poly):

def __call__(self, other):
cdef mp_limb_t c
if any_as_nmod(&c, other, self.val.mod):
if self.ctx.any_as_nmod(&c, other):
v = nmod(0, self.modulus())
(<nmod>v).val = nmod_poly_evaluate_nmod(self.val, c)
return v
t = any_as_nmod_poly(other, self.val.mod)
t = self.ctx.any_as_nmod_poly(other)
if t is not NotImplemented:
r = nmod_poly.__new__(nmod_poly)
nmod_poly_init_preinv((<nmod_poly>r).val, self.val.mod.n, self.val.mod.ninv)
Expand Down Expand Up @@ -369,7 +411,7 @@ cdef class nmod_poly(flint_poly):

def _add_(s, t):
cdef nmod_poly r
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.ctx.any_as_nmod_poly(t)
if t is NotImplemented:
return t
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
Expand All @@ -395,20 +437,20 @@ cdef class nmod_poly(flint_poly):
return r

def __sub__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.ctx.any_as_nmod_poly(t)
if t is NotImplemented:
return t
return s._sub_(t)

def __rsub__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
return t._sub_(s)

def _mul_(s, t):
cdef nmod_poly r
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
Expand All @@ -425,7 +467,7 @@ cdef class nmod_poly(flint_poly):
return s._mul_(t)

def __truediv__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
res, r = s._divmod_(t)
Expand All @@ -434,7 +476,7 @@ cdef class nmod_poly(flint_poly):
return res

def __rtruediv__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
res, r = t._divmod_(s)
Expand All @@ -454,13 +496,13 @@ cdef class nmod_poly(flint_poly):
return r

def __floordiv__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
return s._floordiv_(t)

def __rfloordiv__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
return t._floordiv_(s)
Expand All @@ -479,13 +521,13 @@ cdef class nmod_poly(flint_poly):
return P, Q

def __divmod__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
return s._divmod_(t)

def __rdivmod__(s, t):
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
t = s.any_as_nmod_poly(t)
if t is NotImplemented:
return t
return t._divmod_(s)
Expand Down Expand Up @@ -534,7 +576,7 @@ cdef class nmod_poly(flint_poly):
if e < 0:
raise ValueError("Exponent must be non-negative")

modulus = any_as_nmod_poly(modulus, (<nmod_poly>self).val.mod)
modulus = self.ctx.any_as_nmod_poly(modulus)
if modulus is NotImplemented:
raise TypeError("cannot convert input to nmod_poly")

Expand All @@ -556,7 +598,7 @@ cdef class nmod_poly(flint_poly):

# To optimise powering, we precompute the inverse of the reverse of the modulus
if mod_rev_inv is not None:
mod_rev_inv = any_as_nmod_poly(mod_rev_inv, (<nmod_poly>self).val.mod)
mod_rev_inv = self.any_as_nmod_poly(mod_rev_inv)
if mod_rev_inv is NotImplemented:
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
else:
Expand Down Expand Up @@ -585,7 +627,7 @@ cdef class nmod_poly(flint_poly):
"""
cdef nmod_poly res
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
other = self.any_as_nmod_poly(other)
if other is NotImplemented:
raise TypeError("cannot convert input to nmod_poly")
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
Expand All @@ -597,7 +639,7 @@ cdef class nmod_poly(flint_poly):

def xgcd(self, other):
cdef nmod_poly res1, res2, res3
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
other = self.any_as_nmod_poly(other)
if other is NotImplemented:
raise TypeError("cannot convert input to fmpq_poly")
res1 = nmod_poly.__new__(nmod_poly)
Expand Down

0 comments on commit 34a3b49

Please sign in to comment.