From 34a3b49d8ce6d219ee430156a68ec368e8321e61 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Thu, 22 Aug 2024 01:31:45 +0100 Subject: [PATCH] Add nmod_poly_ctx --- src/flint/types/nmod_poly.pxd | 20 ++++- src/flint/types/nmod_poly.pyx | 152 ++++++++++++++++++++++------------ 2 files changed, 115 insertions(+), 57 deletions(-) diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index c0d1cd85..bd86887b 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -1,10 +1,26 @@ -from flint.flint_base.flint_base cimport flint_poly - +from flint.flintlib.nmod cimport nmod_t from flint.flintlib.nmod_poly cimport nmod_poly_t from flint.flintlib.flint cimport mp_limb_t +from flint.flint_base.flint_base cimport flint_poly + +from flint.types.nmod cimport nmod_ctx + + +cdef class nmod_poly_ctx: + cdef nmod_ctx ctx + cdef nmod_t mod + cdef bint _is_prime + + cdef nmod_poly_set_list(self, nmod_poly_t poly, list val) + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1 + cdef any_as_nmod_poly(self, obj) + + cdef class nmod_poly(flint_poly): cdef nmod_poly_t val + cdef nmod_poly_ctx ctx + cpdef long length(self) cpdef long degree(self) cpdef mp_limb_t modulus(self) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index cb92a7fa..4e821692 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -5,49 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz from flint.types.fmpz_poly cimport any_as_fmpz_poly from flint.types.fmpz_poly cimport fmpz_poly from flint.types.nmod cimport any_as_nmod_ctx -from flint.types.nmod cimport nmod +from flint.types.nmod cimport nmod, nmod_ctx from flint.flintlib.nmod_vec cimport * from flint.flintlib.nmod_poly cimport * from flint.flintlib.nmod_poly_factor cimport * from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly -from flint.flintlib.ulong_extras cimport n_gcdinv +from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime from flint.utils.flint_exceptions import DomainError -cdef any_as_nmod_poly(obj, nmod_t mod): - cdef nmod_poly r - cdef mp_limb_t v - # XXX: should check that modulus is the same here, and not all over the place - if typecheck(obj, nmod_poly): +_nmod_poly_ctx_cache = {} + + +cdef nmod_ctx any_as_nmod_poly_ctx(obj): + """Convert an int to an nmod_ctx.""" + if typecheck(obj, nmod_poly_ctx): return obj - if any_as_nmod(&v, obj, mod): - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, mod.n) - nmod_poly_set_coeff_ui(r.val, 0, v) - return r - x = any_as_fmpz_poly(obj) - if x is not NotImplemented: - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this? - fmpz_poly_get_nmod_poly(r.val, (x).val) - return r + if typecheck(obj, int): + ctx = _nmod_poly_ctx_cache.get(obj) + if ctx is None: + ctx = nmod_poly_ctx(obj) + _nmod_poly_ctx_cache[obj] = ctx + return ctx return NotImplemented -cdef nmod_poly_set_list(nmod_poly_t poly, list val): - cdef long i, n - cdef nmod_t mod - cdef mp_limb_t v - nmod_init(&mod, nmod_poly_modulus(poly)) # XXX - n = PyList_GET_SIZE(val) - nmod_poly_fit_length(poly, n) - for i from 0 <= i < n: - c = val[i] - if any_as_nmod(&v, val[i], mod): - nmod_poly_set_coeff_ui(poly, i, v) - else: - raise TypeError("unsupported coefficient in list") + +cdef class nmod_poly_ctx: + """ + Context object for creating :class:`~.nmod_poly` initalised + with modulus :math:`N`. + + >>> nmod_ctx(17) + nmod_ctx(17) + + """ + def __init__(self, mod): + cdef mp_limb_t m + m = mod + nmod_init(&self.mod, m) + self.ctx = nmod_ctx(mod) + self._is_prime = n_is_prime(m) + + cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.ctx.any_as_nmod(val, obj) + + cdef any_as_nmod_poly(self, obj): + cdef nmod_poly r + cdef mp_limb_t v + # XXX: should check that modulus is the same here, and not all over the place + if typecheck(obj, nmod_poly): + return obj + if self.ctx.any_as_nmod(&v, obj): + r = nmod_poly.__new__(nmod_poly) + nmod_poly_init(r.val, self.mod.n) + nmod_poly_set_coeff_ui(r.val, 0, v) + return r + x = any_as_fmpz_poly(obj) + if x is not NotImplemented: + r = nmod_poly.__new__(nmod_poly) + nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this? + fmpz_poly_get_nmod_poly(r.val, (x).val) + return r + return NotImplemented + + cdef nmod_poly_set_list(self, nmod_poly_t poly, list val): + cdef long i, n + cdef mp_limb_t v + n = PyList_GET_SIZE(val) + nmod_poly_fit_length(poly, n) + for i from 0 <= i < n: + c = val[i] + if self.any_as_nmod(&v, val[i]): + nmod_poly_set_coeff_ui(poly, i, v) + else: + raise TypeError("unsupported coefficient in list") + cdef class nmod_poly(flint_poly): """ @@ -79,24 +113,32 @@ cdef class nmod_poly(flint_poly): def __dealloc__(self): nmod_poly_clear(self.val) - def __init__(self, val=None, ulong mod=0): + def __init__(self, val=None, mod=0): cdef ulong m2 cdef mp_limb_t v + cdef nmod_poly_ctx ctx + if typecheck(val, nmod_poly): m2 = nmod_poly_modulus((val).val) if m2 != mod: raise ValueError("different moduli!") nmod_poly_init(self.val, m2) nmod_poly_set(self.val, (val).val) + self.ctx = (val).ctx else: if mod == 0: raise ValueError("a nonzero modulus is required") - nmod_poly_init(self.val, mod) + ctx = any_as_nmod_poly_ctx(mod) + if ctx is NotImplemented: + raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod)) + + self.ctx = ctx + nmod_poly_init(self.val, ctx.mod.n) if typecheck(val, fmpz_poly): fmpz_poly_get_nmod_poly(self.val, (val).val) elif typecheck(val, list): - nmod_poly_set_list(self.val, val) - elif any_as_nmod(&v, val, self.val.mod): + ctx.nmod_poly_set_list(self.val, val) + elif ctx.any_as_nmod(&v, val): nmod_poly_fit_length(self.val, 1) nmod_poly_set_coeff_ui(self.val, 0, v) else: @@ -178,7 +220,7 @@ cdef class nmod_poly(flint_poly): cdef mp_limb_t v if i < 0: raise ValueError("cannot assign to index < 0 of polynomial") - if any_as_nmod(&v, x, self.val.mod): + if self.ctx.any_as_nmod(&v, x): nmod_poly_set_coeff_ui(self.val, i, v) else: raise TypeError("cannot set element of type %s" % type(x)) @@ -291,7 +333,7 @@ cdef class nmod_poly(flint_poly): 9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1 """ cdef nmod_poly res - other = any_as_nmod_poly(other, (self).val.mod) + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") res = nmod_poly.__new__(nmod_poly) @@ -316,11 +358,11 @@ cdef class nmod_poly(flint_poly): 147*x^3 + 159*x^2 + 4*x + 7 """ cdef nmod_poly res - g = any_as_nmod_poly(other, self.val.mod) + g = self.ctx.any_as_nmod_poly(other) if g is NotImplemented: raise TypeError(f"cannot convert {other = } to nmod_poly") - h = any_as_nmod_poly(modulus, self.val.mod) + h = self.any_as_nmod_poly(modulus) if h is NotImplemented: raise TypeError(f"cannot convert {modulus = } to nmod_poly") @@ -334,11 +376,11 @@ cdef class nmod_poly(flint_poly): def __call__(self, other): cdef mp_limb_t c - if any_as_nmod(&c, other, self.val.mod): + if self.ctx.any_as_nmod(&c, other): v = nmod(0, self.modulus()) (v).val = nmod_poly_evaluate_nmod(self.val, c) return v - t = any_as_nmod_poly(other, self.val.mod) + t = self.ctx.any_as_nmod_poly(other) if t is not NotImplemented: r = nmod_poly.__new__(nmod_poly) nmod_poly_init_preinv((r).val, self.val.mod.n, self.val.mod.ninv) @@ -369,7 +411,7 @@ cdef class nmod_poly(flint_poly): def _add_(s, t): cdef nmod_poly r - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: @@ -395,20 +437,20 @@ cdef class nmod_poly(flint_poly): return r def __sub__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._sub_(t) def __rsub__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return t._sub_(s) def _mul_(s, t): cdef nmod_poly r - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: @@ -425,7 +467,7 @@ cdef class nmod_poly(flint_poly): return s._mul_(t) def __truediv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = s._divmod_(t) @@ -434,7 +476,7 @@ cdef class nmod_poly(flint_poly): return res def __rtruediv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = t._divmod_(s) @@ -454,13 +496,13 @@ cdef class nmod_poly(flint_poly): return r def __floordiv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return s._floordiv_(t) def __rfloordiv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return t._floordiv_(s) @@ -479,13 +521,13 @@ cdef class nmod_poly(flint_poly): return P, Q def __divmod__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return s._divmod_(t) def __rdivmod__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.any_as_nmod_poly(t) if t is NotImplemented: return t return t._divmod_(s) @@ -534,7 +576,7 @@ cdef class nmod_poly(flint_poly): if e < 0: raise ValueError("Exponent must be non-negative") - modulus = any_as_nmod_poly(modulus, (self).val.mod) + modulus = self.ctx.any_as_nmod_poly(modulus) if modulus is NotImplemented: raise TypeError("cannot convert input to nmod_poly") @@ -556,7 +598,7 @@ cdef class nmod_poly(flint_poly): # To optimise powering, we precompute the inverse of the reverse of the modulus if mod_rev_inv is not None: - mod_rev_inv = any_as_nmod_poly(mod_rev_inv, (self).val.mod) + mod_rev_inv = self.any_as_nmod_poly(mod_rev_inv) if mod_rev_inv is NotImplemented: raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial") else: @@ -585,7 +627,7 @@ cdef class nmod_poly(flint_poly): """ cdef nmod_poly res - other = any_as_nmod_poly(other, (self).val.mod) + other = self.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") if self.val.mod.n != (other).val.mod.n: @@ -597,7 +639,7 @@ cdef class nmod_poly(flint_poly): def xgcd(self, other): cdef nmod_poly res1, res2, res3 - other = any_as_nmod_poly(other, (self).val.mod) + other = self.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to fmpq_poly") res1 = nmod_poly.__new__(nmod_poly)