Skip to content

Commit

Permalink
Merge pull request #27 from mhchia/mpcstats-fix-loop-issue
Browse files Browse the repository at this point in the history
fix(mpcstats): for loop issue in circom
  • Loading branch information
mhchia authored Apr 25, 2024
2 parents ae7b175 + 968495b commit e9ae9ed
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 106 deletions.
6 changes: 5 additions & 1 deletion tests/onnx2circom/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Clone circom-2-arithc. Use a fork for now. Will change to the official repo soon
cd ..
git clone https://github.com/mhchia/circom-2-arithc.git
cd circom-2-arithc
# Initialize env file
git checkout mpcstats
cp .env.example .env
circom_2_arithc_project_root=$(pwd)
```

Expand All @@ -36,8 +39,9 @@ cargo build --release
Clone the repo
```bash
cd ..
git clone https://github.com/data61/MP-SPDZ
git clone https://github.com/mhchia/MP-SPDZ
cd MP-SPDZ
git checkout arith-extcutor
mp_spdz_project_root=$(pwd)
```

Expand Down
3 changes: 2 additions & 1 deletion tests/onnx2circom/test_onnx_to_circom.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def compile_and_check(model_type: Type[nn.Module], data: torch.Tensor, tmp_path:
print("Transforming torch model to onnx...")
torch_model_to_onnx(model_type, data, onnx_path)
assert onnx_path.exists() is True, f"The output file {onnx_path} does not exist."
# onnx_to_keras(onnx_path, keras_path)
print("Transforming onnx model to circom...")
onnx_to_circom(onnx_path, circom_path)
assert circom_path.exists() is True, f"The output file {circom_path} does not exist."
Expand Down Expand Up @@ -97,6 +96,7 @@ def compile_and_check(model_type: Type[nn.Module], data: torch.Tensor, tmp_path:
# for convenience (which input is from which party). Now just put every input to party 0.
# Assume the input data is a 1-d tensor
user_config_path = MP_SPDZ_PROJECT_ROOT / f"Configs/{model_name}.json"
user_config_path.parent.mkdir(parents=True, exist_ok=True)
with open(user_config_path, 'w') as f:
json.dump({"inputs_from": {
"0": input_names,
Expand All @@ -107,6 +107,7 @@ def compile_and_check(model_type: Type[nn.Module], data: torch.Tensor, tmp_path:
# Prepare data for party 0
data_list = data.reshape(-1)
input_0_path = MP_SPDZ_PROJECT_ROOT / 'Player-Data/Input-P0-0'
input_0_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_0_path, 'w') as f:
# TODO: change int to float
f.write(' '.join([str(int(x)) for x in data_list.tolist()]))
Expand Down
42 changes: 28 additions & 14 deletions zkstats/arithc_to_bristol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _parse_arithc_json(arithc_path: str):
main_outputs[anode.id] = node_name[2:]
if anode.is_const:
const_values[anode.id] = anode.const_value
const_names[anode.id] = node_name
# Make sure each constant has a unique name since different signals can have the same name
# E.g. `for (var i = 0; i < 16; i++)`, all 16 signals will have the same name `i`
const_names[anode.id] = f"{node_name}_{anode.id}"

for gate in data['gates']:
gate_id = gate['id']
Expand Down Expand Up @@ -292,33 +294,45 @@ def generate_circuit_info(self):
# }

rid_to_iid = {node.rid: node.iid for node in tt.sorted_wires}
# Map input name to wire index in MP-SPDZ circuit (including constant wires)
# Map highest level input name to wire index in MP-SPDZ circuit (including constant wires)
non_const_main_input_rids = [node_rid for node_rid in main_inputs if node_rid not in const_values]
input_name_to_wire_index = {
main_inputs[node_rid]: rid_to_iid[node_rid]
for node_rid in self.leaves if node_rid not in const_values
for node_rid in non_const_main_input_rids
}

# FIXME: outputs without a gate are skipped (i.e. direct assigned from input or a constant, etc)
if len(input_name_to_wire_index) != len(non_const_main_input_rids):
raise Exception("Some inputs have the same name. Please make sure all inputs have distinct names.")

# Prepare constants: const_values is what we want
# Just sanity check for all constant must be in leaves so we don't miss passing any of them to MP-SPDZ circuit
# Check if every node has distinct names
const_rids = [
node_rid for node_rid in const_values
if node_rid in rid_to_iid # Skip constant wires that are not used in any gates. E.g. constant outputs
]
const_name_to_value_wire_id = {
const_names[node_rid]: {
'value': const_value,
'wire_index': rid_to_iid[node_rid],
'value': const_values[node_rid],
'wire_index': rid_to_iid[node_rid]
}
for node_rid, const_value in const_values.items()
if node_rid in rid_to_iid # Skip constant wires that are not used in any gates. E.g. constant outputs
for node_rid in const_rids
}
if len(const_name_to_value_wire_id) != len(const_rids):
raise Exception("Some constants have the same name. Please make sure all constants have distinct names.")

# Prepare outputs
# Map output name to wire index in MP-SPDZ circuit
output_name_to_wire_index = {
output_name: rid_to_iid[node_rid]
for node_rid, output_name in main_outputs.items()
# Map highest level output name to wire index in MP-SPDZ circuit
non_const_main_output_rids = [
node_rid for node_rid in main_outputs
if node_rid in rid_to_iid # Skip output wires that are not used in any gates. E.g. constant outputs
]
output_name_to_wire_index = {
main_outputs[node_rid]: rid_to_iid[node_rid]
for node_rid in non_const_main_output_rids
}
print("!@# output_name_to_wire_index=", output_name_to_wire_index)
if len(output_name_to_wire_index) != len(non_const_main_output_rids):
raise Exception("Some outputs have the same name. Please make sure all outputs have distinct names.")

return {
"input_name_to_wire_index": input_name_to_wire_index,
"constants": const_name_to_value_wire_id,
Expand Down
136 changes: 46 additions & 90 deletions zkstats/onnx2circom/mpc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -60,127 +60,83 @@ template TFReduceMean(nInputs) {
out[0] <== div.out[0];
}


// TODO: e should be 2.71828 instead of 2 for now
template TFLog(e) {
signal input in[1][1];
// b must be passed in as secret

signal output out[1];

var upper_bound = 16;
// find b so that b = e^k and x/b <= 1
// Approximate natural log with talyer series. For 0 < x <= 2
// - ln(x) = ln(1 + (x-1)) = 0 + (x-1) - (x-1)^2/2 + (x-1)^3/3 - (x-1)^4/4 + ...
// - To ensure x <= 2, we can use the following property of logarithm:
// ln(x) = ln(x / b) + ln(b) where b is an integer e^k. So, we need to calculate
// - Step 1: b = e^k such that x/b <= 1
// - Step 2: ln(x/b) using talyer series
// - Step 3: ln(x) = ln(x/b) + ln(b) = ln(x/b) + k

// x can be only up to e^max_exponent
var max_exponent = 64;
var taylor_series_iterations = 40;

// Step 1: Find b so that b = e^k and x/b <= 1
// find b so that x/b < 1 and can be used in talyer series
// b = e^k, e=2.71828
signal e_until[upper_bound];
signal x <== in[0][0];
// e_until[i] = e^i
signal e_until[max_exponent];
e_until[0] <== 1;
for (var i = 1; i < upper_bound; i++) {
for (var i = 1; i < max_exponent; i++) {
e_until[i] <== e_until[i-1] * e;
}
signal x <== in[0][0];
// find k s.t. e_until[k] >= x and e_until[k-1] < x
// signal k;

// e^k >= x and e^(k-1) < x
// sel : [0, 0, 1, 0, 0]
// k : [0, 1, 2, 3, 4]
// sum(sel*k) = k

signal sel[upper_bound];
signal sel[max_exponent];
sel[0] <== x <= 1;
component sel_comp[upper_bound];
for (var i = 1; i < upper_bound; i++) {
component sel_comp[max_exponent];
for (var i = 1; i < max_exponent; i++) {
sel_comp[i] = TFMul();
sel_comp[i].in[0][0] <== x > e_until[i-1];
sel_comp[i].in[1][0] <== x <= e_until[i];
sel[i] <== sel_comp[i].out[0];
}

// component sum = TFReduceSum(nInputs);
// for (var i = 0; i<nInputs; i++){
// sum.in[i][0] <== in[i][0];
// }

// FIXME: investigate why we get errors in MP-SPDZ
// if we use a for loop here
component k_by_sum = TFReduceSum(upper_bound);
// for (var i = 0; i < upper_bound; i++) {
// k_by_sum.in[i][0] <== sel[i] * i;
// }
k_by_sum.in[0][0] <== sel[0] * 0;
k_by_sum.in[1][0] <== sel[1] * 1;
k_by_sum.in[2][0] <== sel[2] * 2;
k_by_sum.in[3][0] <== sel[3] * 3;
k_by_sum.in[4][0] <== sel[4] * 4;
k_by_sum.in[5][0] <== sel[5] * 5;
k_by_sum.in[6][0] <== sel[6] * 6;
k_by_sum.in[7][0] <== sel[7] * 7;
k_by_sum.in[8][0] <== sel[8] * 8;
k_by_sum.in[9][0] <== sel[9] * 9;
k_by_sum.in[10][0] <== sel[10] * 10;
k_by_sum.in[11][0] <== sel[11] * 11;
k_by_sum.in[12][0] <== sel[12] * 12;
k_by_sum.in[13][0] <== sel[13] * 13;
k_by_sum.in[14][0] <== sel[14] * 14;
k_by_sum.in[15][0] <== sel[15] * 15;

component k_by_sum = TFReduceSum(max_exponent);
for (var i = 0; i < max_exponent; i++) {
k_by_sum.in[i][0] <== sel[i] * i;
}
signal k <== k_by_sum.out[0];

// FIXME: investigate why we get errors in MP-SPDZ
// if we use a for loop here
// component b_by_sum = TFReduceSum(upper_bound);
// for (var i = 0; i < upper_bound; i++) {
// b_by_sum.in[i][0] <== sel[i] * e_until[i];
// }
// signal b <== b_by_sum.out[0];
component b_by_sum = TFReduceSum(upper_bound);
b_by_sum.in[0][0] <== sel[0] * e_until[0];
b_by_sum.in[1][0] <== sel[1] * e_until[1];
b_by_sum.in[2][0] <== sel[2] * e_until[2];
b_by_sum.in[3][0] <== sel[3] * e_until[3];
b_by_sum.in[4][0] <== sel[4] * e_until[4];
b_by_sum.in[5][0] <== sel[5] * e_until[5];
b_by_sum.in[6][0] <== sel[6] * e_until[6];
b_by_sum.in[7][0] <== sel[7] * e_until[7];
b_by_sum.in[8][0] <== sel[8] * e_until[8];
b_by_sum.in[9][0] <== sel[9] * e_until[9];
b_by_sum.in[10][0] <== sel[10] * e_until[10];
b_by_sum.in[11][0] <== sel[11] * e_until[11];
b_by_sum.in[12][0] <== sel[12] * e_until[12];
b_by_sum.in[13][0] <== sel[13] * e_until[13];
b_by_sum.in[14][0] <== sel[14] * e_until[14];
b_by_sum.in[15][0] <== sel[15] * e_until[15];

// sum(sel*e^k) = b
component b_by_sum = TFReduceSum(max_exponent);
for (var i = 0; i < max_exponent; i++) {
b_by_sum.in[i][0] <== sel[i] * e_until[i];
}
signal b <== b_by_sum.out[0];

// Step 2: Calculate ln(x/b) using talyer series
signal x_over_b <== x / b;

var taylor_series_iterations = 40;
signal x_over_b_minus_one_exp[taylor_series_iterations+1];
signal x_over_b_minus_one_exp[taylor_series_iterations];
x_over_b_minus_one_exp[0] <== 0;
x_over_b_minus_one_exp[1] <== (x / b) - 1;
for (var i = 2; i < taylor_series_iterations+1; i++) {
for (var i = 2; i < taylor_series_iterations; i++) {
x_over_b_minus_one_exp[i] <== x_over_b_minus_one_exp[i-1] * (1 - x_over_b);
}

signal taylor_series[taylor_series_iterations];
taylor_series[0] <== 0;
for (var i = 1; i < taylor_series_iterations; i++) {
taylor_series[i] <== x_over_b_minus_one_exp[i] / i;
}

// log(x) = log(x/b) + log(b)
// use talyer series to approximate
// log(x) = log(1 + (x-1)) = (x-1) - (x-1)^2/2 + (x-1)^3/3 - (x-1)^4/4 + ...

// FIXME: investigate why we get errors in MP-SPDZ
// if we use a for loop here
// signal taylor_series[taylor_series_iterations+1];
// taylor_series[0] <== 0;
// for (var i = 1; i < taylor_series_iterations+1; i++) {
// taylor_series[i] <== x_over_b_minus_one_exp[i] / i;
// }
// out[0] <== k + taylor_series[0]+taylor_series[1]+taylor_series[2]+taylor_series[3]+taylor_series[4]+taylor_series[5]+taylor_series[6]+taylor_series[7]+taylor_series[8]+taylor_series[9]+taylor_series[10]+taylor_series[11]+taylor_series[12]+taylor_series[13]+taylor_series[14]+taylor_series[15]+taylor_series[16];
// signal taylor_series_sum;
// component taylor_series_sum_comp = TFReduceSum(taylor_series_iterations);
// for (var i = 0; i < taylor_series_iterations; i++) {
// taylor_series_sum_comp.in[i][0] <== taylor_series[i+1];
// }
// taylor_series_sum <== taylor_series_sum_comp.out[0];
// out[0] <== taylor_series_sum + k;
out[0] <== k + x_over_b_minus_one_exp[1]+x_over_b_minus_one_exp[2]/2+x_over_b_minus_one_exp[3]/3+x_over_b_minus_one_exp[4]/4+x_over_b_minus_one_exp[5]/5+x_over_b_minus_one_exp[6]/6 + x_over_b_minus_one_exp[7]/7 + x_over_b_minus_one_exp[8]/8 + x_over_b_minus_one_exp[9]/9 + x_over_b_minus_one_exp[10]/10+x_over_b_minus_one_exp[11]/11+x_over_b_minus_one_exp[12]/12+x_over_b_minus_one_exp[13]/13+x_over_b_minus_one_exp[14]/14+x_over_b_minus_one_exp[15]/15+x_over_b_minus_one_exp[16]/16 + x_over_b_minus_one_exp[17]/17 + x_over_b_minus_one_exp[18]/18 + x_over_b_minus_one_exp[19]/19 + x_over_b_minus_one_exp[20]/20+x_over_b_minus_one_exp[21]/21+x_over_b_minus_one_exp[22]/22+x_over_b_minus_one_exp[23]/23+x_over_b_minus_one_exp[24]/24+x_over_b_minus_one_exp[25]/25+x_over_b_minus_one_exp[26]/26+x_over_b_minus_one_exp[27]/27+x_over_b_minus_one_exp[28]/28+x_over_b_minus_one_exp[29]/29+x_over_b_minus_one_exp[30]/30+x_over_b_minus_one_exp[31]/31+x_over_b_minus_one_exp[32]/32+x_over_b_minus_one_exp[33]/33+x_over_b_minus_one_exp[34]/34+x_over_b_minus_one_exp[35]/35+x_over_b_minus_one_exp[36]/36+x_over_b_minus_one_exp[37]/37+x_over_b_minus_one_exp[38]/38+x_over_b_minus_one_exp[39]/39+x_over_b_minus_one_exp[40]/40;
signal taylor_series_sum;
component taylor_series_sum_comp = TFReduceSum(taylor_series_iterations);
for (var i = 0; i < taylor_series_iterations; i++) {
taylor_series_sum_comp.in[i][0] <== taylor_series[i];
}
taylor_series_sum <== taylor_series_sum_comp.out[0];

out[0] <== taylor_series_sum + k;
}

0 comments on commit e9ae9ed

Please sign in to comment.