Skip to content

Commit

Permalink
Prefer cdef to def, removes runtime reflection in operands
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake-Moss committed Aug 20, 2024
1 parent 7a3076c commit c9de856
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 214 deletions.
9 changes: 8 additions & 1 deletion src/flint/flint_base/flint_base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ cdef class flint_mpoly(flint_elem):
cdef _add_scalar_(self, other)
cdef _sub_scalar_(self, other)
cdef _mul_scalar_(self, other)
cdef _pow_(self, other)

cdef _add_mpoly_(self, other)
cdef _sub_mpoly_(self, other)
Expand All @@ -28,7 +27,15 @@ cdef class flint_mpoly(flint_elem):
cdef _truediv_mpoly_(self, other)
cdef _mod_mpoly_(self, other)

cdef _rsub_scalar_(self, other)
cdef _rsub_mpoly_(self, other)

cdef _rdivmod_mpoly_(self, other)
cdef _rfloordiv_mpoly_(self, other)
cdef _rtruediv_mpoly_(self, other)
cdef _rmod_mpoly_(self, other)

cdef _pow_(self, other)

cdef _iadd_scalar_(self, other)
cdef _isub_scalar_(self, other)
Expand Down
187 changes: 112 additions & 75 deletions src/flint/flint_base/flint_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ cdef class flint_scalar(flint_elem):
return self._invert_()



cdef class flint_poly(flint_elem):
"""
Base class for polynomials.
Expand Down Expand Up @@ -414,9 +413,6 @@ cdef class flint_mpoly(flint_elem):
cdef _mul_scalar_(self, other):
return NotImplemented

cdef _pow_(self, other):
return NotImplemented

cdef _add_mpoly_(self, other):
return NotImplemented

Expand All @@ -435,10 +431,28 @@ cdef class flint_mpoly(flint_elem):
cdef _truediv_mpoly_(self, other):
return NotImplemented

cdef _mod_mpoly_(self, other):
return NotImplemented

cdef _rsub_scalar_(self, other):
return NotImplemented

cdef _rsub_mpoly_(self, other):
return NotImplemented

cdef _rdivmod_mpoly_(self, other):
return NotImplemented

cdef _rfloordiv_mpoly_(self, other):
return NotImplemented

cdef _rtruediv_mpoly_(self, other):
return NotImplemented

cdef _mod_mpoly_(self, other):
cdef _rmod_mpoly_(self, other):
return NotImplemented

cdef _pow_(self, other):
return NotImplemented

cdef _iadd_scalar_(self, other):
Expand Down Expand Up @@ -471,32 +485,15 @@ cdef class flint_mpoly(flint_elem):
return self._add_scalar_(other)

def __radd__(self, other):
return self.__add__(other)

def iadd(self, other):
"""
In-place addition, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.iadd(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 + 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._iadd_mpoly_(other)
return
return self._add_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot add {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._iadd_scalar_(other_scalar)
return self._add_scalar_(other)

def __sub__(self, other):
if typecheck(other, type(self)):
Expand All @@ -510,32 +507,15 @@ cdef class flint_mpoly(flint_elem):
return self._sub_scalar_(other)

def __rsub__(self, other):
return -self.__sub__(other)

def isub(self, other):
"""
In-place subtraction, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.isub(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 - 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._isub_mpoly_(other)
return
return self._rsub_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot subtract {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._isub_scalar_(other_scalar)
return self._rsub_scalar_(other)

def __mul__(self, other):
if typecheck(other, type(self)):
Expand All @@ -549,32 +529,15 @@ cdef class flint_mpoly(flint_elem):
return self._mul_scalar_(other)

def __rmul__(self, other):
return self.__mul__(other)

def imul(self, other):
"""
In-place multiplication, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.imul(2)
>>> f
8*x0*x1 + 4*x0 + 6*x1
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._imul_mpoly_(other)
return
return self._mul_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot multiply {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._imul_scalar_(other_scalar)
return self._mul_scalar_(other)

def __pow__(self, other, modulus):
if modulus is not None:
Expand Down Expand Up @@ -611,7 +574,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._divmod_mpoly_(self)
return self._rdivmod_mpoly_(other)

def __truediv__(self, other):
if typecheck(other, type(self)):
Expand All @@ -634,7 +597,6 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
# return other._truediv_mpoly_(self)
return self._rtruediv_mpoly_(other)

def __floordiv__(self, other):
Expand All @@ -658,7 +620,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._floordiv_mpoly_(self)
return self._rfloordiv_mpoly_(other)

def __mod__(self, other):
if typecheck(other, type(self)):
Expand All @@ -681,7 +643,82 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._mod_mpoly_(self)
return self._rmod_mpoly_(other)

def iadd(self, other):
"""
In-place addition, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.iadd(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 + 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._iadd_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot add {type(self)} and {type(other)}")

self._iadd_scalar_(other_scalar)

def isub(self, other):
"""
In-place subtraction, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.isub(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 - 5
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._isub_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot subtract {type(self)} and {type(other)}")

self._isub_scalar_(other_scalar)

def imul(self, other):
"""
In-place multiplication, mutates self.
>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.imul(2)
>>> f
8*x0*x1 + 4*x0 + 6*x1
"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._imul_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot multiply {type(self)} and {type(other)}")

self._imul_scalar_(other_scalar)

def __contains__(self, x):
"""
Expand Down
Loading

0 comments on commit c9de856

Please sign in to comment.