Skip to content

Commit

Permalink
fix: circom loop
Browse files Browse the repository at this point in the history
What's the issue: in our `arithc_to_bristol`, it didn't consider different node names
can be the same. For example, in `for (var i = 0; i < 16; i++)`, var i
always has the same name for [0, 16) in arithc. Since we use node names
as keys in circuit_info.json, different constants i=0, ..., i=15 will
keep overwriting each other and finally in MP-SPDZ there was only one wire
value for these different nodes and resulted that some wires wasn't assigned
values.

How it is fixed: jut make each constant node name distinct, by postfixing
with their unique node value. E.g. i=0 might be "i_123128318238" while i=15
be "i_3824329329"
  • Loading branch information
mhchia committed Apr 23, 2024
1 parent 991e7a1 commit 968495b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 104 deletions.
42 changes: 28 additions & 14 deletions zkstats/arithc_to_bristol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _parse_arithc_json(arithc_path: str):
main_outputs[anode.id] = node_name[2:]
if anode.is_const:
const_values[anode.id] = anode.const_value
const_names[anode.id] = node_name
# Make sure each constant has a unique name since different signals can have the same name
# E.g. `for (var i = 0; i < 16; i++)`, all 16 signals will have the same name `i`
const_names[anode.id] = f"{node_name}_{anode.id}"

for gate in data['gates']:
gate_id = gate['id']
Expand Down Expand Up @@ -292,33 +294,45 @@ def generate_circuit_info(self):
# }

rid_to_iid = {node.rid: node.iid for node in tt.sorted_wires}
# Map input name to wire index in MP-SPDZ circuit (including constant wires)
# Map highest level input name to wire index in MP-SPDZ circuit (including constant wires)
non_const_main_input_rids = [node_rid for node_rid in main_inputs if node_rid not in const_values]
input_name_to_wire_index = {
main_inputs[node_rid]: rid_to_iid[node_rid]
for node_rid in self.leaves if node_rid not in const_values
for node_rid in non_const_main_input_rids
}

# FIXME: outputs without a gate are skipped (i.e. direct assigned from input or a constant, etc)
if len(input_name_to_wire_index) != len(non_const_main_input_rids):
raise Exception("Some inputs have the same name. Please make sure all inputs have distinct names.")

# Prepare constants: const_values is what we want
# Just sanity check for all constant must be in leaves so we don't miss passing any of them to MP-SPDZ circuit
# Check if every node has distinct names
const_rids = [
node_rid for node_rid in const_values
if node_rid in rid_to_iid # Skip constant wires that are not used in any gates. E.g. constant outputs
]
const_name_to_value_wire_id = {
const_names[node_rid]: {
'value': const_value,
'wire_index': rid_to_iid[node_rid],
'value': const_values[node_rid],
'wire_index': rid_to_iid[node_rid]
}
for node_rid, const_value in const_values.items()
if node_rid in rid_to_iid # Skip constant wires that are not used in any gates. E.g. constant outputs
for node_rid in const_rids
}
if len(const_name_to_value_wire_id) != len(const_rids):
raise Exception("Some constants have the same name. Please make sure all constants have distinct names.")

