Skip to content

Commit

Permalink
None as default, where, support torch with state
Browse files Browse the repository at this point in the history
  • Loading branch information
JernKunpittaya committed May 14, 2024
1 parent 939f91f commit c849e60
Show file tree
Hide file tree
Showing 9 changed files with 963 additions and 351 deletions.
4 changes: 4 additions & 0 deletions examples/1.only_torch/data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6, 7],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}
281 changes: 281 additions & 0 deletions examples/1.only_torch/only_torch.ipynb

Large diffs are not rendered by default.

File renamed without changes.
294 changes: 294 additions & 0 deletions examples/2.torch+state/torch+state.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions examples/3.state/data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"x": [0.5, 1, 2, 3, 4, 5, 6],
"y": [2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4, 3.8]
}
294 changes: 294 additions & 0 deletions examples/3.state/state.ipynb

Large diffs are not rendered by default.

282 changes: 0 additions & 282 deletions examples/computation/computation.ipynb

This file was deleted.

44 changes: 17 additions & 27 deletions zkstats/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Covariance,
Correlation,
Regression,
Where,
IsResultPrecise,
)

Expand Down Expand Up @@ -139,15 +138,15 @@ def linear_regression(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return self._call_op([x, y], Regression)

# WHERE operation
def where(self, filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Calculate the where operation of x. The behavior should conform to `torch.where` in PyTorch.
:param filter: A boolean tensor serves as a filter
:param _filter: A boolean tensor serves as a filter
:param x: A tensor to be filtered
:return: filtered tensor
"""
return self._call_op([filter, x], Where)
return torch.where(_filter, x, x-x+MagicNumber)

def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
Expand Down Expand Up @@ -210,16 +209,12 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
# print('Verifier side create')
precal_witness = json.loads(open(self.precal_witness_path, "r").read())
op = op_type.create(x, self.error, precal_witness, self.op_dict)
# dont need to include Where
if not isinstance(op, Where):
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.op_dict[op_class_str] = 1
else:
self.op_dict[op_class_str]+=1
op_class_str =str(type(op)).split('.')[-1].split("'")[0]
if op_class_str not in self.op_dict:
self.op_dict[op_class_str] = 1
else:
self.op_dict[op_class_str]+=1
self.ops.append(op)
if isinstance(op, Where):
return torch.where(x[0], x[1], MagicNumber)
return op.result
else:
# Copy the current op index to a local variable since self.current_op_index will be incremented.
Expand Down Expand Up @@ -255,24 +250,15 @@ def is_precise() -> IsResultPrecise:
for i in range(len_bools):
res = self.bools[i]()
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
if isinstance(op, Where):
# print('Only where')
return is_precise_aggregated, torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
else:
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0][0]

elif current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
# for where
if isinstance(op, Where):
# print('many ops incl where')
return torch.where(x[0], x[1], x[1]-x[1]+MagicNumber)
else:
return op.result+(x[0]-x[0])[0][0][0]
return op.result+(x[0]-x[0])[0][0][0]


class IModel(nn.Module):
Expand Down Expand Up @@ -314,7 +300,11 @@ def preprocess(self, x: list[torch.Tensor]) -> None:

def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
# print('x sy: ')
return computation(state, x)
result = computation(state, x)
if len(result) ==1:
return x[0][0][0][0]-x[0][0][0][0]+torch.tensor(1.0), result
else:
return result
# print('state:: ', state.aggregate_witness_path)
return state, Model

111 changes: 69 additions & 42 deletions zkstats/ops.py

Large diffs are not rendered by default.

0 comments on commit c849e60

Please sign in to comment.