diff --git a/ethsnarks/mimc/__init__.py b/ethsnarks/mimc/__init__.py new file mode 100644 index 0000000..c900e03 --- /dev/null +++ b/ethsnarks/mimc/__init__.py @@ -0,0 +1 @@ +from .permutation import mimc, mimc_hash, mimc_hash_md \ No newline at end of file diff --git a/ethsnarks/mimc/contract.py b/ethsnarks/mimc/contract.py new file mode 100644 index 0000000..20436b8 --- /dev/null +++ b/ethsnarks/mimc/contract.py @@ -0,0 +1,169 @@ +# Copyright (c) 2018 Jordi Baylina +# Copyright (c) 2019 Harry Roberts +# License: LGPL-3.0+ +# Based on: https://github.com/iden3/circomlib/blob/master/src/mimc_gencontract.js + +import sys +import json +from binascii import hexlify + +from ..sha3 import keccak_256 +from ..evmasm import * +from ..field import SNARK_SCALAR_FIELD + +from .permutation import mimc_constants + + +def _mimc_opcodes_round(exponent): + # x = input + # k = key + # q = field modulus + # stack upon entry: x q q k q + # stack upon exit: r k q + if exponent == 7: + return [ + DUP(3), # k x q q k q + ADDMOD, # t=x+k q k q + DUP(1), # q t q k q + DUP(0), # q q t q k q + DUP(2), # t q q t q k q + DUP(0), # t t q q t q k q + MULMOD(), # a=t^2 q t q k q + DUP(1), # q a q t q k q + DUP(1), # a q a q t q k q + DUP(0), # a a q a q t q k q + MULMOD, # b=t^4 a q t q k q + MULMOD, # c=t^6 t q k q + MULMOD # r=t^7 k q + ] + elif exponent == 5: + return [ + DUP(3), # k x q q k q + ADDMOD, # t=x+k q k q + DUP(1), # q t q k q + DUP(0), # q q t q k q + DUP(2), # t q q t q k q + DUP(0), # t t q q t q k q + MULMOD(), # a=t^2 q t q k q + DUP(0), # a a q t q k q + MULMOD, # b=t^4 t q k q + MULMOD # r=t^5 k q + ] + + +def mimc_contract_opcodes(exponent): + assert exponent in (5, 7) + tag = keccak_256(f"MiMCpe{exponent}(uint256,uint256)".encode('ascii')).hexdigest() + + # Ensuring that `exponent ** n_rounds` > SNARK_SCALAR_FIELD + n_rounds = 110 if exponent == 5 else 91 + constants = mimc_constants(R=n_rounds) + + yield [PUSH(0x44), # callDataLength + PUSH(0), # callDataOffset + PUSH(0), # memoryOffset + CALLDATACOPY, + PUSH(1<<224), + PUSH(0), + MLOAD, + DIV, + PUSH(int(tag[:8], 16)), # function selector + EQ, + JMPI('start'), + INVALID] + + yield [LABEL('start'), + PUSH(SNARK_SCALAR_FIELD), # q + PUSH(0x24), + MLOAD] # k q + + yield [ + PUSH(0x04), # 0x04 k q + MLOAD # x k q + ] + + for c_i in constants: + yield [ + DUP(2), # q r k q + DUP(0), # q q r k q + DUP(0), # q q q r k q + SWAP(3), # r q q q k q + PUSH(c_i), # c r q q q k q + ADDMOD, # c+r q q k q + ] + yield _mimc_opcodes_round(exponent) + + # add k to result, then return + yield [ + ADDMOD, # r+k + PUSH(0), # r+k 0 + MSTORE, # + PUSH(0x20), # 0x20 + PUSH(0), # 0 0x20 + RETURN + ] + + +def mimc_abi(exponent): + assert exponent in (5, 7) + return [{ + "constant": True, + "inputs": [ + { + "name": "in_x", + "type": "uint256" + }, + { + "name": "in_k", + "type": "uint256" + } + ], + "name": f"MiMCpe{exponent}", + "outputs": [ + { + "name": "out_x", + "type": "uint256" + } + ], + "payable": False, + "stateMutability": "pure", + "type": "function" + }] + + +def mimc_contract(exponent): + gen = Codegen() + for _ in mimc_contract_opcodes(exponent): + gen.append(_) + return gen.createTxData() + + +def main(*args): + if len(args) < 2: + print("Usage: %s [outfile]" % (args[0],)) + return 1 + command = args[1] + exponent = int(args[2]) + if exponent not in (5, 7): + print("Error: exponent must be 5 or 7") + return 2 + outfile = sys.stdout + if len(args) > 3: + outfile = open(args[3], 'wb') + if command == "abi": + outfile.write(json.dumps(mimc_abi(exponent)) + "\n") + return 0 + elif command == "contract": + data = mimc_contract(exponent) + if outfile == sys.stdout: + data = '0x' + hexlify(data).decode('ascii') + outfile.write(data) + return 0 + else: + print("Error: unknown command", command) + if outfile != sys.stdout: + outfile.close() + + +if __name__ == "__main__": + sys.exit(main(*sys.argv)) diff --git a/ethsnarks/mimc.py b/ethsnarks/mimc/permutation.py similarity index 90% rename from ethsnarks/mimc.py rename to ethsnarks/mimc/permutation.py index 28f10bd..7108559 100644 --- a/ethsnarks/mimc.py +++ b/ethsnarks/mimc/permutation.py @@ -1,16 +1,10 @@ # Copyright (c) 2018 HarryR # License: LGPL-3.0+ -try: - # pysha3 - from sha3 import keccak_256 -except ImportError: - # pycryptodome - from Crypto.Hash import keccak - keccak_256 = lambda *args: keccak.new(*args, digest_bits=256) +from ..sha3 import keccak_256 +from ..field import SNARK_SCALAR_FIELD -SNARK_SCALAR_FIELD = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 DEFAULT_EXPONENT = 7 DEFAULT_ROUNDS = 91 DEFAULT_SEED = b'mimc' @@ -130,6 +124,28 @@ def mimc_hash(x, k=0, seed=DEFAULT_SEED, p=SNARK_SCALAR_FIELD, e=DEFAULT_EXPONEN return k +""" +Merkle-Damgard structure, used to turn a cipher into a one-way-compression function + + m_i + | + | + v + k_{i-1} -->[E] + | + | + v + k_i + +The output is used as the key for the next message +The last output is used as the result +""" +def mimc_hash_md(x, k=0, seed=DEFAULT_SEED, p=SNARK_SCALAR_FIELD, e=DEFAULT_EXPONENT, R=DEFAULT_ROUNDS): + for x_i in x: + k = mimc(x_i, k, seed, p, e, R) + return k + + def _main(): import argparse parser = argparse.ArgumentParser("MiMC") diff --git a/ethsnarks/sha3.py b/ethsnarks/sha3.py new file mode 100644 index 0000000..d6177bc --- /dev/null +++ b/ethsnarks/sha3.py @@ -0,0 +1,7 @@ +try: + # pysha3 + from sha3 import keccak_256 +except ImportError: + # pycryptodome + from Crypto.Hash import keccak + keccak_256 = lambda *args: keccak.new(*args, digest_bits=256) diff --git a/package.json b/package.json index f3df1b9..9f54c7f 100644 --- a/package.json +++ b/package.json @@ -1,19 +1,21 @@ { "name": "ethsnarks", - "version": "0.1.0", + "version": "0.2.0", "description": "zkSNARKS for Ethereum", "main": "truffle.js", "repository": "https://github.com/HarryR/ethsnarks.git", "author": "HarryR@noreply.users.gihub.com", "license": "LGPL-3.0+", "dependencies": { - "solc": "^0.5.9", - "truffle": "^5.0.22" + "rlp": "^2.2.3", + "solc": "^0.5.10", + "truffle": "^5.0.29" }, "devDependencies": { "ganache-cli": "^6.4.4", "solhint": "^1.5.1", - "solidity-coverage": "^0.5.11" + "solidity-coverage": "^0.5.11", + "eth-gas-reporter": "^0.2.4" }, "scripts": { "compile": "truffle compile", diff --git a/src/gadgets/mimc.hpp b/src/gadgets/mimc.hpp index 0153f88..adfff48 100644 --- a/src/gadgets/mimc.hpp +++ b/src/gadgets/mimc.hpp @@ -44,12 +44,77 @@ namespace ethsnarks { */ -#define MIMC_ROUNDS 91 #define MIMC_SEED "mimc" +class MiMCe5_round : public GadgetT { +public: + static constexpr size_t N_ROUNDS = 110; + const VariableT x; + const VariableT k; + const FieldT& C; + const bool add_k_to_result; + const VariableT a; + const VariableT b; + const VariableT c; + +public: + MiMCe5_round( + ProtoboardT& pb, + const VariableT in_x, + const VariableT in_k, + const FieldT& in_C, + const bool in_add_k_to_result, + const std::string &annotation_prefix + ) : + GadgetT(pb, annotation_prefix), + x(in_x), k(in_k), C(in_C), + add_k_to_result(in_add_k_to_result), + a(make_variable(pb, FMT(annotation_prefix, ".a"))), + b(make_variable(pb, FMT(annotation_prefix, ".b"))), + c(make_variable(pb, FMT(annotation_prefix, ".c"))) + { } + + const VariableT& result() const + { + return c; + } + + void generate_r1cs_constraints() + { + auto t = x + k + C; + this->pb.add_r1cs_constraint(ConstraintT(t, t, a), ".a = t*t"); // x^2 + this->pb.add_r1cs_constraint(ConstraintT(a, a, b), ".b = a*a"); // x^4 + + if( add_k_to_result ) + { + this->pb.add_r1cs_constraint(ConstraintT(t, b, c - k), ".c = (b*t) + k"); // x^5 + } + else { + this->pb.add_r1cs_constraint(ConstraintT(t, b, c), ".c = b*t"); // x^5 + } + } + + void generate_r1cs_witness() const + { + const auto val_k = this->pb.val(k); + const auto t = this->pb.val(x) + val_k + C; + + const auto val_a = t * t; + this->pb.val(a) = val_a; + + const auto val_b = val_a * val_a; + this->pb.val(b) = val_b; + + const FieldT result = (val_b * t) + (add_k_to_result ? val_k : FieldT::zero()); + this->pb.val(c) = result; + } +}; + + class MiMCe7_round : public GadgetT { public: + static constexpr size_t N_ROUNDS = 91; const VariableT x; const VariableT k; const FieldT& C; @@ -85,16 +150,16 @@ class MiMCe7_round : public GadgetT { void generate_r1cs_constraints() { auto t = x + k + C; - this->pb.add_r1cs_constraint(ConstraintT(t, t, a), ".a = t*t"); // x^2 - this->pb.add_r1cs_constraint(ConstraintT(a, a, b), ".b = a*a"); // x^4 - this->pb.add_r1cs_constraint(ConstraintT(a, b, c), ".c = a*b"); // x^6 + this->pb.add_r1cs_constraint(ConstraintT(t, t, a), ".a = t*t == t^2"); // x^2 + this->pb.add_r1cs_constraint(ConstraintT(a, a, b), ".b = a*a == t^4"); // x^4 + this->pb.add_r1cs_constraint(ConstraintT(a, b, c), ".c = a*b == t^6"); // x^6 if( add_k_to_result ) { - this->pb.add_r1cs_constraint(ConstraintT(t, c, d - k), ".d = (c*t) + k"); // x^7 + this->pb.add_r1cs_constraint(ConstraintT(t, c, d - k), ".d = (c*t) + k == t^7 + k"); // x^7 } else { - this->pb.add_r1cs_constraint(ConstraintT(t, c, d), ".d = c*t"); // x^7 + this->pb.add_r1cs_constraint(ConstraintT(t, c, d), ".d = c*t == t^7"); // x^7 } } @@ -118,10 +183,11 @@ class MiMCe7_round : public GadgetT { }; -class MiMCe7_gadget : public GadgetT +template +class MiMC_gadget : public GadgetT { public: - std::vector m_rounds; + std::vector m_rounds; const VariableT k; void _setup_gadgets( @@ -142,7 +208,7 @@ class MiMCe7_gadget : public GadgetT } public: - MiMCe7_gadget( + MiMC_gadget( ProtoboardT& pb, const VariableT in_x, const VariableT in_k, @@ -155,7 +221,7 @@ class MiMCe7_gadget : public GadgetT _setup_gadgets(in_x, in_k, in_round_constants); } - MiMCe7_gadget( + MiMC_gadget( ProtoboardT& pb, const VariableT in_x, const VariableT in_k, @@ -196,17 +262,12 @@ class MiMCe7_gadget : public GadgetT */ static const std::vector& static_constants () { - static bool filled = false; static std::vector round_constants; - static std::mutex fill_lock; + static std::once_flag flag; - if( ! filled ) - { - fill_lock.lock(); + std::call_once(flag, [](){ constants_fill(round_constants); - filled = true; - fill_lock.unlock(); - } + }); return round_constants; } @@ -214,12 +275,12 @@ class MiMCe7_gadget : public GadgetT /** * Generate a sequence of round constants from an initial seed value. */ - static void constants_fill( std::vector& round_constants, const char* seed = MIMC_SEED, int round_count = MIMC_ROUNDS ) + static void constants_fill( std::vector& round_constants, const char* seed = MIMC_SEED ) { // XXX: replace '32' with digest size in bytes const size_t DIGEST_SIZE_BYTES = 32; - round_constants.reserve(round_count); + round_constants.reserve(RoundT::N_ROUNDS); unsigned char output_digest[DIGEST_SIZE_BYTES]; @@ -229,7 +290,7 @@ class MiMCe7_gadget : public GadgetT sha3_Update(&ctx, seed, strlen(seed)); memcpy(output_digest, sha3_Finalize(&ctx), DIGEST_SIZE_BYTES); - for( int i = 0; i < round_count; i++ ) + for( int i = 0; i < RoundT::N_ROUNDS; i++ ) { // Derive a sequence of hashes to use as round constants sha3_Init256(&ctx); @@ -256,36 +317,40 @@ class MiMCe7_gadget : public GadgetT } } - static const std::vector constants( const char* seed = MIMC_SEED, int round_count = MIMC_ROUNDS ) + static const std::vector constants( const char* seed = MIMC_SEED ) { std::vector round_constants; - constants_fill(round_constants, seed, round_count); + constants_fill(round_constants, seed); return round_constants; } }; -using MiMC_gadget = MiMCe7_gadget; +using MiMC_e5_gadget = MiMC_gadget; +using MiMC_e7_gadget = MiMC_gadget; -class MiMC_hash_MiyaguchiPreneel_gadget : public MiyaguchiPreneel_OWF +template +class MiMC_hash_MiyaguchiPreneel_gadget : public MiyaguchiPreneel_OWF { public: - using MiyaguchiPreneel_OWF::MiyaguchiPreneel_OWF; + using MiyaguchiPreneel_OWF::MiyaguchiPreneel_OWF; }; -class MiMC_hash_MerkleDamgard_gadget : public MerkleDamgard_OWF +template +class MiMC_hash_MerkleDamgard_gadget : public MerkleDamgard_OWF { public: - using MerkleDamgard_OWF::MerkleDamgard_OWF; + using MerkleDamgard_OWF::MerkleDamgard_OWF; }; // generic aliases for 'MiMC', masks specific implementation -using MiMC_hash_gadget = MiMC_hash_MiyaguchiPreneel_gadget; +using MiMC_e7_hash_gadget = MiMC_hash_MiyaguchiPreneel_gadget; +using MiMC_e5_hash_gadget = MiMC_hash_MiyaguchiPreneel_gadget; @@ -297,7 +362,7 @@ const FieldT mimc( const std::vector& round_constants, const FieldT& x, const VariableT var_k = make_variable(pb, k, "k"); pb.set_input_sizes(2); - MiMC_gadget the_gadget(pb, var_x, var_k, round_constants, "the_gadget"); + MiMC_e7_gadget the_gadget(pb, var_x, var_k, round_constants, "the_gadget"); the_gadget.generate_r1cs_witness(); the_gadget.generate_r1cs_constraints(); @@ -307,7 +372,7 @@ const FieldT mimc( const std::vector& round_constants, const FieldT& x, const FieldT mimc( const FieldT& x, const FieldT& k ) { - return mimc(MiMC_gadget::static_constants(), x, k); + return mimc(MiMC_e7_gadget::static_constants(), x, k); } @@ -325,7 +390,7 @@ const FieldT mimc_hash( const std::vector& m, const FieldT& k ) const auto var_k = make_variable(pb, k, "k"); pb.set_input_sizes(m.size() + 1); - MiMC_hash_gadget the_gadget(pb, var_k, var_m, "the_gadget"); + MiMC_e7_hash_gadget the_gadget(pb, var_k, var_m, "the_gadget"); the_gadget.generate_r1cs_witness(); the_gadget.generate_r1cs_constraints(); diff --git a/src/test/test_merkle_tree.cpp b/src/test/test_merkle_tree.cpp index fe9ff4f..3f35bf0 100644 --- a/src/test/test_merkle_tree.cpp +++ b/src/test/test_merkle_tree.cpp @@ -83,7 +83,7 @@ bool test_merkle_path_authenticator() { pb.val(expected_root) = root; size_t tree_depth = 1; - merkle_path_authenticator auth( + merkle_path_authenticator auth( pb, tree_depth, address_bits, merkle_tree_IVs(pb), leaf, expected_root, path, diff --git a/src/test/test_mimc.cpp b/src/test/test_mimc.cpp index 56b1578..fa13c81 100644 --- a/src/test/test_mimc.cpp +++ b/src/test/test_mimc.cpp @@ -8,7 +8,7 @@ using ethsnarks::ppT; using ethsnarks::FieldT; using ethsnarks::ProtoboardT; using ethsnarks::VariableT; -using ethsnarks::MiMC_gadget; +using ethsnarks::MiMC_e7_gadget; using ethsnarks::make_variable; @@ -28,7 +28,7 @@ bool test_MiMC(const MiMC_TestCase& test_case) const VariableT in_k = make_variable(pb, test_case.key, "k"); pb.set_input_sizes(2); - MiMC_gadget the_gadget(pb, in_x, in_k, "the_gadget"); + MiMC_e7_gadget the_gadget(pb, in_x, in_k, "the_gadget"); the_gadget.generate_r1cs_witness(); the_gadget.generate_r1cs_constraints(); diff --git a/src/test/test_mimc_hash.cpp b/src/test/test_mimc_hash.cpp index f80a820..881705e 100644 --- a/src/test/test_mimc_hash.cpp +++ b/src/test/test_mimc_hash.cpp @@ -20,7 +20,7 @@ bool test_mimc_hash() // Private inputs VariableT iv = make_variable(pb, FieldT("918403109389145570117360101535982733651217667914747213867238065296420114726"), "iv"); - MiMC_hash_gadget the_gadget(pb, iv, {m_0, m_1}, "gadget"); + MiMC_e7_hash_gadget the_gadget(pb, iv, {m_0, m_1}, "gadget"); the_gadget.generate_r1cs_witness(); the_gadget.generate_r1cs_constraints(); diff --git a/src/utils.cpp b/src/utils.cpp index d520258..3f322dc 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -158,6 +158,11 @@ const std::vector bits2blocks_padded(ProtoboardT& in_pb, const V // Set the bit immediately after the input bits to 1 if( in_bits_offset == in_bits.size() ) { in_pb.val(block[j]) = FieldT::one(); + in_pb.add_r1cs_constraint(ConstraintT(block[j], FieldT::one(), FieldT::one())); + } + else if( in_bits_offset < (block_end - length_bits) ) { + // Enforce padding bits are zero + in_pb.add_r1cs_constraint(ConstraintT(block[j], FieldT::zero(), FieldT::zero())); } j += 1; @@ -182,7 +187,9 @@ const std::vector bits2blocks_padded(ProtoboardT& in_pb, const V size_t k = 0; for( size_t j = (block_size - length_bits); j < block_size; j++ ) { - in_pb.val(last_block[j]) = FieldT(bitlen_bits[k++]); + const auto value = FieldT(bitlen_bits[k++]); + in_pb.add_r1cs_constraint(ConstraintT(last_block[j], 1, value)); + in_pb.val(last_block[j]) = value; } return out_blocks; diff --git a/test/test_mimc_evm.py b/test/test_mimc_evm.py new file mode 100644 index 0000000..df50c5a --- /dev/null +++ b/test/test_mimc_evm.py @@ -0,0 +1,39 @@ +import unittest +import logging +import sys +from web3 import Web3 + +from ethsnarks.field import FQ +from ethsnarks.mimc import mimc +from ethsnarks.mimc.contract import mimc_abi, mimc_contract + + +class TestMiMCEvm(unittest.TestCase): + def _deploy_contract(self, w3, exponent): + abi = mimc_abi(exponent) + contract = w3.eth.contract(abi=abi, bytecode=mimc_contract(exponent)) + tx_hash = contract.constructor().transact() + tx_receipt = w3.eth.waitForTransactionReceipt(tx_hash) + return w3.eth.contract(address=tx_receipt.contractAddress, abi=abi) + + def setUp(self): + super(TestMiMCEvm, self).setUp() + w3 = Web3(Web3.EthereumTesterProvider()) + self.contract_e7 = self._deploy_contract(w3, 7) + self.contract_e5 = self._deploy_contract(w3, 5) + + def test_e7(self): + m_i, k_i = int(FQ.random()), int(FQ.random()) + python_result = mimc(m_i, k_i, e=7, R=91) + evm_result = self.contract_e7.functions.MiMCpe7(m_i, k_i).call() + self.assertEqual(evm_result, python_result) + + def test_e5(self): + m_i, k_i = int(FQ.random()), int(FQ.random()) + python_result = mimc(m_i, k_i, e=5, R=110) + evm_result = self.contract_e5.functions.MiMCpe5(m_i, k_i).call() + self.assertEqual(evm_result, python_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/truffle.js b/truffle.js index 03797ba..05c60ee 100644 --- a/truffle.js +++ b/truffle.js @@ -18,5 +18,11 @@ module.exports = { enabled: true, runs: 200 } + }, + mocha: { + reporter: 'eth-gas-reporter', + reporterOptions: { + onlyCalledMethods: false + } } } \ No newline at end of file