Skip to content

Commit

Permalink
Added Poseidon merkle-tree (with test for binary tree)
Browse files Browse the repository at this point in the history
  • Loading branch information
HarryR authored and HarryR committed Jul 10, 2019
1 parent 89da7b7 commit 26b0fb9
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 28 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The following gadgets are available
* [2-bit lookup table](src/gadgets/lookup_2bit.cpp)
* [3-bit lookup table](src/gadgets/lookup_3bit.cpp)
* [MiMC](https://eprint.iacr.org/2016/492) hash and cipher
* [Poseidon](https://eprint.iacr.org/2019/458.pdf) hash function
* [Miyaguchi-Preneel one-way function](https://en.wikipedia.org/wiki/One-way_compression_function)
* Merkle tree
* SHA256 (Ethereum compatible, full round)
Expand Down
29 changes: 25 additions & 4 deletions ethsnarks/merkletree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
from collections import namedtuple

from .poseidon import poseidon, DefaultParams as poseidon_DefaultParams
from .mimc import mimc_hash
from .field import FQ, SNARK_SCALAR_FIELD

Expand Down Expand Up @@ -46,6 +47,7 @@ def valid(self, item):
return isinstance(item, int) and item > 0 and item < SNARK_SCALAR_FIELD


# TODO: move to ethsnarks.mimc ?
class MerkleHasher_MiMC(Abstract_MerkleHasher):
def __init__(self, tree_depth, node_width=2):
if node_width != 2:
Expand All @@ -57,6 +59,25 @@ def hash_node(self, depth, *args):
return mimc_hash(args, self._IVs[depth])


# TODO: move to ethsnarks.poseidon?
class MerkleHasher_Poseidon(Abstract_MerkleHasher):
def __init__(self, params, depth, node_width=2):
assert node_width > 0
if params is None:
params = poseidon_DefaultParams
if node_width >= (params.t - 1) or node_width <= 0:
raise ValueError("Node width must be in range: 0 < width < (t-1)")
self._params = params
self._tree_depth = depth

@classmethod
def factory(cls, params=None):
return lambda *args, **kwa: cls(params, *args, **kwa)

def hash_node(self, depth, *args):
return poseidon(args, params=self._params)


DEFAULT_HASHER = MerkleHasher_MiMC


Expand All @@ -82,14 +103,14 @@ class MerkleTree(object):
Each element of the proof supplies the index that the previous output will be inserted
into the list of other elements in the hash to re-construct the root
"""
def __init__(self, n_items, width=2, tree_hasher=None):
def __init__(self, n_items, width=2, hasher=None):
assert n_items >= width
assert (n_items % width) == 0
if tree_hasher is None:
tree_hasher = DEFAULT_HASHER
if hasher is None:
hasher = DEFAULT_HASHER
self._width = width
self._tree_depth = int(math.log(n_items, width))
self._hasher = tree_hasher(self._tree_depth, width)
self._hasher = hasher(self._tree_depth, width)
self._n_items = n_items
self._cur = 0
self._leaves = [list() for _ in range(0, self._tree_depth + 1)]
Expand Down
5 changes: 5 additions & 0 deletions ethsnarks/poseidon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .permutation import *

if __name__ == "__main__":
# TODO: implement 'constants' and 'matrix'
pass
196 changes: 196 additions & 0 deletions ethsnarks/poseidon/permutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#!/usr/bin/env python

"""
Implements the Poseidon permutation:
Starkad and Poseidon: New Hash Functions for Zero Knowledge Proof Systems
- Lorenzo Grassi, Daniel Kales, Dmitry Khovratovich, Arnab Roy, Christian Rechberger, and Markus Schofnegger
- https://eprint.iacr.org/2019/458.pdf
Other implementations:
- https://github.com/shamatar/PoseidonTree/
- https://github.com/iden3/circomlib/blob/master/src/poseidon.js
- https://github.com/dusk-network/poseidon252
"""

from math import log2
from collections import namedtuple
from pyblake2 import blake2b
from ..field import SNARK_SCALAR_FIELD


PoseidonParamsType = namedtuple('_PoseidonParams', ('p', 't', 'nRoundsF', 'nRoundsP', 'seed', 'e', 'constants_C', 'constants_M'))


def poseidon_params(p, t, nRoundsF, nRoundsP, seed, e, constants_C=None, constants_M=None, security_target=None):
assert nRoundsF % 2 == 0 and nRoundsF > 0
assert nRoundsP > 0
assert t >= 2
assert isinstance(seed, bytes)

if security_target is None:
M = 128 # security target, in bits
else:
M = security_target

n = log2(p)
assert n >= M

N = n * t

if p % 2 == 3:
assert e == 3
grobner_attack_ratio_rounds = 0.32
grobner_attack_ratio_sboxes = 0.18
interpolation_attack_ratio = 0.63
elif p % 5 != 1:
assert e == 5
grobner_attack_ratio_rounds = 0.21
grobner_attack_ratio_sboxes = 0.14
interpolation_attack_ratio = 0.43
else:
# XXX: in other cases use, can we use 7?
raise ValueError('Invalid p for congruency')

# Verify that the parameter choice exceeds the recommendations to prevent attacks
# iacr.org/2019/458 § 3 Cryptanalysis Summary of Starkad and Poseidon Hashes (pg 10)
# Figure 1
#print('(nRoundsF + nRoundsP)', (nRoundsF + nRoundsP))
#print('Interpolation Attackable Rounds', ((interpolation_attack_ratio * min(n, M)) + log2(t)))
assert (nRoundsF + nRoundsP) > ((interpolation_attack_ratio * min(n, M)) + log2(t))
# Figure 3
#print('grobner_attack_ratio_rounds', ((2 + min(M, n)) * grobner_attack_ratio_rounds))
assert (nRoundsF + nRoundsP) > ((2 + min(M, n)) * grobner_attack_ratio_rounds)
# Figure 4
#print('grobner_attack_ratio_sboxes', (M * grobner_attack_ratio_sboxes))
assert (nRoundsF + (t * nRoundsP)) > (M * grobner_attack_ratio_sboxes)

# iacr.org/2019/458 § 4.1 Minimize "Number of S-Boxes"
# In order to minimize the number of S-boxes for given `n` and `t`, the goal is to and
# the best ratio between RP and RF that minimizes:
# number of S-Boxes = t · RF + RP
# - Use S-box x^q
# - Select R_F to 6 or rhigher
# - Select R_P that minimizes tRF +RP such that no inequation (1),(3),(4),(5) is satisfied.

if constants_C is None:
constants_C = list(poseidon_constants(p, seed + b'_constants', nRoundsF + nRoundsP))
if constants_M is None:
constants_M = poseidon_matrix(p, seed + b'_matrix_0000', t)

# iacr.org/2019/458 § 4.1 6 SNARKs Application via Poseidon-π
# page 16 formula (8) and (9)
n_constraints = (nRoundsF * t) + nRoundsP
if e == 5:
n_constraints *= 3
elif e == 3:
n_constraints *= 2
#print('n_constraints', n_constraints)

return PoseidonParamsType(p, t, nRoundsF, nRoundsP, seed, e, constants_C, constants_M)


def H(arg):
if isinstance(arg, int):
arg = arg.to_bytes(32, 'little')
# XXX: ensure that (digest_size*8) >= log2(p)
hashed = blake2b(data=arg, digest_size=32).digest()
return int.from_bytes(hashed, 'little')


def poseidon_constants(p, seed, n):
assert isinstance(n, int)
for _ in range(n):
seed = H(seed)
yield seed % p


def poseidon_matrix(p, seed, t):
"""
iacr.org/2019/458 § 2.3 About the MDS Matrix (pg 8)
Also:
- https://en.wikipedia.org/wiki/Cauchy_matrix
"""
c = list(poseidon_constants(p, seed, t * 2))
return [[pow((c[i] - c[t+j]) % p, p - 2, p) for j in range(t)]
for i in range(t)]


DefaultParams = poseidon_params(SNARK_SCALAR_FIELD, 6, 8, 57, b'poseidon', 5)


def poseidon_sbox(state, i, params):
"""
iacr.org/2019/458 § 2.2 The Hades Strategy (pg 6)
In more details, assume R_F = 2 · R_f is an even number. Then
- the first R_f rounds have a full S-Box layer,
- the middle R_P rounds have a partial S-Box layer (i.e., 1 S-Box layer),
- the last R_f rounds have a full S-Box layer
"""
half_F = params.nRoundsF // 2
e, p = params.e, params.p
if i < half_F or i >= (half_F + params.nRoundsP):
for j, _ in enumerate(state):
state[j] = pow(_, e, p)
else:
state[0] = pow(state[0], e, p)


def poseidon_mix(state, M, p):
"""
The mixing layer is a matrix vector product of the state with the mixing matrix
- https://mathinsight.org/matrix_vector_multiplication
"""
return [ sum([M[i][j] * _ for j, _ in enumerate(state)]) % p
for i in range(len(M)) ]


def poseidon(inputs, params=None, chained=False, trace=False):
"""
Main instansiation of the Poseidon permutation
The state is `t` elements wide, there are `F` full-rounds
followed by `P` partial rounds, then `F` full rounds again.
[ ARK ] --,
| | | | | | |
[ SBOX ] - Full Round
| | | | | | |
[ MIX ] --`
[ ARK ] --,
| | | | | | |
[ SBOX ] - Partial Round
| | Only 1 element is substituted in partial round
[ MIX ] --`
There are F+P rounds for the full permutation.
You can provide `r = N - 2s` bits of input per round, where `s` is the desired
security level, in most cases this means you can provide `t-1` inputs with
appropriately chosen parameters. The permutation can be 'chained' together
to form a sponge construct.
"""
if params is None:
params = DefaultParams
assert isinstance(params, PoseidonParamsType)
assert len(inputs) > 0
if not chained:
# Don't allow inputs to exceed the rate, unless in chained mode
assert len(inputs) < params.t
state = [0] * params.t
state[:len(inputs)] = inputs
for i, C_i in enumerate(params.constants_C):
state = [_ + C_i for _ in state] # ARK(.)
poseidon_sbox(state, i, params)
state = poseidon_mix(state, params.constants_M, params.p)
if trace:
for j, val in enumerate(state):
print('%d %d' % (i, j), '=', val)
if chained:
# Provide the full state as output in 'chained' mode
return state
return state[0]
49 changes: 25 additions & 24 deletions test/test_merkle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

import hashlib
from ethsnarks.merkletree import MerkleTree, DEFAULT_HASHER
from ethsnarks.merkletree import MerkleTree, DEFAULT_HASHER, MerkleHasher_Poseidon
from ethsnarks.field import FQ, SNARK_SCALAR_FIELD


Expand Down Expand Up @@ -47,29 +47,30 @@ def test_known1(self):

def test_update(self):
# Verify that items in the tree can be updated
tree = MerkleTree(2)
tree.append(FQ.random())
tree.append(FQ.random())
proof_0_before = tree.proof(0)
proof_1_before = tree.proof(1)
root_before = tree.root
self.assertTrue(proof_0_before.verify(tree.root))
self.assertTrue(proof_1_before.verify(tree.root))

leaf_0_after = FQ.random()
tree.update(0, leaf_0_after)
root_after_0 = tree.root
proof_0_after = tree.proof(0)
self.assertTrue(proof_0_after.verify(tree.root))
self.assertNotEqual(root_before, root_after_0)

leaf_1_after = FQ.random()
tree.update(1, leaf_1_after)
root_after_1 = tree.root
proof_1_after = tree.proof(1)
self.assertTrue(proof_1_after.verify(tree.root))
self.assertNotEqual(root_before, root_after_1)
self.assertNotEqual(root_after_0, root_after_1)
for hasher in [DEFAULT_HASHER, MerkleHasher_Poseidon.factory()]:
tree = MerkleTree(2, hasher=hasher)
tree.append(FQ.random())
tree.append(FQ.random())
proof_0_before = tree.proof(0)
proof_1_before = tree.proof(1)
root_before = tree.root
self.assertTrue(proof_0_before.verify(tree.root))
self.assertTrue(proof_1_before.verify(tree.root))

leaf_0_after = FQ.random()
tree.update(0, leaf_0_after)
root_after_0 = tree.root
proof_0_after = tree.proof(0)
self.assertTrue(proof_0_after.verify(tree.root))
self.assertNotEqual(root_before, root_after_0)

leaf_1_after = FQ.random()
tree.update(1, leaf_1_after)
root_after_1 = tree.root
proof_1_after = tree.proof(1)
self.assertTrue(proof_1_after.verify(tree.root))
self.assertNotEqual(root_before, root_after_1)
self.assertNotEqual(root_after_0, root_after_1)

def test_known_2pow28(self):
tree = MerkleTree(2<<28)
Expand Down
22 changes: 22 additions & 0 deletions test/test_poseidon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2019 Harry Roberts
# License: LGPL-3.0+

import unittest

from ethsnarks.poseidon import DefaultParams, poseidon


class TestPedersenHash(unittest.TestCase):
def test_constants(self):
self.assertEqual(DefaultParams.constants_C[0], 14397397413755236225575615486459253198602422701513067526754101844196324375522)
self.assertEqual(DefaultParams.constants_C[-1], 10635360132728137321700090133109897687122647659471659996419791842933639708516)
self.assertEqual(DefaultParams.constants_M[0][0], 19167410339349846567561662441069598364702008768579734801591448511131028229281)
self.assertEqual(DefaultParams.constants_M[-1][-1], 20261355950827657195644012399234591122288573679402601053407151083849785332516)


def test_permutation(self):
self.assertEqual(poseidon([1,2]), 12242166908188651009877250812424843524687801523336557272219921456462821518061)


if __name__ == "__main__":
unittest.main()

0 comments on commit 26b0fb9

Please sign in to comment.