# Prepare outputs
# Map output name to wire index in MP-SPDZ circuit
output_name_to_wire_index = {
output_name: rid_to_iid[node_rid]
for node_rid, output_name in main_outputs.items()
# Map highest level output name to wire index in MP-SPDZ circuit
non_const_main_output_rids = [
node_rid for node_rid in main_outputs
if node_rid in rid_to_iid # Skip output wires that are not used in any gates. E.g. constant outputs
]
output_name_to_wire_index = {
main_outputs[node_rid]: rid_to_iid[node_rid]
for node_rid in non_const_main_output_rids
}
print("!@# output_name_to_wire_index=", output_name_to_wire_index)
if len(output_name_to_wire_index) != len(non_const_main_output_rids):
raise Exception("Some outputs have the same name. Please make sure all outputs have distinct names.")

return {
"input_name_to_wire_index": input_name_to_wire_index,
"constants": const_name_to_value_wire_id,
Expand Down
136 changes: 46 additions & 90 deletions zkstats/onnx2circom/mpc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -60,127 +60,83 @@ template TFReduceMean(nInputs) {
out[0] <== div.out[0];
}


// TODO: e should be 2.71828 instead of 2 for now
template TFLog(e) {
signal input in[1][1];
// b must be passed in as secret

signal output out[1];

var upper_bound = 16;
// find b so that b = e^k and x/b <= 1
// Approximate natural log with talyer series. For 0 < x <= 2
// - ln(x) = ln(1 + (x-1)) = 0 + (x-1) - (x-1)^2/2 + (x-1)^3/3 - (x-1)^4/4 + ...
// - To ensure x <= 2, we can use the following property of logarithm:
// ln(x) = ln(x / b) + ln(b) where b is an integer e^k. So, we need to calculate
// - Step 1: b = e^k such that x/b <= 1
// - Step 2: ln(x/b) using talyer series
// - Step 3: ln(x) = ln(x/b) + ln(b) = ln(x/b) + k

// x can be only up to e^max_exponent
var max_exponent = 64;
var taylor_series_iterations = 40;

// Step 1: Find b so that b = e^k and x/b <= 1
// find b so that x/b < 1 and can be used in talyer series
// b = e^k, e=2.71828
signal e_until[upper_bound];
signal x <== in[0][0];
// e_until[i] = e^i
signal e_until[max_exponent];
e_until[0] <== 1;
for (var i = 1; i < upper_bound; i++) {
for (var i = 1; i < max_exponent; i++) {
e_until[i] <== e_until[i-1] * e;
}
signal x <== in[0][0];
// find k s.t. e_until[k] >= x and e_until[k-1] < x
// signal k;

// e^k >= x and e^(k-1) < x
// sel : [0, 0, 1, 0, 0]
// k : [0, 1, 2, 3, 4]
// sum(sel*k) = k

signal sel[upper_bound];
signal sel[max_exponent];
sel[0] <== x <= 1;
component sel_comp[upper_bound];
for (var i = 1; i < upper_bound; i++) {
component sel_comp[max_exponent];
for (var i = 1; i < max_exponent; i++) {
sel_comp[i] = TFMul();
sel_comp[i].in[0][0] <== x > e_until[i-1];
sel_comp[i].in[1][0] <== x <= e_until[i];
sel[i] <== sel_comp[i].out[0];
}

// component sum = TFReduceSum(nInputs);
// for (var i = 0; i<nInputs; i++){
// sum.in[i][0] <== in[i][0];
// }

// FIXME: investigate why we get errors in MP-SPDZ
// if we use a for loop here
component k_by_sum = TFReduceSum(upper_bound);
// for (var i = 0; i < upper_bound; i++) {
// k_by_sum.in[i][0] <== sel[i] * i;
// }
k_by_sum.in[0][0] <== sel[0] * 0;
k_by_sum.in[1][0] <== sel[1] * 1;
k_by_sum.in[2][0] <== sel[2] * 2;
k_by_sum.in[3][0] <== sel[3] * 3;
k_by_sum.in[4][0] <== sel[4] * 4;
k_by_sum.in[5][0] <== sel[5] * 5;
k_by_sum.in[6][0] <== sel[6] * 6;
k_by_sum.in[7][0] <== sel[7] * 7;
k_by_sum.in[8][0] <== sel[8] * 8;
k_by_sum.in[9][0] <== sel[9] * 9;
k_by_sum.in[10][0] <== sel[10] * 10;
k_by_sum.in[11][0] <== sel[11] * 11;
k_by_sum.in[12][0] <== sel[12] * 12;
k_by_sum.in[13][0] <== sel[13] * 13;
k_by_sum.in[14][0] <== sel[14] * 14;
k_by_sum.in[15][0] <== sel[15] * 15;

component k_by_sum = TFReduceSum(max_exponent);
for (var i = 0; i < max_exponent; i++) {
k_by_sum.in[i][0] <== sel[i] * i;
}
signal k <== k_by_sum.out[0];

// FIXME: investigate why we get errors in MP-SPDZ
// if we use a for loop here
// component b_by_sum = TFReduceSum(upper_bound);
// for (var i = 0; i < upper_bound; i++) {
// b_by_sum.in[i][0] <== sel[i] * e_until[i];
// }
// signal b <== b_by_sum.out[0];
component b_by_sum = TFReduceSum(upper_bound);
b_by_sum.in[0][0] <== sel[0] * e_until[0];
b_by_sum.in[1][0] <== sel[1] * e_until[1];
b_by_sum.in[2][0] <== sel[2] * e_until[2];
b_by_sum.in[3][0] <== sel[3] * e_until[3];
b_by_sum.in[4][0] <== sel[4] * e_until[4];
b_by_sum.in[5][0] <== sel[5] * e_until[5];
b_by_sum.in[6][0] <== sel[6] * e_until[6];
b_by_sum.in[7][0] <== sel[7] * e_until[7];
b_by_sum.in[8][0] <== sel[8] * e_until[8];
b_by_sum.in[9][0] <== sel[9] * e_until[9];
b_by_sum.in[10][0] <== sel[10] * e_until[10];
b_by_sum.in[11][0] <== sel[11] * e_until[11];
b_by_sum.in[12][0] <== sel[12] * e_until[12];
b_by_sum.in[13][0] <== sel[13] * e_until[13];
b_by_sum.in[14][0] <== sel[14] * e_until[14];
b_by_sum.in[15][0] <== sel[15] * e_until[15];

// sum(sel*e^k) = b
component b_by_sum = TFReduceSum(max_exponent);
for (var i = 0; i < max_exponent; i++) {
b_by_sum.in[i][0] <== sel[i] * e_until[i];
}
signal b <== b_by_sum.out[0];

// Step 2: Calculate ln(x/b) using talyer series
signal x_over_b <== x / b;

var taylor_series_iterations = 40;
signal x_over_b_minus_one_exp[taylor_series_iterations+1];
signal x_over_b_minus_one_exp[taylor_series_iterations];
x_over_b_minus_one_exp[0] <== 0;
x_over_b_minus_one_exp[1] <== (x / b) - 1;
for (var i = 2; i < taylor_series_iterations+1; i++) {
for (var i = 2; i < taylor_series_iterations; i++) {
x_over_b_minus_one_exp[i] <== x_over_b_minus_one_exp[i-1] * (1 - x_over_b);
}

signal taylor_series[taylor_series_iterations];
taylor_series[0] <== 0;
for (var i = 1; i < taylor_series_iterations; i++) {
taylor_series[i] <== x_over_b_minus_one_exp[i] / i;
}

// log(x) = log(x/b) + log(b)
// use talyer series to approximate
// log(x) = log(1 + (x-1)) = (x-1) - (x-1)^2/2 + (x-1)^3/3 - (x-1)^4/4 + ...

// FIXME: investigate why we get errors in MP-SPDZ
// if we use a for loop here
// signal taylor_series[taylor_series_iterations+1];
// taylor_series[0] <== 0;
// for (var i = 1; i < taylor_series_iterations+1; i++) {
// taylor_series[i] <== x_over_b_minus_one_exp[i] / i;
// }
// out[0] <== k + taylor_series[0]+taylor_series[1]+taylor_series[2]+taylor_series[3]+taylor_series[4]+taylor_series[5]+taylor_series[6]+taylor_series[7]+taylor_series[8]+taylor_series[9]+taylor_series[10]+taylor_series[11]+taylor_series[12]+taylor_series[13]+taylor_series[14]+taylor_series[15]+taylor_series[16];
// signal taylor_series_sum;
// component taylor_series_sum_comp = TFReduceSum(taylor_series_iterations);
// for (var i = 0; i < taylor_series_iterations; i++) {
// taylor_series_sum_comp.in[i][0] <== taylor_series[i+1];
// }
// taylor_series_sum <== taylor_series_sum_comp.out[0];
// out[0] <== taylor_series_sum + k;
out[0] <== k + x_over_b_minus_one_exp[1]+x_over_b_minus_one_exp[2]/2+x_over_b_minus_one_exp[3]/3+x_over_b_minus_one_exp[4]/4+x_over_b_minus_one_exp[5]/5+x_over_b_minus_one_exp[6]/6 + x_over_b_minus_one_exp[7]/7 + x_over_b_minus_one_exp[8]/8 + x_over_b_minus_one_exp[9]/9 + x_over_b_minus_one_exp[10]/10+x_over_b_minus_one_exp[11]/11+x_over_b_minus_one_exp[12]/12+x_over_b_minus_one_exp[13]/13+x_over_b_minus_one_exp[14]/14+x_over_b_minus_one_exp[15]/15+x_over_b_minus_one_exp[16]/16 + x_over_b_minus_one_exp[17]/17 + x_over_b_minus_one_exp[18]/18 + x_over_b_minus_one_exp[19]/19 + x_over_b_minus_one_exp[20]/20+x_over_b_minus_one_exp[21]/21+x_over_b_minus_one_exp[22]/22+x_over_b_minus_one_exp[23]/23+x_over_b_minus_one_exp[24]/24+x_over_b_minus_one_exp[25]/25+x_over_b_minus_one_exp[26]/26+x_over_b_minus_one_exp[27]/27+x_over_b_minus_one_exp[28]/28+x_over_b_minus_one_exp[29]/29+x_over_b_minus_one_exp[30]/30+x_over_b_minus_one_exp[31]/31+x_over_b_minus_one_exp[32]/32+x_over_b_minus_one_exp[33]/33+x_over_b_minus_one_exp[34]/34+x_over_b_minus_one_exp[35]/35+x_over_b_minus_one_exp[36]/36+x_over_b_minus_one_exp[37]/37+x_over_b_minus_one_exp[38]/38+x_over_b_minus_one_exp[39]/39+x_over_b_minus_one_exp[40]/40;
signal taylor_series_sum;
component taylor_series_sum_comp = TFReduceSum(taylor_series_iterations);
for (var i = 0; i < taylor_series_iterations; i++) {
taylor_series_sum_comp.in[i][0] <== taylor_series[i];
}
taylor_series_sum <== taylor_series_sum_comp.out[0];

out[0] <== taylor_series_sum + k;
}

0 comments on commit 968495b

Please sign in to comment.