diff --git a/pyrival/algebra/mod_sqrt.py b/pyrival/algebra/mod_sqrt.py index 38bc175..8c1840f 100644 --- a/pyrival/algebra/mod_sqrt.py +++ b/pyrival/algebra/mod_sqrt.py @@ -1,8 +1,8 @@ def mod_sqrt(a, p): """returns x s.t. x**2 == a (mod p)""" a %= p - if a == 0: - return 0 + if a < 2: + return a assert pow(a, (p - 1) // 2, p) == 1 if p & 3 == 3: return pow(a, (p + 1) // 4, p)