diff --git a/zkstats/arithc_to_bristol.py b/zkstats/arithc_to_bristol.py index 382b163..6a9abe6 100644 --- a/zkstats/arithc_to_bristol.py +++ b/zkstats/arithc_to_bristol.py @@ -74,7 +74,9 @@ def _parse_arithc_json(arithc_path: str): main_outputs[anode.id] = node_name[2:] if anode.is_const: const_values[anode.id] = anode.const_value - const_names[anode.id] = node_name + # Make sure each constant has a unique name since different signals can have the same name + # E.g. `for (var i = 0; i < 16; i++)`, all 16 signals will have the same name `i` + const_names[anode.id] = f"{node_name}_{anode.id}" for gate in data['gates']: gate_id = gate['id'] @@ -292,33 +294,45 @@ def generate_circuit_info(self): # } rid_to_iid = {node.rid: node.iid for node in tt.sorted_wires} - # Map input name to wire index in MP-SPDZ circuit (including constant wires) + # Map highest level input name to wire index in MP-SPDZ circuit (including constant wires) + non_const_main_input_rids = [node_rid for node_rid in main_inputs if node_rid not in const_values] input_name_to_wire_index = { main_inputs[node_rid]: rid_to_iid[node_rid] - for node_rid in self.leaves if node_rid not in const_values + for node_rid in non_const_main_input_rids } - - # FIXME: outputs without a gate are skipped (i.e. direct assigned from input or a constant, etc) + if len(input_name_to_wire_index) != len(non_const_main_input_rids): + raise Exception("Some inputs have the same name. Please make sure all inputs have distinct names.") # Prepare constants: const_values is what we want # Just sanity check for all constant must be in leaves so we don't miss passing any of them to MP-SPDZ circuit + # Check if every node has distinct names + const_rids = [ + node_rid for node_rid in const_values + if node_rid in rid_to_iid # Skip constant wires that are not used in any gates. E.g. constant outputs + ] const_name_to_value_wire_id = { const_names[node_rid]: { - 'value': const_value, - 'wire_index': rid_to_iid[node_rid], + 'value': const_values[node_rid], + 'wire_index': rid_to_iid[node_rid] } - for node_rid, const_value in const_values.items() - if node_rid in rid_to_iid # Skip constant wires that are not used in any gates. E.g. constant outputs + for node_rid in const_rids } + if len(const_name_to_value_wire_id) != len(const_rids): + raise Exception("Some constants have the same name. Please make sure all constants have distinct names.") # Prepare outputs - # Map output name to wire index in MP-SPDZ circuit - output_name_to_wire_index = { - output_name: rid_to_iid[node_rid] - for node_rid, output_name in main_outputs.items() + # Map highest level output name to wire index in MP-SPDZ circuit + non_const_main_output_rids = [ + node_rid for node_rid in main_outputs if node_rid in rid_to_iid # Skip output wires that are not used in any gates. E.g. constant outputs + ] + output_name_to_wire_index = { + main_outputs[node_rid]: rid_to_iid[node_rid] + for node_rid in non_const_main_output_rids } - print("!@# output_name_to_wire_index=", output_name_to_wire_index) + if len(output_name_to_wire_index) != len(non_const_main_output_rids): + raise Exception("Some outputs have the same name. Please make sure all outputs have distinct names.") + return { "input_name_to_wire_index": input_name_to_wire_index, "constants": const_name_to_value_wire_id, diff --git a/zkstats/onnx2circom/mpc.circom b/zkstats/onnx2circom/mpc.circom index f7a2467..1ff98ec 100644 --- a/zkstats/onnx2circom/mpc.circom +++ b/zkstats/onnx2circom/mpc.circom @@ -60,127 +60,83 @@ template TFReduceMean(nInputs) { out[0] <== div.out[0]; } + +// TODO: e should be 2.71828 instead of 2 for now template TFLog(e) { signal input in[1][1]; - // b must be passed in as secret signal output out[1]; - var upper_bound = 16; - // find b so that b = e^k and x/b <= 1 + // Approximate natural log with talyer series. For 0 < x <= 2 + // - ln(x) = ln(1 + (x-1)) = 0 + (x-1) - (x-1)^2/2 + (x-1)^3/3 - (x-1)^4/4 + ... + // - To ensure x <= 2, we can use the following property of logarithm: + // ln(x) = ln(x / b) + ln(b) where b is an integer e^k. So, we need to calculate + // - Step 1: b = e^k such that x/b <= 1 + // - Step 2: ln(x/b) using talyer series + // - Step 3: ln(x) = ln(x/b) + ln(b) = ln(x/b) + k + + // x can be only up to e^max_exponent + var max_exponent = 64; + var taylor_series_iterations = 40; + + // Step 1: Find b so that b = e^k and x/b <= 1 // find b so that x/b < 1 and can be used in talyer series // b = e^k, e=2.71828 - signal e_until[upper_bound]; + signal x <== in[0][0]; + // e_until[i] = e^i + signal e_until[max_exponent]; e_until[0] <== 1; - for (var i = 1; i < upper_bound; i++) { + for (var i = 1; i < max_exponent; i++) { e_until[i] <== e_until[i-1] * e; } - signal x <== in[0][0]; - // find k s.t. e_until[k] >= x and e_until[k-1] < x - // signal k; // e^k >= x and e^(k-1) < x // sel : [0, 0, 1, 0, 0] // k : [0, 1, 2, 3, 4] // sum(sel*k) = k - - signal sel[upper_bound]; + signal sel[max_exponent]; sel[0] <== x <= 1; - component sel_comp[upper_bound]; - for (var i = 1; i < upper_bound; i++) { + component sel_comp[max_exponent]; + for (var i = 1; i < max_exponent; i++) { sel_comp[i] = TFMul(); sel_comp[i].in[0][0] <== x > e_until[i-1]; sel_comp[i].in[1][0] <== x <= e_until[i]; sel[i] <== sel_comp[i].out[0]; } - - // component sum = TFReduceSum(nInputs); - // for (var i = 0; i