Skip to content

Commit

Permalink
Merge pull request #93 from oscarbenjamin/pr_pow3
Browse files Browse the repository at this point in the history
Fix pow(int, int, fmpz)
  • Loading branch information
oscarbenjamin authored Oct 2, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 98c2883 + 0990591 commit dbbf8bb
Showing 2 changed files with 35 additions and 26 deletions.
4 changes: 4 additions & 0 deletions src/flint/test/test.py
Original file line number Diff line number Diff line change
@@ -136,12 +136,16 @@ def test_fmpz():
(2, 2, 3, 1),
(2, -1, 5, 3),
(2, 0, 5, 1),
(2, 5, 1000, 32),
]
for a, b, c, ab_mod_c in pow_mod_examples:
assert pow(a, b, c) == ab_mod_c
assert pow(flint.fmpz(a), b, c) == ab_mod_c
assert pow(a, flint.fmpz(b), c) == ab_mod_c
assert pow(a, b, flint.fmpz(c)) == ab_mod_c
assert pow(flint.fmpz(a), flint.fmpz(b), c) == ab_mod_c
assert pow(flint.fmpz(a), b, flint.fmpz(c)) == ab_mod_c
assert pow(a, flint.fmpz(b), flint.fmpz(c)) == ab_mod_c
assert pow(flint.fmpz(a), flint.fmpz(b), flint.fmpz(c)) == ab_mod_c

assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError)
57 changes: 31 additions & 26 deletions src/flint/types/fmpz.pyx
Original file line number Diff line number Diff line change
@@ -360,53 +360,58 @@ cdef class fmpz(flint_scalar):
return u

def __pow__(s, t, m):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
cdef fmpz_struct mval[1]
cdef int stype = FMPZ_UNKNOWN
cdef int ttype = FMPZ_UNKNOWN
cdef int mtype = FMPZ_UNKNOWN
cdef int success
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented

if m is None:
# fmpz_pow_fmpz throws if x is negative
if fmpz_sgn(tval) == -1:
if ttype == FMPZ_TMP: fmpz_clear(tval)
raise ValueError("negative exponent")
try:
stype = fmpz_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
return NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
if m is None:
# fmpz_pow_fmpz throws if x is negative
if fmpz_sgn(tval) == -1:
raise ValueError("negative exponent")

u = fmpz.__new__(fmpz)
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)
u = fmpz.__new__(fmpz)
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)

if not success:
if ttype == FMPZ_TMP: fmpz_clear(tval)
raise OverflowError("fmpz_pow_fmpz: exponent too large")
else:
# Modular exponentiation
mtype = fmpz_set_any_ref(mval, m)
if mtype != FMPZ_UNKNOWN:
if not success:
raise OverflowError("fmpz_pow_fmpz: exponent too large")

return u
else:
# Modular exponentiation
mtype = fmpz_set_any_ref(mval, m)
if mtype == FMPZ_UNKNOWN:
return NotImplemented

if fmpz_is_zero(mval):
if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)
raise ValueError("pow(): modulus cannot be zero")

# The Flint docs say that fmpz_powm will throw if m is zero
# but it also throws if m is negative. Python generally allows
# e.g. pow(2, 2, -3) == (2^2) % (-3) == -2. We could implement
# that here as well but it is not clear how useful it is.
if fmpz_sgn(mval) == -1:
if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)
raise ValueError("pow(): negative modulua not supported")
raise ValueError("pow(): negative modulus not supported")

u = fmpz.__new__(fmpz)
fmpz_powm((<fmpz>u).val, (<fmpz>s).val, tval, mval)
fmpz_powm((<fmpz>u).val, sval, tval, mval)

if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)
return u
return u
finally:
if stype == FMPZ_TMP: fmpz_clear(sval)
if ttype == FMPZ_TMP: fmpz_clear(tval)
if mtype == FMPZ_TMP: fmpz_clear(mval)

def __rpow__(s, t, m):
t = any_as_fmpz(t)

0 comments on commit dbbf8bb

Please sign in to comment.