Skip to content

Commit

Permalink
Merge pull request #57 from ZKStats/auto_gen_selected_columns
Browse files Browse the repository at this point in the history
Detect selected columns in `computation_to_model`
  • Loading branch information
mhchia authored Sep 13, 2024
2 parents 54341df + 94fcdfb commit f2cb773
Show file tree
Hide file tree
Showing 8 changed files with 1,036 additions and 833 deletions.
21 changes: 4 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,9 @@ poetry install

### Define Your Computation

User computation must be defined as **a function** using ZKStats operations and PyTorch functions. The function signature must be `Callable[[State, list[torch.Tensor]], torch.Tensor]`:

```python
import torch

from zkstats.computation import State

# User-defined computation
def user_computation(s: State, data: list[torch.Tensor]) -> torch.Tensor:
# Define your computation here
...

```

User computation must be defined as **a function** using ZKStats operations and PyTorch functions. The function signature must be `Callable[[State, Args], torch.Tensor]`:
- first argument is a `State` object, which contains the statistical functions that ZKStats supports.
- second argument is a list of PyTorch tensors, the input data. `data[0]` is the first column, `data[1]` is the second column, and so on.
- second argument is a `Args` object, which is a dictionary of PyTorch tensors, the input data. `Args['column1']` is the first column, `Args['column2']` is the second column, and so on.

For example, we have two columns of data and we want to compute the mean of the medians of the two columns:

Expand Down Expand Up @@ -116,9 +103,9 @@ Note here, that we can also just let prover generate model, and then send that m
```python
from zkstats.core import computation_to_model
# For prover: generate prover_model, and write to precal_witness file
_, prover_model = computation_to_model(user_computation, precal_witness_path, True, selected_columns, error)
selected_columns, _, prover_model = computation_to_model(user_computation, precal_witness_path, data_shape, True, error)
# For verifier, generate verifier model (which is same as prover_model) by reading precal_witness file
_, verifier_model = computation_to_model(user_computation, precal_witness_path, False, selected_columns, error)
selected_columns, _, verifier_model = computation_to_model(user_computation, precal_witness_path, data_shape, False, error)
```

#### Data Provider: generate settings
Expand Down
1,553 changes: 832 additions & 721 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ authors = ["Jern Kunpittaya", "Kevin Chia"]
[tool.poetry.dependencies]
python = "^3.9"
ezkl = "9.1.0"
torch = "^2.1.1"
# fix torch version to 2.2.0 due to a weird issue when upgrading to 2.4.1
torch = "2.2.0"
requests = "^2.31.0"
scipy = "^1.11.4"
numpy = "^1.26.2"
Expand Down
45 changes: 8 additions & 37 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,21 @@
ERROR_CIRCUIT_RELAXED = 0.1


def data_to_json_file(data_path: Path, data: list[torch.Tensor]) -> dict[str, list]:
column_names = [f"columns_{i}" for i in range(len(data))]
def data_to_json_file(data_path: Path, data: dict[str, torch.Tensor]) -> dict[str, list]:
column_to_data = {
column: d.tolist()
for column, d in zip(column_names, data)
for column, d in data.items()
}
print('columnnnn: ', column_to_data)
with open(data_path, "w") as f:
json.dump(column_to_data, f)
return column_to_data



