Skip to content

Commit

Permalink
Merge pull request #192 from Jake-Moss/mpoly_update
Browse files Browse the repository at this point in the history
Prefer cdef to def in mpoly operands
  • Loading branch information
oscarbenjamin authored Aug 21, 2024
2 parents 920a3a2 + c9de856 commit 630e20c
Show file tree
Hide file tree
Showing 6 changed files with 459 additions and 254 deletions.
32 changes: 31 additions & 1 deletion src/flint/flint_base/flint_base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,37 @@ cdef class flint_mpoly_context(flint_elem):
cdef const char ** c_names

cdef class flint_mpoly(flint_elem):
pass
cdef _add_scalar_(self, other)
cdef _sub_scalar_(self, other)
cdef _mul_scalar_(self, other)

cdef _add_mpoly_(self, other)
cdef _sub_mpoly_(self, other)
cdef _mul_mpoly_(self, other)

cdef _divmod_mpoly_(self, other)
cdef _floordiv_mpoly_(self, other)
cdef _truediv_mpoly_(self, other)
cdef _mod_mpoly_(self, other)

cdef _rsub_scalar_(self, other)
cdef _rsub_mpoly_(self, other)

cdef _rdivmod_mpoly_(self, other)
cdef _rfloordiv_mpoly_(self, other)
cdef _rtruediv_mpoly_(self, other)
cdef _rmod_mpoly_(self, other)

cdef _pow_(self, other)

cdef _iadd_scalar_(self, other)
cdef _isub_scalar_(self, other)
cdef _imul_scalar_(self, other)

cdef _iadd_mpoly_(self, other)
cdef _isub_mpoly_(self, other)
cdef _imul_mpoly_(self, other)


cdef class flint_mat(flint_elem):
pass
Expand Down
218 changes: 131 additions & 87 deletions src/flint/flint_base/flint_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ cdef class flint_scalar(flint_elem):
return self._invert_()



cdef class flint_poly(flint_elem):
"""
Base class for polynomials.
Expand Down Expand Up @@ -405,52 +404,73 @@ cdef class flint_mpoly(flint_elem):
if not other:
raise ZeroDivisionError("nmod_mpoly division by zero")

def _add_scalar_(self, other):
cdef _add_scalar_(self, other):
return NotImplemented

cdef _sub_scalar_(self, other):
return NotImplemented

cdef _mul_scalar_(self, other):
return NotImplemented

cdef _add_mpoly_(self, other):
return NotImplemented

cdef _sub_mpoly_(self, other):
return NotImplemented

cdef _mul_mpoly_(self, other):
return NotImplemented

def _add_mpoly_(self, other):
cdef _divmod_mpoly_(self, other):
return NotImplemented

def _iadd_scalar_(self, other):
cdef _floordiv_mpoly_(self, other):
return NotImplemented

def _iadd_mpoly_(self, other):
cdef _truediv_mpoly_(self, other):
return NotImplemented

def _sub_scalar_(self, other):
cdef _mod_mpoly_(self, other):
return NotImplemented

def _sub_mpoly_(self, other):
cdef _rsub_scalar_(self, other):
return NotImplemented

def _isub_scalar_(self, other):
cdef _rsub_mpoly_(self, other):
return NotImplemented

def _isub_mpoly_(self, other):
cdef _rdivmod_mpoly_(self, other):
return NotImplemented

def _mul_scalar_(self, other):
cdef _rfloordiv_mpoly_(self, other):
return NotImplemented

def _imul_mpoly_(self, other):
cdef _rtruediv_mpoly_(self, other):
return NotImplemented

def _imul_scalar_(self, other):
cdef _rmod_mpoly_(self, other):
return NotImplemented

def _mul_mpoly_(self, other):
cdef _pow_(self, other):
return NotImplemented

def _pow_(self, other):
cdef _iadd_scalar_(self, other):
return NotImplemented

def _divmod_mpoly_(self, other):
cdef _isub_scalar_(self, other):
return NotImplemented

def _floordiv_mpoly_(self, other):
cdef _imul_scalar_(self, other):
return NotImplemented

def _truediv_mpoly_(self, other):
cdef _iadd_mpoly_(self, other):
return NotImplemented

cdef _isub_mpoly_(self, other):
return NotImplemented

cdef _imul_mpoly_(self, other):
return NotImplemented

def __add__(self, other):
Expand All @@ -465,32 +485,15 @@ cdef class flint_mpoly(flint_elem):
return self._add_scalar_(other)

def __radd__(self, other):
return self.__add__(other)

def iadd(self, other):
"""
In-place addition, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.iadd(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 + 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._iadd_mpoly_(other)
return
return self._add_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot add {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._iadd_scalar_(other_scalar)
return self._add_scalar_(other)

def __sub__(self, other):
if typecheck(other, type(self)):
Expand All @@ -504,32 +507,15 @@ cdef class flint_mpoly(flint_elem):
return self._sub_scalar_(other)

def __rsub__(self, other):
return -self.__sub__(other)

def isub(self, other):
"""
In-place subtraction, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.isub(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 - 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._isub_mpoly_(other)
return
return self._rsub_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot subtract {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._isub_scalar_(other_scalar)
return self._rsub_scalar_(other)

def __mul__(self, other):
if typecheck(other, type(self)):
Expand All @@ -543,32 +529,15 @@ cdef class flint_mpoly(flint_elem):
return self._mul_scalar_(other)

def __rmul__(self, other):
return self.__mul__(other)

def imul(self, other):
"""
In-place multiplication, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.imul(2)
>>> f
8*x0*x1 + 4*x0 + 6*x1
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._imul_mpoly_(other)
return
return self._mul_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot multiply {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._imul_scalar_(other_scalar)
return self._mul_scalar_(other)

def __pow__(self, other, modulus):
if modulus is not None:
Expand Down Expand Up @@ -605,7 +574,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._divmod_mpoly_(self)
return self._rdivmod_mpoly_(other)

def __truediv__(self, other):
if typecheck(other, type(self)):
Expand All @@ -628,7 +597,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._truediv_mpoly_(self)
return self._rtruediv_mpoly_(other)

def __floordiv__(self, other):
if typecheck(other, type(self)):
Expand All @@ -651,7 +620,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._floordiv_mpoly_(self)
return self._rfloordiv_mpoly_(other)

def __mod__(self, other):
if typecheck(other, type(self)):
Expand All @@ -674,7 +643,82 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._mod_mpoly_(self)
return self._rmod_mpoly_(other)

def iadd(self, other):
"""
In-place addition, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.iadd(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 + 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._iadd_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot add {type(self)} and {type(other)}")

self._iadd_scalar_(other_scalar)

def isub(self, other):
"""
In-place subtraction, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.isub(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 - 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._isub_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot subtract {type(self)} and {type(other)}")

self._isub_scalar_(other_scalar)

def imul(self, other):
"""
In-place multiplication, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.imul(2)
>>> f
8*x0*x1 + 4*x0 + 6*x1
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._imul_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot multiply {type(self)} and {type(other)}")

self._imul_scalar_(other_scalar)

def __contains__(self, x):
"""
Expand Down
Loading

0 comments on commit 630e20c

Please sign in to comment.