Last active
June 21, 2024 11:17
-
-
Save vkobel/f609468a1d3b0fc8ed9c5e1177d3673a to your computer and use it in GitHub Desktop.
Short incremental Merkle Tree python implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
import random | |
class IncrementalMerkleTree: | |
def __init__(self, depth: int) -> None: | |
if depth < 1: | |
raise ValueError("Depth must be at least 1") | |
self.depth: int = depth | |
self.leaf_count: int = 2 ** depth | |
self.empty_leaf: bytes = self._hash(b'') | |
# The total number of nodes in the tree (both internal and leaf nodes) is 2^(d+1) - 1. | |
self.tree: list[bytes] = [self.empty_leaf * 32] * (2 * self.leaf_count - 1) | |
self.next_leaf_index = 0 | |
@staticmethod | |
def _hash(left: bytes, right: bytes = b'') -> bytes: | |
return hashlib.sha256(left + right).digest() | |
def update(self, leaf: str) -> None: | |
if self.next_leaf_index >= self.leaf_count: | |
raise ValueError("Tree is full") | |
index: int = self.leaf_count - 1 + self.next_leaf_index | |
self.tree[index] = self._hash(leaf.encode()) | |
while index > 0: | |
parent: int = (index - 1) // 2 | |
left_child: bytes = self.tree[2 * parent + 1] | |
right_child: bytes = self.tree[2 * parent + 2] | |
self.tree[parent] = self._hash(left_child, right_child) | |
index = parent | |
self.next_leaf_index += 1 | |
def get_root(self) -> str: | |
return self.tree[0].hex() | |
def get_proof(self, leaf_index: int) -> list[dict[str, str]]: | |
if leaf_index < 0 or leaf_index >= self.next_leaf_index: | |
raise ValueError("Leaf index out of range") | |
proof = [] | |
index: int = self.leaf_count - 1 + leaf_index | |
while index > 0: | |
sibling_index: int = index - 1 if index % 2 == 0 else index + 1 | |
is_left: bool = sibling_index < index | |
proof.append({ | |
'sibling': self.tree[sibling_index].hex(), | |
'is_left': is_left | |
}) | |
index = (index - 1) // 2 | |
return proof | |
@staticmethod | |
def verify_proof(leaf: str, proof: list[dict[str, str]], root: str) -> bool: | |
current: bytes = IncrementalMerkleTree._hash(leaf.encode()) | |
for node in proof: | |
sibling: bytes = bytes.fromhex(node['sibling']) | |
if node['is_left']: | |
current = IncrementalMerkleTree._hash(sibling, current) | |
else: | |
current = IncrementalMerkleTree._hash(current, sibling) | |
return current.hex() == root | |
def print_tree(self) -> None: | |
def format_hash(hash_bytes: bytes) -> str: | |
return hash_bytes.hex()[:8] # Show first 8 characters for brevity | |
levels = [] | |
for i in range(self.depth + 1): | |
start: int = 2**i - 1 | |
end: int = 2**(i + 1) - 1 | |
levels.append([format_hash(h) for h in self.tree[start:end]]) | |
max_level = len(levels[-1]) | |
max_width = max_level * 10 # Assuming each hash is 8 characters long | |
def center_text(text, width): | |
if len(text) >= width: | |
return text | |
space = (width - len(text)) // 2 | |
return ' ' * space + text + ' ' * space | |
for i, level in enumerate(levels): | |
level_width = len(level) * 10 | |
spacing = (max_width - level_width) // len(level) | |
padded_level = [center_text(h, 10 + spacing) for h in level] | |
print(''.join(padded_level).center(max_width)) | |
def test_merkle_tree(depth=3, nb_leaves=0) -> None: | |
tree = IncrementalMerkleTree(depth) | |
# generate a number of leaves <= tree leaf count (random) | |
if nb_leaves <= 0: | |
random_nb = random.randint(1, tree.leaf_count) | |
nb_leaves = random_nb | |
print(f"\nGenerating {nb_leaves} leaves, tree depth: { | |
depth}, max leaves: {tree.leaf_count}") | |
leaves = [f"leaf_{i}" for i in range(nb_leaves)] | |
for leaf in leaves: | |
tree.update(leaf) | |
if depth <= 5: | |
print("Tree structure:") | |
tree.print_tree() | |
root: str = tree.get_root() | |
print(f"Root: {root}") | |
for i, leaf in enumerate(leaves): | |
proof: list[dict[str, str]] = tree.get_proof(i) | |
is_valid: bool = IncrementalMerkleTree.verify_proof(leaf, proof, root) | |
if not is_valid: | |
print(f" Proof: {proof}") | |
assert is_valid | |
# Test invalid leaf | |
invalid_leaf = "invalid_leaf" | |
invalid_proof: list[dict[str, str]] = tree.get_proof( | |
0) # Use proof of leaf1 | |
is_valid = IncrementalMerkleTree.verify_proof( | |
invalid_leaf, invalid_proof, root) | |
assert not is_valid | |
if __name__ == "__main__": | |
test_merkle_tree(depth=14, nb_leaves=16_000) | |
test_merkle_tree(depth=18, nb_leaves=1) | |
test_merkle_tree(depth=1, nb_leaves=1) | |
test_merkle_tree(depth=4) | |
print("All tests passed!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment