From 4b7b693e61e21cd639468981e204afc809323c75 Mon Sep 17 00:00:00 2001 From: HarryR Date: Thu, 11 Jul 2019 14:42:22 +0100 Subject: [PATCH] Added EVM implementation of Poseidon, ported from Jordi's Circom JS code --- ethsnarks/evmasm.py | 205 ++++++++++++++++++++ ethsnarks/poseidon/contract.py | 301 ++++++++++++++++++++++++++++++ ethsnarks/poseidon/permutation.py | 12 +- requirements-dev.txt | 1 + test/test_poseidon_evm.py | 47 +++++ 5 files changed, 560 insertions(+), 6 deletions(-) create mode 100644 ethsnarks/evmasm.py create mode 100644 ethsnarks/poseidon/contract.py create mode 100644 test/test_poseidon_evm.py diff --git a/ethsnarks/evmasm.py b/ethsnarks/evmasm.py new file mode 100644 index 0000000..9983c56 --- /dev/null +++ b/ethsnarks/evmasm.py @@ -0,0 +1,205 @@ +# Copyright (c) 2018 Jordi Baylina +# Copyright (c) 2019 Harry Roberts +# License: LGPL-3.0+ +# +# Based on: https://github.com/iden3/circomlib/blob/master/src/evmasm.js + +from binascii import unhexlify +from collections import defaultdict + +class _opcode(object): + extra = None + def __init__(self, code, extra=None): + self._code = code + if not isinstance(extra, (bytes, bytearray)): + assert extra is None + self.extra = extra + + def data(self): + extra = self.extra if self.extra is not None else b'' + return bytes([self._code]) + extra + + def __call__(self): + return self + +class LABEL(_opcode): + _code = 0x5b + def __init__(self, name): + assert isinstance(name, (str, bytes, bytearray)) + self.name = name + +def _encode_offset(offset): + return bytes([offset >> 16, (offset >> 8) & 0xFF, offset & 0xFF]) + +class PUSHLABEL(_opcode): + def __init__(self, target): + self.target = target + + def data(self, offset): + assert offset >= 0 and offset < (1<<24) + return bytes([0x62]) + _encode_offset(offset) + +class JMP(PUSHLABEL): + _code = 0x56 + + def __init__(self, target=None): + super(JMP, self).__init__(target) + + def data(self, offset=None): + if offset is not None: + return super(JMP, self).data(offset) + bytes([self._code]) + return bytes([self._code]) + +class JMPI(JMP): + _code = 0x57 + +def DUP(n): + if n < 0 or n >= 16: + raise ValueError("DUP must be 0 to 16") + return _opcode(0x80 + n) + +def SWAP(n): + if n < 0 or n >= 16: + raise ValueError("SWAP must be 0 to 16") + return _opcode(0x8f + n) + +def PUSH(data): + if isinstance(data, int): + if data < 0 or data >= ((1<<256)-1): + raise ValueError("Push value out of range: %r" % (data,)) + hexdata = hex(data)[2:] + if (len(hexdata) % 2) != 0: + hexdata = '0' + hexdata + data = unhexlify(hexdata) + assert isinstance(data, (bytes, bytearray)) + return _opcode(0x5F + len(data), data) + +STOP = _opcode(0x00) +ADD = _opcode(0x01) +MUL = _opcode(0x02) +SUB = _opcode(0x03) +DIV = _opcode(0x04) +SDIV = _opcode(0x05) +MOD = _opcode(0x06) +SMOD = _opcode(0x07) +ADDMOD = _opcode(0x08) +MULMOD = _opcode(0x09) + +EXP = _opcode(0x0a) +SIGNEXTEND = _opcode(0x0b) +LT = _opcode(0x10) +GT = _opcode(0x11) +SLT = _opcode(0x12) +SGT = _opcode(0x13) +EQ = _opcode(0x14) +ISZERO = _opcode(0x15) +AND = _opcode(0x16) +OR = _opcode(0x17) +SHOR = _opcode(0x18) +NOT = _opcode(0x19) +BYTE = _opcode(0x1a) +KECCAK = _opcode(0x20) +SHA3 = _opcode(0x20) + +ADDRESS = _opcode(0x30) +BALANCE = _opcode(0x31) +ORIGIN = _opcode(0x32) +CALLER = _opcode(0x33) +CALLVALUE = _opcode(0x34) +CALLDATALOAD = _opcode(0x35) +CALLDATASIZE = _opcode(0x36) +CALLDATACOPY = _opcode(0x37) +CODESIZE = _opcode(0x38) +CODECOPY = _opcode(0x39) +GASPRICE = _opcode(0x3a) +EXTCODESIZE = _opcode(0x3b) +EXTCODECOPY = _opcode(0x3c) +RETURNDATASIZE = _opcode(0x3d) +RETURNDATACOPY = _opcode(0x3e) + +BLOCKHASH = _opcode(0x40) +COINBASE = _opcode(0x41) +TIMESTAMP = _opcode(0x42) +NUMBER = _opcode(0x43) +DIFFICULTY = _opcode(0x44) +GASLIMIT = _opcode(0x45) + +POP = _opcode(0x50) +MLOAD = _opcode(0x51) +MSTORE = _opcode(0x52) +MSTORE8 = _opcode(0x53) +SLOAD = _opcode(0x54) +SSTORE = _opcode(0x55) +PC = _opcode(0x58) +MSIZE = _opcode(0x59) +GAS = _opcode(0x5a) + +LOG0 = _opcode(0xa0) +LOG1 = _opcode(0xa1) +LOG2 = _opcode(0xa2) +LOG3 = _opcode(0xa3) +LOG4 = _opcode(0xa4) + +CREATE = _opcode(0xf0) +CALL = _opcode(0xf1) +CALLCODE = _opcode(0xf2) +RETURN = _opcode(0xf3) +DELEGATECALL = _opcode(0xf4) +STATICCALL = _opcode(0xfa) +REVERT = _opcode(0xfd) +INVALID = _opcode(0xfe) +SELFDESTRUCT = _opcode(0xff) + +class Codegen(object): + def __init__(self, code=None): + self.code = bytearray() + self._labels = dict() + self._jumps = defaultdict(list) + if code is not None: + self.append(code) + + def createTxData(self): + if len(self._jumps): + raise RuntimeError("Pending labels: " + ','.join(self._jumps.keys())) + + return type(self)([ + PUSH(len(self.code)), # length of code being deployed + DUP(0), + DUP(0), + CODESIZE, # total length + SUB, # codeOffset = (total_length - body_length) + PUSH(0), # memOffset + CODECOPY, + PUSH(0), + RETURN + ]).code + self.code + + def append(self, *args): + for arg in args: + if isinstance(arg, (list, tuple)): + # Allow x.append([opcode, opcode, ...]) + arg = self.append(*arg) + continue + if isinstance(arg, PUSHLABEL): + offset = None + if arg.target is not None: + if arg.target not in self._labels: + self._jumps[arg.target].append(len(self.code)) + offset = 0 # jump destination filled-in later + else: + offset = self._labels[arg.target] + from binascii import hexlify + self.code += arg.data(offset) + elif isinstance(arg, LABEL): + if arg.name in self._labels: + raise RuntimeError("Cannot re-define label %r" % (arg.name,)) + self._labels[arg.name] = len(self.code) + if arg.name in self._jumps: + for jump in self._jumps[arg.name]: + self.code[jump+1:jump+4] = _encode_offset(len(self.code)) + del self._jumps[arg.name] + self.code += arg.data() + elif isinstance(arg, _opcode): + self.code += arg.data() + else: + raise RuntimeError("Unknown opcode %r" % (arg,)) diff --git a/ethsnarks/poseidon/contract.py b/ethsnarks/poseidon/contract.py new file mode 100644 index 0000000..e439fef --- /dev/null +++ b/ethsnarks/poseidon/contract.py @@ -0,0 +1,301 @@ +# Copyright (c) 2018 Jordi Baylina +# Copyright (c) 2019 Harry Roberts +# License: LGPL-3.0+ + + +import sys +import json +from binascii import hexlify +from ..evmasm import * +from ..field import SNARK_SCALAR_FIELD +from .permutation import DefaultParams + + +def _add_round_key(r, t, K): + """ + function ark(r) { + C.push(toHex256(K[r])); // K, st, q + for (let i=0; i=nRoundsP+nRoundsF/2)) { + for (let j=0; j= (params.nRoundsP+(params.nRoundsF//2)): + for j in range(params.t): + yield _sigma(j, params.t) + else: + yield _sigma(0, params.t) + label = 'after_mix_%d' % (i,) + yield [ + PUSHLABEL(label), + PUSH(0), + MSTORE, + JMP('mix'), + LABEL(label) + ] + + """ + C.push("0x00"); + C.mstore(); // Save it to pos 0; + C.push("0x20"); + C.push("0x00"); + C.return(); + mix(); + """ + yield [PUSH(0), + MSTORE, # Save it to pos 0 + PUSH(0x20), + PUSH(0), + RETURN] + + for _ in _mix(params): + yield _ + + +def poseidon_contract(params=None): + gen = Codegen() + for _ in poseidon_contract_opcodes(params): + gen.append(_) + return gen.createTxData() + + +def poseidon_abi(): + return [ + { + "constant": True, + "inputs": [ + { + "name": "input", + "type": "uint256[]" + } + ], + "name": "poseidon", + "outputs": [ + { + "name": "", + "type": "uint256" + } + ], + "payable": False, + "stateMutability": "pure", + "type": "function" + } + ] + + +def main(*args): + if len(args) < 2: + print("Usage: %s [outfile]" % (args[0],)) + return 1 + command = args[1] + outfile = sys.stdout + if len(args) > 2: + outfile = open(args[2], 'wb') + if command == "abi": + outfile.write(json.dumps(poseidon_abi()) + "\n") + return 0 + elif command == "contract": + data = poseidon_contract() + if outfile == sys.stdout: + data = '0x' + hexlify(data).decode('ascii') + outfile.write(data) + return 0 + else: + print("Error: unknown command", command) + if outfile != sys.stdout: + outfile.close() + +if __name__ == "__main__": + sys.exit(main(*sys.argv)) diff --git a/ethsnarks/poseidon/permutation.py b/ethsnarks/poseidon/permutation.py index 4957d2b..587d449 100644 --- a/ethsnarks/poseidon/permutation.py +++ b/ethsnarks/poseidon/permutation.py @@ -14,7 +14,7 @@ - https://github.com/dusk-network/poseidon252 """ -from math import log2 +from math import log2, floor from collections import namedtuple from pyblake2 import blake2b from ..field import SNARK_SCALAR_FIELD @@ -27,16 +27,16 @@ def poseidon_params(p, t, nRoundsF, nRoundsP, seed, e, constants_C=None, constan assert nRoundsF % 2 == 0 and nRoundsF > 0 assert nRoundsP > 0 assert t >= 2 - assert isinstance(seed, bytes) + assert isinstance(seed, bytes) + n = floor(log2(p)) if security_target is None: - M = 128 # security target, in bits + M = n # security target, in bits else: M = security_target - - n = log2(p) assert n >= M + # Size of the state (in bits) N = n * t if p % 2 == 3: @@ -117,7 +117,7 @@ def poseidon_matrix(p, seed, t): for i in range(t)] -DefaultParams = poseidon_params(SNARK_SCALAR_FIELD, 6, 8, 57, b'poseidon', 5) +DefaultParams = poseidon_params(SNARK_SCALAR_FIELD, 6, 8, 57, b'poseidon', 5, security_target=126) def poseidon_sbox(state, i, params): diff --git a/requirements-dev.txt b/requirements-dev.txt index acbe190..6ed990f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ coverage pyflakes pylint +web3 diff --git a/test/test_poseidon_evm.py b/test/test_poseidon_evm.py new file mode 100644 index 0000000..2f7d8bb --- /dev/null +++ b/test/test_poseidon_evm.py @@ -0,0 +1,47 @@ +import unittest +import logging +import sys +from web3 import Web3 +from ethsnarks.poseidon import poseidon +from ethsnarks.poseidon.contract import poseidon_abi, poseidon_contract + +""" +def vm_logger(): + # Enable trace logging for EVM opcodes + from eth.tools.logging import DEBUG2_LEVEL_NUM + level = DEBUG2_LEVEL_NUM + logger = logging.getLogger() + logger.setLevel(level) + + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(level) + logger.addHandler(handler) + +vm_logger() +""" + + +class TestPoseidonEvm(unittest.TestCase): + def setUp(self): + super(TestPoseidonEvm, self).setUp() + w3 = Web3(Web3.EthereumTesterProvider()) + + bytecode = poseidon_contract() + abi = poseidon_abi() + + PoseidonContract = w3.eth.contract(abi=abi, bytecode=bytecode) + tx_hash = PoseidonContract.constructor().transact() + tx_receipt = w3.eth.waitForTransactionReceipt(tx_hash) + self.contract = w3.eth.contract( + address=tx_receipt.contractAddress, + abi=abi) + + def test_basic(self): + inputs = [1,2] + python_result = poseidon(inputs) + evm_result = self.contract.functions.poseidon(inputs).call() + self.assertEqual(evm_result, python_result) + + +if __name__ == "__main__": + unittest.main()