def compute_model(
def compute(
basepath: Path,
data: list[torch.Tensor],
model: IModel,
data: dict[str, torch.Tensor],
model: Type[IModel],
# computation: TComputation,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
):
Expand All @@ -47,10 +45,10 @@ def compute_model(
data_path = basepath / "data.json"
data_commitment_path = basepath / "commitments.json"

column_to_data = data_to_json_file(data_path, data)
data_to_json_file(data_path, data)
# If selected_columns_params is None, select all columns
if selected_columns_params is None:
selected_columns = list(column_to_data.keys())
selected_columns = list(data.keys())
else:
selected_columns = selected_columns_params

Expand All @@ -62,44 +60,17 @@ def compute_model(
else:
scales = scales_params
scales_for_commitments = scales_params
# create_dummy((data_path), (dummy_data_path))
generate_data_commitment((data_path), scales_for_commitments, (data_commitment_path))
# _, prover_model = computation_to_model(computation, (precal_witness_path), True, selected_columns, error)

prover_gen_settings((data_path), selected_columns, (sel_data_path), model, (model_path), scales, "resources", (settings_path))

# No need, since verifier & prover share the same onnx
# _, verifier_model = computation_to_model(computation, (precal_witness_path), False, selected_columns, error)
# verifier_define_calculation((dummy_data_path), selected_columns, (sel_dummy_data_path),verifier_model, (verifier_model_path))

setup((model_path), (compiled_model_path), (settings_path),(vk_path), (pk_path ))

prover_gen_proof((model_path), (sel_data_path), (witness_path), (compiled_model_path), (settings_path), (proof_path), (pk_path))
# print('slett col: ', selected_columns)
verifier_verify((proof_path), (settings_path), (vk_path), selected_columns, (data_commitment_path))


def compute(
basepath: Path,
data: list[torch.Tensor],
computation: TComputation,
scales_params: Optional[Sequence[int]] = None,
selected_columns_params: Optional[list[str]] = None,
) -> State:
data_path = basepath / "data.json"
precal_witness_path = basepath / "precal_witness_path.json"

column_to_data = data_to_json_file(data_path, data)
# If selected_columns_params is None, select all columns
if selected_columns_params is None:
selected_columns = list(column_to_data.keys())
else:
selected_columns = selected_columns_params

state, model = computation_to_model(computation, precal_witness_path, True, selected_columns)
compute_model(basepath, data, model, scales_params, selected_columns_params)
return state



# Error tolerance between zkstats python implementation and python statistics module
Expand Down
103 changes: 91 additions & 12 deletions tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from zkstats.computation import State, Args, computation_to_model
from zkstats.computation import State, computation_to_model, analyze_computation, TComputation, Args
from zkstats.ops import (
Mean,
Median,
Expand All @@ -25,9 +25,9 @@


def nested_computation(state: State, args: Args):
x = args['columns_0']
y = args['columns_1']
z = args['columns_2']
x = args["x"]
y = args["y"]
z = args["z"]
out_0 = state.median(x)
out_1 = state.geometric_mean(y)
out_2 = state.harmonic_mean(x)
Expand Down Expand Up @@ -63,8 +63,14 @@ def nested_computation(state: State, args: Args):
[ERROR_CIRCUIT_DEFAULT],
)
def test_nested_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, column_2: torch.Tensor, error, scales):
precal_witness_path = tmp_path / "precal_witness_path.json"
x, y, z = column_0, column_1, column_2
state = compute(tmp_path, [x, y, z], nested_computation, scales)
data_shape = {"x": len(x), "y": len(y), "z": len(z)}
data = {"x": x, "y": y, "z": z}
selected_columns, state, model = computation_to_model(nested_computation, precal_witness_path, data_shape, True, error)
compute(tmp_path, data, model, scales, selected_columns)
# There are 11 ops in the computation

assert state.current_op_index == 12

ops = state.ops
Expand Down Expand Up @@ -152,10 +158,14 @@ def test_computation_with_where_1d(tmp_path, error, column_0, op_type: Callable[
def condition(_x: torch.Tensor):
return _x < 4

def where_and_op(state: State, args: Args):
x = args['columns_0']
column_name = "x"

def where_and_op(state, args):
x = args[column_name]
return op_type(state, state.where(condition(x), x))
state = compute(tmp_path, [column], where_and_op, scales)
precal_witness_path = tmp_path / "precal_witness_path.json"
_, state, model = computation_to_model(where_and_op, precal_witness_path, {column_name: column.shape}, True, error)
compute(tmp_path, {column_name: column}, model, scales)

res_op = state.ops[-1]
filtered = column[condition(column)]
Expand All @@ -174,18 +184,87 @@ def test_computation_with_where_2d(tmp_path, error, column_0, column_1, op_type:
def condition_0(_x: torch.Tensor):
return _x > 4

def where_and_op(state: State, args: Args):
x = args['columns_0']
y = args['columns_1']
def where_and_op(state: State, args: list[torch.Tensor]):
x = args["x"]
y = args["y"]
condition_x = condition_0(x)
filtered_x = state.where(condition_x, x)
filtered_y = state.where(condition_x, y)
return op_type(state, filtered_x, filtered_y)
state = compute(tmp_path, [column_0, column_1], where_and_op, scales)
precal_witness_path = tmp_path / "precal_witness_path.json"
data_shape = {"x": len(column_0), "y": len(column_1)}
data = {"x": column_0, "y": column_1}
selected_columns, state, model = computation_to_model(where_and_op, precal_witness_path, data_shape, True ,error)
compute(tmp_path, data, model, scales, selected_columns)

res_op = state.ops[-1]
condition_x = condition_0(column_0)
filtered_x = column_0[condition_x]
filtered_y = column_1[condition_x]
expected_res = expected_func(filtered_x.tolist(), filtered_y.tolist())
assert_result(res_op.result.data, expected_res)


def test_analyze_computation_success():
def valid_computation(state, args):
x = args["column1"]
y = args["column2"]
return state.mean(x) + state.median(y)

result = analyze_computation(valid_computation)
assert set(result) == {"column1", "column2"}

def test_analyze_computation_no_columns():
def no_columns_computation(state, args):
return state.mean(state.median([1, 2, 3]))

result = analyze_computation(no_columns_computation)
assert result == []

def test_analyze_computation_multiple_uses():
def multiple_uses_computation(state, args):
x = args["column1"]
y = args["column2"]
z = args["column1"] # Using column1 twice
return state.mean(x) + state.median(y) + state.sum(z)

result = analyze_computation(multiple_uses_computation)
assert set(result) == {"column1", "column2"}

def test_analyze_computation_nested_args():
def nested_args_computation(state, args):
x = args["column1"]["nested"]
y = args["column2"]
return state.mean(x) + state.median(y)

result = analyze_computation(nested_args_computation)
assert set(result) == {"column1", "column2"}

def test_analyze_computation_invalid_params():
def invalid_params_computation(invalid1, invalid2):
return invalid1.mean(invalid2["column"])

with pytest.raises(ValueError, match="The computation function must have two parameters named 'state' and 'args'"):
analyze_computation(invalid_params_computation)

def test_analyze_computation_wrong_param_names():
def wrong_param_names(state, wrong_name):
return state.mean(wrong_name["column"])

with pytest.raises(ValueError, match="The computation function must have two parameters named 'state' and 'args'"):
analyze_computation(wrong_param_names)

def test_analyze_computation_dynamic_column_access():
def dynamic_column_access(state, args):
columns = ["column1", "column2"]
return sum(state.mean(args[col]) for col in columns)

# This won't catch dynamically accessed columns
result = analyze_computation(dynamic_column_access)
assert result == []

def test_analyze_computation_lambda():
lambda_computation = lambda state, args: state.mean(args["column"])

with pytest.raises(ValueError, match="Lambda functions are not supported in analyze_computation"):
analyze_computation(lambda_computation)
37 changes: 19 additions & 18 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def test_get_data_commitment_maps(tmp_path, column_0, column_1, scales):
# "columns_0": [1, 2, 3, 4, 5],
# "columns_1": [6, 7, 8, 9, 10],
# }
data_json = data_to_json_file(data_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_to_json_file(data_path, data_json)
# data_commitment is a mapping[scale -> mapping[column_name, commitment_hex]]
# {
# scale_0: {
Expand Down Expand Up @@ -51,7 +52,8 @@ def test_get_data_commitment_maps_hardcoded(tmp_path):
data_commitment_path = tmp_path / "commitments.json"
column_0 = torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])
column_1 = torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])
data_to_json_file(data_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_to_json_file(data_path, data_json)
scales = [2, 3]
generate_data_commitment(data_path, scales, data_commitment_path)
with open(data_commitment_path, "r") as f:
Expand All @@ -63,30 +65,28 @@ def test_get_data_commitment_maps_hardcoded(tmp_path):

def test_integration_select_partial_columns(tmp_path, column_0, column_1, error, scales):
data_path = tmp_path / "data.json"
data_json = data_to_json_file(data_path, [column_0, column_1])
columns = list(data_json.keys())
assert len(columns) == 2
# Select only the first column from two columns
selected_columns = [columns[0]]
data_json = {"columns_0": column_0, "columns_1": column_1}
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
data_to_json_file(data_path, data_json)

def simple_computation(state, args):
x = args['columns_0']
return state.mean(x)
return state.mean(args["columns_0"])
precal_witness_path = tmp_path / "precal_witness_path.json"
selected_columns, _, model = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
# gen settings, setup, prove, verify
compute(tmp_path, [column_0, column_1], simple_computation, scales, selected_columns)
compute(tmp_path, data_json, model, scales, selected_columns)


def test_csv_data(tmp_path, column_0, column_1, error, scales):
data_json_path = tmp_path / "data.json"
data_csv_path = tmp_path / "data.csv"
data_json = data_to_json_file(data_json_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_shape = {"columns_0": len(column_0), "columns_1": len(column_1)}
data_to_json_file(data_json_path, data_json)
json_file_to_csv(data_json_path, data_csv_path)

selected_columns = list(data_json.keys())

def simple_computation(state, args):
x = args['columns_0']
return state.mean(x)
return state.mean(args["columns_0"])

sel_data_path = tmp_path / "comb_data.json"
model_path = tmp_path / "model.onnx"
Expand All @@ -98,7 +98,7 @@ def simple_computation(state, args):
generate_data_commitment(data_csv_path, scales, data_commitment_path)

# Test: `prover_gen_settings` works with csv
_, model_for_proving = computation_to_model(simple_computation, precal_witness_path, True, selected_columns, error)
selected_columns, _, model_for_proving = computation_to_model(simple_computation, precal_witness_path, data_shape, True, error)
prover_gen_settings(
data_path=data_csv_path,
selected_columns=selected_columns,
Expand All @@ -112,7 +112,7 @@ def simple_computation(state, args):

# Test: `prover_gen_settings` works with csv
# Instantiate the model for verification since the state of `model_for_proving` is changed after `prover_gen_settings`
_, model_for_verification = computation_to_model(simple_computation, precal_witness_path, False, selected_columns, error)
selected_columns, _, model_for_verification = computation_to_model(simple_computation, precal_witness_path, data_shape, False, error)
verifier_define_calculation(data_csv_path, selected_columns, str(sel_data_path), model_for_verification, str(model_path))

def json_file_to_csv(data_json_path, data_csv_path):
Expand All @@ -135,7 +135,8 @@ def json_file_to_csv(data_json_path, data_csv_path):

def test__preprocess_data_file_to_json(tmp_path, column_0, column_1):
data_json_path = tmp_path / "data.json"
data_from_json = data_to_json_file(data_json_path, [column_0, column_1])
data_json = {"columns_0": column_0, "columns_1": column_1}
data_from_json = data_to_json_file(data_json_path, data_json)

# Test: csv can be converted to json
# 1. Generate a csv file from json
Expand Down
Loading

0 comments on commit f2cb773

Please sign in to comment.