Skip to content

Commit

Permalink
Fix bug in secret-clear fixed-point division.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed May 20, 2024
1 parent 28f8664 commit 5ba7e71
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
4 changes: 2 additions & 2 deletions Compiler/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
# For simplex, we can get rid of computing abs(b)
temp = None
if simplex_flag == False:
temp = comparison.LessThanZero(b, k, kappa)
temp = b.less_than(0, k)
elif simplex_flag == True:
temp = cint(0)

Expand All @@ -1807,7 +1807,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
z[i] = suffixes[i] - suffixes[i+1]
z[k - 1] = suffixes[k-1]

acc = sint.bit_compose(reversed(z))
acc = b.bit_compose(reversed(z))

part_reciprocal = absolute_val * acc
signed_acc = sign * acc
Expand Down
38 changes: 21 additions & 17 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,17 @@ def __init__(self, *args, **kwargs):
raise CompilerError('functionality only available in arithmetic circuits')
super(_arithmetic_register, self).__init__(*args, **kwargs)

@classmethod
def get_type(cls, length):
return cls

@staticmethod
def two_power(n, size=None):
return floatingpoint.two_power(n)

def Norm(self, k, f, kappa=None, simplex_flag=False):
return library.Norm(self, k, f, kappa=kappa, simplex_flag=simplex_flag)

class _clear(_arithmetic_register):
""" Clear domain-dependent type. """
__slots__ = []
Expand Down Expand Up @@ -1268,12 +1279,20 @@ def right_shift(self, other, bit_length=None):
:param other: cint/regint/int """
return self >> other

def round(self, k, m, kappa=None, nearest=None, signed=False):
if signed:
self += 2 ** (k - 1)
res = self >> m
if signed:
res -= 2 ** (k - m - 1)
return res

@read_mem_value
def greater_than(self, other, bit_length=None):
return self > other

@vectorize
def bit_decompose(self, bit_length=None):
def bit_decompose(self, bit_length=None, kappa=None, maybe_mixed=None):
""" Clear bit decomposition.
:param bit_length: number of bits (default is global bit length)
Expand Down Expand Up @@ -2878,9 +2897,6 @@ def round(self, k, m, kappa=None, nearest=False, signed=False):
return floatingpoint.Trunc(self, k, m, kappa)
return self.TruncPr(k, m, kappa, signed=signed)

def Norm(self, k, f, kappa=None, simplex_flag=False):
return library.Norm(self, k, f, kappa, simplex_flag)

def __truediv__(self, other):
""" Secret fixed-point division.
Expand Down Expand Up @@ -2924,10 +2940,6 @@ def trunc_zeros(self, n_zeros, bit_length=None, signed=True):
bit_length = bit_length or program.bit_length
return comparison.TruncZeros(self, bit_length, n_zeros, signed)

@staticmethod
def two_power(n, size=None):
return floatingpoint.two_power(n)

def split_to_n_summands(self, length, n):
comparison.require_ring_size(length, 'splitting')
from .GC.types import sbits
Expand Down Expand Up @@ -3181,10 +3193,6 @@ class sgf2n(_secret, _gf2n):
reg_type = 'sg'
long_one = staticmethod(lambda: 1)

@classmethod
def get_type(cls, length):
return cls

@classmethod
def get_raw_input_from(cls, player):
res = cls()
Expand Down Expand Up @@ -4683,13 +4691,9 @@ def __truediv__(self, other):
other = self.coerce(other)
assert self.k == other.k
assert self.f == other.f
if isinstance(other, _fix):
if isinstance(other, (_fix, cfix)):
v = library.FPDiv(self.v, other.v, self.k, self.f, self.kappa,
nearest=self.round_nearest)
elif isinstance(other, cfix):
v = library.sint_cint_division(self.v, other.v, self.k, self.f,
self.kappa,
nearest=self.round_nearest)
else:
raise TypeError('Incompatible fixed point types in division')
return self._new(v, k=self.k, f=self.f)
Expand Down

0 comments on commit 5ba7e71

Please sign in to comment.