Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add datalayer proof of inclusion generation #902

Draft
wants to merge 2 commits into
base: long_lived/initial_datalayer
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,74 @@ pub enum NodeType {
Leaf = 1,
}

#[cfg_attr(feature = "py-bindings", pyclass(get_all))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProofOfInclusionLayer {
pub other_hash_side: Side,
pub other_hash: Hash,
pub combined_hash: Hash,
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl ProofOfInclusionLayer {
#[new]
pub fn py_init(other_hash_side: Side, other_hash: Hash, combined_hash: Hash) -> PyResult<Self> {
Ok(Self {
other_hash_side,
other_hash,
combined_hash,
})
}
}

#[cfg_attr(feature = "py-bindings", pyclass(get_all))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProofOfInclusion {
pub node_hash: Hash,
pub layers: Vec<ProofOfInclusionLayer>,
}

impl ProofOfInclusion {
pub fn root_hash(&self) -> Hash {
if let Some(last) = self.layers.last() {
last.combined_hash
} else {
self.node_hash
}
}

pub fn valid(&self) -> bool {
let mut existing_hash = self.node_hash;

for layer in &self.layers {
let calculated_hash =
calculate_internal_hash(&existing_hash, layer.other_hash_side, &layer.other_hash);

if calculated_hash != layer.combined_hash {
return false;
}

existing_hash = calculated_hash;
}

existing_hash == self.root_hash()
}
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl ProofOfInclusion {
#[pyo3(name = "root_hash")]
pub fn py_root_hash(&self) -> Hash {
self.root_hash()
}
#[pyo3(name = "valid")]
pub fn py_valid(&self) -> bool {
self.valid()
}
}

#[allow(clippy::needless_pass_by_value)]
fn sha256_num<T: ToBytes>(input: T) -> Hash {
let mut hasher = Sha256::new();
Expand All @@ -306,6 +374,13 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash {
Hash(Bytes32::new(hasher.finalize()))
}

pub fn calculate_internal_hash(hash: &Hash, other_hash_side: Side, other_hash: &Hash) -> Hash {
match other_hash_side {
Side::Left => internal_hash(other_hash, hash),
Side::Right => internal_hash(hash, other_hash),
}
}

#[cfg_attr(feature = "py-bindings", pyclass(eq, eq_int))]
#[repr(u8)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Streamable)]
Expand Down Expand Up @@ -351,6 +426,16 @@ impl InternalNode {
Err(Error::IndexIsNotAChild(index))
}
}

pub fn get_sibling_side(&self, index: TreeIndex) -> Result<Side, Error> {
if self.left == index {
Ok(Side::Right)
} else if self.right == index {
Ok(Side::Left)
} else {
Err(Error::IndexIsNotAChild(index))
}
}
}

#[cfg(feature = "py-bindings")]
Expand Down Expand Up @@ -461,6 +546,15 @@ impl Node {
*leaf
}

fn expect_internal(&self, message: &str) -> InternalNode {
let Node::Internal(internal) = self else {
let message = message.replace("<<self>>", &format!("{self:?}"));
panic!("{}", message)
};

*internal
}

fn try_into_leaf(self) -> Result<LeafNode, Error> {
match self {
Node::Leaf(leaf) => Ok(leaf),
Expand Down Expand Up @@ -1350,6 +1444,36 @@ impl MerkleBlob {
.copied()
.ok_or(Error::UnknownKey(key))
}

pub fn get_proof_of_inclusion(&self, key: KeyId) -> Result<ProofOfInclusion, Error> {
let mut index = *self.key_to_index.get(&key).ok_or(Error::UnknownKey(key))?;

// TODO: message
let node = self.get_node(index)?.expect_leaf("");

let parents = self.get_lineage_with_indexes(index)?;
let mut layers: Vec<ProofOfInclusionLayer> = Vec::new();
let mut parents_iter = parents[1..].iter();
parents_iter.next();
for (next_index, parent) in parents_iter {
// TODO: message
let parent = parent.expect_internal("");
let sibling_index = parent.sibling_index(index)?;
let sibling = self.get_node(sibling_index)?;
let layer = ProofOfInclusionLayer {
other_hash_side: parent.get_sibling_side(index)?,
other_hash: sibling.hash(),
combined_hash: parent.hash,
};
layers.push(layer);
index = *next_index;
}

Ok(ProofOfInclusion {
node_hash: node.hash,
layers,
})
}
}

impl PartialEq for MerkleBlob {
Expand Down Expand Up @@ -1554,6 +1678,11 @@ impl MerkleBlob {
pub fn py_get_key_index(&self, key: KeyId) -> PyResult<TreeIndex> {
Ok(self.get_key_index(key)?)
}

#[pyo3(name = "get_proof_of_inclusion")]
pub fn py_get_proof_of_inclusion(&self, key: KeyId) -> PyResult<ProofOfInclusion> {
Ok(self.get_proof_of_inclusion(key)?)
}
}

fn try_get_block(blob: &[u8], index: TreeIndex) -> Result<Block, Error> {
Expand Down
17 changes: 17 additions & 0 deletions wheel/python/chia_rs/datalayer.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ class LeafNode:
def value(self) -> int64: ...


@final
class ProofOfInclusionLayer:
def __init__(self, parent: Optional[uint32], hash: bytes32, left: uint32, right: uint32) -> None: ...
other_hash_side: uint8
other_hash: bytes32
combined_hash: bytes32

@final
class ProofOfInclusion:
node_hash: bytes32
# children before parents
layers: list[ProofOfInclusionLayer]

def root_hash(self) -> bytes32: ...
def valid(self) -> bool: ...

@final
class MerkleBlob:
@property
Expand Down Expand Up @@ -90,6 +106,7 @@ class MerkleBlob:
def get_hash_at_index(self, index: uint32): ...
def get_keys_values(self) -> dict[int64, int64]: ...
def get_key_index(self, key: int64) -> uint32: ...
def get_proof_of_inclusion(self, key: int64) -> ProofOfInclusion: ...

def __len__(self) -> int: ...

Expand Down
2 changes: 2 additions & 0 deletions wheel/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,8 @@ pub fn add_datalayer_submodule(py: Python<'_>, parent: &Bound<'_, PyModule>) ->
datalayer.add_class::<MerkleBlob>()?;
datalayer.add_class::<InternalNode>()?;
datalayer.add_class::<LeafNode>()?;
datalayer.add_class::<ProofOfInclusionLayer>()?;
datalayer.add_class::<ProofOfInclusion>()?;

datalayer.add("BLOCK_SIZE", BLOCK_SIZE)?;
datalayer.add("DATA_SIZE", DATA_SIZE)?;
Expand Down
Loading