Skip to content

Commit

Permalink
update stem generation implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Karim Taam <[email protected]>
  • Loading branch information
matkt committed Jun 20, 2024
1 parent 9f422b8 commit 1a36a4a
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.hyperledger.besu.ethereum.trie.verkle.util.Parameters.VERSION_LEAF_KEY;

import org.hyperledger.besu.ethereum.trie.verkle.hasher.Hasher;
import org.hyperledger.besu.ethereum.trie.verkle.hasher.PedersenHasher;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -48,12 +49,17 @@ public class TrieKeyAdapter {

private final Hasher hasher;

/** Creates a TrieKeyAdapter with the default Perdersen hasher. */
public TrieKeyAdapter() {
this.hasher = new PedersenHasher();
}

/**
* Creates a TrieKeyAdapter with the provided hasher.
*
* @param hasher The hasher used for key generation.
*/
public TrieKeyAdapter(Hasher hasher) {
public TrieKeyAdapter(final Hasher hasher) {
this.hasher = hasher;
}

Expand All @@ -73,51 +79,63 @@ public Hasher getHasher() {
* @param storageKey The storage key.
* @return The generated storage key.
*/
public Bytes32 storageKey(Bytes address, Bytes32 storageKey) {
final UInt256 pos = locateStorageKeyOffset(storageKey);
final Bytes32 base = hasher.trieKeyHash(address, pos);
final UInt256 suffix = locateStorageKeySuffix(storageKey);
return swapLastByte(base, suffix);
public Bytes32 storageKey(final Bytes address, final Bytes32 storageKey) {
final Bytes stem = getStorageStem(address, storageKey);
final UInt256 suffix = getStorageKeySuffix(storageKey);
return swapLastByte(stem, suffix);
}

public UInt256 locateStorageKeyOffset(Bytes32 storageKey) {
UInt256 index = UInt256.fromBytes(storageKey);
if (index.compareTo(HEADER_STORAGE_SIZE) < 0) {
return index.add(HEADER_STORAGE_OFFSET).divide(VERKLE_NODE_WIDTH);
public UInt256 getStorageKeyTrieIndex(final Bytes32 storageKey) {
final UInt256 uintStorageKey = UInt256.fromBytes(storageKey);
if (uintStorageKey.compareTo(HEADER_STORAGE_SIZE) < 0) {
return uintStorageKey.add(HEADER_STORAGE_OFFSET).divide(VERKLE_NODE_WIDTH);
} else {
// We divide by VerkleNodeWidthLog2 to make space and prevent any potential overflow
// Then, we increment, a step that is safeguarded against overflow.
return index
return uintStorageKey
.shiftRight(VERKLE_NODE_WIDTH_LOG2.intValue())
.add(MAIN_STORAGE_OFFSET_SHIFT_LEFT_VERKLE_NODE_WIDTH);
}
}

public UInt256 locateStorageKeySuffix(Bytes32 storageKey) {
UInt256 index = UInt256.fromBytes(storageKey);
if (index.compareTo(HEADER_STORAGE_SIZE) < 0) {
final UInt256 mod = index.add(HEADER_STORAGE_OFFSET).mod(VERKLE_NODE_WIDTH);
public UInt256 getStorageKeySuffix(final Bytes32 storageKey) {
final UInt256 uintStorageKey = UInt256.fromBytes(storageKey);
if (uintStorageKey.compareTo(HEADER_STORAGE_SIZE) < 0) {
final UInt256 mod = uintStorageKey.add(HEADER_STORAGE_OFFSET).mod(VERKLE_NODE_WIDTH);
return UInt256.fromBytes(mod.slice(mod.size() - 1));
} else {
return UInt256.fromBytes(storageKey.slice(Bytes32.SIZE - 1));
}
}

public Bytes getStorageStem(final Bytes address, final Bytes32 storageKey) {
final UInt256 trieIndex = getStorageKeyTrieIndex(storageKey);
return hasher.computeStem(address, trieIndex);
}

/**
* Generates a code chunk key for a given address and chunkId.
*
* @param address The address.
* @param chunkId The chunk ID.
* @return The generated code chunk key.
*/
public Bytes32 codeChunkKey(Bytes address, UInt256 chunkId) {
UInt256 pos = locateCodeChunkKeyOffset(chunkId);
Bytes32 base = hasher.trieKeyHash(address, pos.divide(VERKLE_NODE_WIDTH));
return swapLastByte(base, pos.mod(VERKLE_NODE_WIDTH));
public Bytes32 codeChunkKey(final Bytes address, final UInt256 chunkId) {
final Bytes stem = getCodeChunkStem(address, chunkId);
return swapLastByte(stem, getCodeChunkKeySuffix(chunkId));
}

public UInt256 getCodeChunkKeyTrieIndex(final Bytes32 chunkId) {
return CODE_OFFSET.add(UInt256.fromBytes(chunkId)).divide(VERKLE_NODE_WIDTH);
}

public UInt256 getCodeChunkKeySuffix(final Bytes32 chunkId) {
return CODE_OFFSET.add(UInt256.fromBytes(chunkId)).mod(VERKLE_NODE_WIDTH);
}

public UInt256 locateCodeChunkKeyOffset(Bytes32 chunkId) {
return CODE_OFFSET.add(UInt256.fromBytes(chunkId));
public Bytes getCodeChunkStem(final Bytes address, final UInt256 chunkId) {
final UInt256 trieIndex = getCodeChunkKeyTrieIndex(chunkId);
return hasher.computeStem(address, trieIndex);
}

/**
Expand All @@ -127,9 +145,13 @@ public UInt256 locateCodeChunkKeyOffset(Bytes32 chunkId) {
* @param leafKey The leaf key.
* @return The generated header key.
*/
Bytes32 headerKey(Bytes address, UInt256 leafKey) {
Bytes32 base = hasher.trieKeyHash(address, UInt256.valueOf(0).toBytes());
return swapLastByte(base, leafKey);
public Bytes32 headerKey(final Bytes address, final UInt256 leafKey) {
final Bytes stem = getHeaderStem(address);
return swapLastByte(stem, leafKey);
}

public Bytes getHeaderStem(final Bytes address) {
return hasher.computeStem(address, UInt256.valueOf(0).toBytes());
}

/**
Expand All @@ -139,9 +161,9 @@ Bytes32 headerKey(Bytes address, UInt256 leafKey) {
* @param subIndex The subIndex.
* @return The modified key.
*/
public Bytes32 swapLastByte(Bytes32 base, Bytes subIndex) {
public Bytes32 swapLastByte(final Bytes base, final Bytes subIndex) {
final Bytes lastByte = subIndex.slice(subIndex.size() - 1, 1);
return (Bytes32) Bytes.concatenate(base.slice(0, 31), lastByte);
return (Bytes32) Bytes.concatenate(base, lastByte);
}

/**
Expand All @@ -150,7 +172,7 @@ public Bytes32 swapLastByte(Bytes32 base, Bytes subIndex) {
* @param address The address.
* @return The generated version key.
*/
public Bytes32 versionKey(Bytes address) {
public Bytes32 versionKey(final Bytes address) {
return headerKey(address, VERSION_LEAF_KEY);
}

Expand All @@ -160,7 +182,7 @@ public Bytes32 versionKey(Bytes address) {
* @param address The address.
* @return The generated balance key.
*/
public Bytes32 balanceKey(Bytes address) {
public Bytes32 balanceKey(final Bytes address) {
return headerKey(address, BALANCE_LEAF_KEY);
}

Expand All @@ -170,7 +192,7 @@ public Bytes32 balanceKey(Bytes address) {
* @param address The address.
* @return The generated nonce key.
*/
public Bytes32 nonceKey(Bytes address) {
public Bytes32 nonceKey(final Bytes address) {
return headerKey(address, NONCE_LEAF_KEY);
}

Expand All @@ -180,7 +202,7 @@ public Bytes32 nonceKey(Bytes address) {
* @param address The address.
* @return The generated code Keccak key.
*/
public Bytes32 codeKeccakKey(Bytes address) {
public Bytes32 codeKeccakKey(final Bytes address) {
return headerKey(address, CODE_KECCAK_LEAF_KEY);
}

Expand All @@ -190,11 +212,11 @@ public Bytes32 codeKeccakKey(Bytes address) {
* @param address The address.
* @return The generated code size key.
*/
public Bytes32 codeSizeKey(Bytes address) {
public Bytes32 codeSizeKey(final Bytes address) {
return (headerKey(address, CODE_SIZE_LEAF_KEY));
}

public int getNbChunk(Bytes bytecode) {
public int getNbChunk(final Bytes bytecode) {
return bytecode.isEmpty() ? 0 : (1 + ((bytecode.size() - 1) / CHUNK_SIZE));
}
/**
Expand All @@ -204,15 +226,15 @@ public int getNbChunk(Bytes bytecode) {
* @param bytecode Code's bytecode
* @return List of 32-bytes code chunks
*/
public List<UInt256> chunkifyCode(Bytes bytecode) {
public List<UInt256> chunkifyCode(final Bytes bytecode) {
if (bytecode.isEmpty()) {
return new ArrayList<>();
}

// Chunking variables
final int CHUNK_SIZE = 31;
int nChunks = getNbChunk(bytecode);
int padSize = nChunks * CHUNK_SIZE - bytecode.size();
final int nChunks = getNbChunk(bytecode);
final int padSize = nChunks * CHUNK_SIZE - bytecode.size();
final Bytes code = Bytes.concatenate(bytecode, Bytes.repeat((byte) 0, padSize));
final List<UInt256> chunks = new ArrayList<>(nChunks);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package org.hyperledger.besu.ethereum.trie.verkle.adapter;

import static org.hyperledger.besu.ethereum.trie.verkle.util.Parameters.VERKLE_NODE_WIDTH;

import org.hyperledger.besu.ethereum.trie.verkle.hasher.Hasher;

import java.util.ArrayList;
Expand All @@ -40,25 +38,24 @@ public TrieKeyBatchAdapter(final Hasher hasher) {
super(hasher);
}

public Map<Bytes32, Bytes32> manyTrieKeyHashes(
public Map<Bytes32, Bytes> manyStems(
final Bytes address,
final List<Bytes32> headerKeys,
final List<Bytes32> storageKeys,
final List<Bytes32> codeChunkIds) {

final Set<Bytes32> offsets = new HashSet<>();
final Set<Bytes32> trieIndex = new HashSet<>();

if (headerKeys.size() > 0) {
offsets.add(UInt256.ZERO);
trieIndex.add(UInt256.ZERO);
}
for (Bytes32 storageKey : storageKeys) {
offsets.add(locateStorageKeyOffset(storageKey));
trieIndex.add(getStorageKeyTrieIndex(storageKey));
}
for (Bytes32 codeChunkId : codeChunkIds) {
final UInt256 codeChunkOffset = locateCodeChunkKeyOffset(codeChunkId);
offsets.add(codeChunkOffset.divide(VERKLE_NODE_WIDTH));
trieIndex.add(getCodeChunkKeyTrieIndex(codeChunkId));
}

return getHasher().manyTrieKeyHashes(address, new ArrayList<>(offsets));
return getHasher().manyStems(address, new ArrayList<>(trieIndex));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,32 @@
*/
package org.hyperledger.besu.ethereum.trie.verkle.hasher;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;

public class CachedPedersenHasher implements Hasher {
private final Map<Bytes32, Bytes32> preloadedTrieKeyHashes;
private final Cache<Bytes32, Bytes> stemCache;
private final Hasher fallbackHasher;

public CachedPedersenHasher(final Map<Bytes32, Bytes32> preloadedTrieKeyHashes) {
this.preloadedTrieKeyHashes = preloadedTrieKeyHashes;
this.fallbackHasher = new PedersenHasher();
public CachedPedersenHasher(final int cacheSize) {
this(cacheSize, new HashMap<>());
}

public CachedPedersenHasher(final int cacheSize, final Map<Bytes32, Bytes> preloadedStems) {
this(cacheSize, preloadedStems, new PedersenHasher());
}

public CachedPedersenHasher(
final Map<Bytes32, Bytes32> preloadedTrieKeyHashes, final Hasher fallbackHasher) {
this.preloadedTrieKeyHashes = preloadedTrieKeyHashes;
final int cacheSize, final Map<Bytes32, Bytes> preloadedStems, final Hasher fallbackHasher) {
this.stemCache = CacheBuilder.newBuilder().maximumSize(cacheSize).build();
this.stemCache.putAll(preloadedStems);
this.fallbackHasher = fallbackHasher;
}

Expand Down Expand Up @@ -62,18 +69,22 @@ public Bytes32 compress(Bytes commitment) {
}

@Override
public Bytes32 trieKeyHash(final Bytes bytes, final Bytes32 bytes32) {
final Bytes32 hash = preloadedTrieKeyHashes.get(bytes32);
if (hash != null) {
return hash;
public Bytes computeStem(final Bytes address, final Bytes32 trieKeyIndex) {
Bytes stem = stemCache.getIfPresent(trieKeyIndex);
if (stem != null) {
return stem;
} else {
return fallbackHasher.trieKeyHash(bytes, bytes32);
stem = fallbackHasher.computeStem(address, trieKeyIndex);
stemCache.put(trieKeyIndex, stem);
return stem;
}
}

@Override
public Map<Bytes32, Bytes32> manyTrieKeyHashes(final Bytes bytes, final List<Bytes32> list) {
return fallbackHasher.manyTrieKeyHashes(bytes, list);
public Map<Bytes32, Bytes> manyStems(final Bytes address, final List<Bytes32> trieKeyIndexes) {
final Map<Bytes32, Bytes> trieKeyHashes = fallbackHasher.manyStems(address, trieKeyIndexes);
stemCache.putAll(trieKeyHashes);
return trieKeyHashes;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,20 @@ Bytes commitUpdate(
List<Bytes32> hashMany(Bytes[] commitments);

/**
* Calculates the hash for an address and index.
* Calculates the stem for an address and index.
*
* @param address Account address.
* @param index index in storage.
* @return trie-key hash
*/
Bytes32 trieKeyHash(Bytes address, Bytes32 index);
Bytes computeStem(Bytes address, Bytes32 index);

/**
* Calculates the hash for an address and indexes.
* Calculates the stem for an address and indexes.
*
* @param address Account address.
* @param indexes list of indexes in storage.
* @return The list of trie-key hashes
*/
Map<Bytes32, Bytes32> manyTrieKeyHashes(Bytes address, List<Bytes32> indexes);
Map<Bytes32, Bytes> manyStems(Bytes address, List<Bytes32> indexes);
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public class PedersenHasher implements Hasher {
// making the total number of chunks equal to five.
private static final int NUM_CHUNKS = 5;

// Size of the stem is 31 bytes
private static final int STEM_SIZE = 31;

/**
* Commit to a vector of values.
*
Expand Down Expand Up @@ -141,7 +144,7 @@ public List<Bytes32> hashMany(final Bytes[] commitments) {
* @return The trie-key hash
*/
@Override
public Bytes32 trieKeyHash(Bytes address, Bytes32 index) {
public Bytes computeStem(Bytes address, Bytes32 index) {

// Pad the address so that it is 32 bytes
final Bytes32 addr = Bytes32.leftPad(address);
Expand All @@ -151,7 +154,7 @@ public Bytes32 trieKeyHash(Bytes address, Bytes32 index) {
final Bytes hash =
Bytes.wrap(
LibIpaMultipoint.hash(LibIpaMultipoint.commit(Bytes.concatenate(chunks).toArray())));
return Bytes32.wrap(hash);
return hash.slice(0, STEM_SIZE);
}

/**
Expand All @@ -162,7 +165,7 @@ public Bytes32 trieKeyHash(Bytes address, Bytes32 index) {
* @return The list of trie-key hashes
*/
@Override
public Map<Bytes32, Bytes32> manyTrieKeyHashes(Bytes address, List<Bytes32> indexes) {
public Map<Bytes32, Bytes> manyStems(Bytes address, List<Bytes32> indexes) {

// Pad the address so that it is 32 bytes
final Bytes32 addr = Bytes32.leftPad(address);
Expand All @@ -176,12 +179,12 @@ public Map<Bytes32, Bytes32> manyTrieKeyHashes(Bytes address, List<Bytes32> inde

final Bytes hashMany = Bytes.wrap(LibIpaMultipoint.hashMany(outputStream.toByteArray()));

final Map<Bytes32, Bytes32> hashes = new HashMap<>();
final Map<Bytes32, Bytes> stems = new HashMap<>();
for (int i = 0; i < indexes.size(); i++) {
// Slice input into 16 byte segments
hashes.put(indexes.get(i), Bytes32.wrap(hashMany.slice(i * Bytes32.SIZE, Bytes32.SIZE)));
stems.put(indexes.get(i), Bytes.wrap(hashMany.slice(i * Bytes32.SIZE, STEM_SIZE)));
}
return hashes;
return stems;
} catch (IOException e) {
throw new RuntimeException("unable to generate trie key hash", e);
}
Expand Down
Loading

0 comments on commit 1a36a4a

Please sign in to comment.