Skip to content

Commit

Permalink
refine after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
JernKunpittaya committed May 10, 2024
1 parent 7def4dd commit 939f91f
Showing 1 changed file with 10 additions and 48 deletions.
58 changes: 10 additions & 48 deletions zkstats/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def prover_gen_settings(
"""
data_tensor_array = _process_data(data_path, selected_columns, sel_data_path)

# export onnx file
_export_onnx(prover_model, data_tensor_array, prover_model_path)

# gen + calibrate setting
_gen_settings(sel_data_path, prover_model_path, scale, mode, settings_path)

Expand Down Expand Up @@ -282,6 +282,7 @@ def generate_data_commitment(data_path: str, scales: Sequence[int], data_commitm
:param scales: a list of scales to use for the commitments
:param data_commitment_path: path to store the generated data commitment maps
"""

# Convert `data_path` to json file `data_json_path`
data_path: Path = Path(data_path)
data_json_path = Path(data_path).with_suffix(DataExtension.JSON.value)
Expand Down Expand Up @@ -354,7 +355,7 @@ def _gen_settings(
# Poseidon is not homomorphic additive, maybe consider Pedersens or Dory commitment.
gip_run_args = ezkl.PyRunArgs()
gip_run_args.input_visibility = "hashed" # one commitment (values hashed) for each column
gip_run_args.param_visibility = "fixed" # no parameters shown
gip_run_args.param_visibility = "private" # no parameters shown
gip_run_args.output_visibility = "public" # should be `(torch.Tensor(1.0), output)`

# generate settings
Expand All @@ -374,49 +375,6 @@ def _gen_settings(
print("scale: ", scale)
print("setting: ", f_setting.read())

def _csv_file_to_json(old_file_path: Union[Path, str], out_data_json_path: Union[Path, str], *, delimiter: str = ",") -> None:
data_csv_path = Path(old_file_path)
with open(data_csv_path, 'r') as f_csv:
reader = csv.reader(f_csv, delimiter=delimiter, strict=True)
# Read all data from the reader to `rows`
rows_with_column_name = tuple(reader)
if len(rows_with_column_name) < 1:
raise ValueError("No column names in the CSV file")
if len(rows_with_column_name) < 2:
raise ValueError("No data in the CSV file")
column_names = rows_with_column_name[0]
rows = rows_with_column_name[1:]

columns = [
[
float(rows[j][i])
for j in range(len(rows))
]
for i in range(len(rows[0]))
]
data = {
column_name: column_data
for column_name, column_data in zip(column_names, columns)
}
with open(out_data_json_path, "w") as f_json:
json.dump(data, f_json)


class DataExtension(Enum):
CSV = ".csv"
JSON = ".json"


DATA_FORMAT_PREPROCESSING_FUNCTION: dict[DataExtension, Callable[[Union[Path, str], Path], None]] = {
DataExtension.CSV: _csv_file_to_json,
DataExtension.JSON: lambda old_file_path, out_data_json_path: Path(out_data_json_path).write_text(Path(old_file_path).read_text())
}

def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_path: Path):
data_file_extension = DataExtension(data_path.suffix)
preprocess_function = DATA_FORMAT_PREPROCESSING_FUNCTION[data_file_extension]
preprocess_function(data_path, out_data_json_path)


def _csv_file_to_json(old_file_path: Union[Path, str], out_data_json_path: Union[Path, str], *, delimiter: str = ",") -> None:
data_csv_path = Path(old_file_path)
Expand Down Expand Up @@ -463,13 +421,17 @@ def _preprocess_data_file_to_json(data_path: Union[Path, str], out_data_json_pat


def _process_data(
data_path: Union[str| Path],
data_path: Union[str | Path],
col_array: list[str],
sel_data_path: list[str],
) -> list[torch.Tensor]:
data_tensor_array=[]
sel_data = []
data_onefile = json.loads(open(data_path, "r").read())
data_path: Path = Path(data_path)
# Convert data file to json under the same directory but with suffix .json
data_json_path = Path(data_path).with_suffix(DataExtension.JSON.value)
_preprocess_data_file_to_json(data_path, data_json_path)
data_onefile = json.loads(open(data_json_path, "r").read())

for col in col_array:
data = data_onefile[col]
Expand All @@ -489,4 +451,4 @@ def _get_commitment_for_column(column: list[float], scale: int) -> str:
res_poseidon_hash = ezkl.poseidon_hash(serialized_data)[0]
# res_hex = ezkl.vecu64_to_felt(res_poseidon_hash[0])

return res_poseidon_hash
return res_poseidon_hash

0 comments on commit 939f91f

Please sign in to comment.