diff --git a/chia/_tests/core/data_layer/test_merkle_blob.py b/chia/_tests/core/data_layer/test_merkle_blob.py index 0229a1ee1ddd..f4202413858d 100644 --- a/chia/_tests/core/data_layer/test_merkle_blob.py +++ b/chia/_tests/core/data_layer/test_merkle_blob.py @@ -16,7 +16,8 @@ from chia.data_layer.data_layer_util import InternalNode, Side, internal_hash from chia.data_layer.util.merkle_blob import ( InvalidIndexError, - KVId, + KeyId, + KeyOrValueId, MerkleBlob, NodeMetadata, NodeType, @@ -24,6 +25,7 @@ RawLeafMerkleNode, RawMerkleNodeProtocol, TreeIndex, + ValueId, data_size, metadata_size, null_parent, @@ -135,8 +137,8 @@ def id(self) -> str: raw=RawLeafMerkleNode( hash=bytes(range(32)), parent=TreeIndex(0x20212223), - key=KVId(0x2425262728292A2B), - value=KVId(0x2C2D2E2F30313233), + key=KeyId(KeyOrValueId(0x2425262728292A2B)), + value=ValueId(KeyOrValueId(0x2C2D2E2F30313233)), ), ), ] @@ -169,8 +171,8 @@ def test_merkle_blob_one_leaf_loads() -> None: leaf = RawLeafMerkleNode( hash=bytes(range(32)), parent=null_parent, - key=KVId(0x0405060708090A0B), - value=KVId(0x0405060708090A1B), + key=KeyId(KeyOrValueId(0x0405060708090A0B)), + value=ValueId(KeyOrValueId(0x0405060708090A1B)), ) blob = bytearray(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(leaf)) @@ -190,14 +192,14 @@ def test_merkle_blob_two_leafs_loads() -> None: left_leaf = RawLeafMerkleNode( hash=bytes(range(32)), parent=TreeIndex(0), - key=KVId(0x0405060708090A0B), - value=KVId(0x0405060708090A1B), + key=KeyId(KeyOrValueId(0x0405060708090A0B)), + value=ValueId(KeyOrValueId(0x0405060708090A1B)), ) right_leaf = RawLeafMerkleNode( hash=bytes(range(32)), parent=TreeIndex(0), - key=KVId(0x1415161718191A1B), - value=KVId(0x1415161718191A2B), + key=KeyId(KeyOrValueId(0x1415161718191A1B)), + value=ValueId(KeyOrValueId(0x1415161718191A2B)), ) blob = bytearray() blob.extend(NodeMetadata(type=NodeType.internal, dirty=True).pack() + pack_raw_node(root)) @@ -218,20 +220,20 @@ def test_merkle_blob_two_leafs_loads() -> None: son_hash = bytes32(range(32)) root_hash = internal_hash(son_hash, son_hash) expected_node = InternalNode(root_hash, son_hash, son_hash) - assert merkle_blob.get_lineage_by_key_id(KVId(0x0405060708090A0B)) == [expected_node] - assert merkle_blob.get_lineage_by_key_id(KVId(0x1415161718191A1B)) == [expected_node] + assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x0405060708090A0B))) == [expected_node] + assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x1415161718191A1B))) == [expected_node] -def generate_kvid(seed: int) -> tuple[KVId, KVId]: - kv_ids: list[KVId] = [] +def generate_kvid(seed: int) -> tuple[KeyId, ValueId]: + kv_ids: list[KeyOrValueId] = [] for offset in range(2): seed_bytes = (2 * seed + offset).to_bytes(8, byteorder="big", signed=True) hash_obj = hashlib.sha256(seed_bytes) hash_int = int.from_bytes(hash_obj.digest()[:8], byteorder="big", signed=True) - kv_ids.append(KVId(hash_int)) + kv_ids.append(KeyOrValueId(hash_int)) - return kv_ids[0], kv_ids[1] + return KeyId(kv_ids[0]), ValueId(kv_ids[1]) def generate_hash(seed: int) -> bytes32: @@ -245,7 +247,7 @@ def test_insert_delete_loads_all_keys() -> None: num_keys = 200000 extra_keys = 100000 max_height = 25 - keys_values: dict[KVId, KVId] = {} + keys_values: dict[KeyId, ValueId] = {} random = Random() random.seed(100, version=2) @@ -304,7 +306,7 @@ def test_small_insert_deletes() -> None: for repeats in range(num_repeats): for num_inserts in range(1, max_inserts): - keys_values: dict[KVId, KVId] = {} + keys_values: dict[KeyId, ValueId] = {} for inserts in range(num_inserts): seed += 1 key, value = generate_kvid(seed) @@ -330,13 +332,13 @@ def test_proof_of_inclusion_merkle_blob() -> None: random.seed(100, version=2) merkle_blob = MerkleBlob(blob=bytearray()) - keys_values: dict[KVId, KVId] = {} + keys_values: dict[KeyId, ValueId] = {} for repeats in range(num_repeats): num_inserts = 1 + repeats * 100 num_deletes = 1 + repeats * 10 - kv_ids: list[tuple[KVId, KVId]] = [] + kv_ids: list[tuple[KeyId, ValueId]] = [] hashes: list[bytes32] = [] for _ in range(num_inserts): seed += 1 @@ -363,7 +365,7 @@ def test_proof_of_inclusion_merkle_blob() -> None: with pytest.raises(Exception, match=f"Key {kv_id} not present in the store"): merkle_blob.get_proof_of_inclusion(kv_id) - new_keys_values: dict[KVId, KVId] = {} + new_keys_values: dict[KeyId, ValueId] = {} for old_kv in keys_values.keys(): seed += 1 _, value = generate_kvid(seed) @@ -382,7 +384,9 @@ def test_proof_of_inclusion_merkle_blob() -> None: @pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(-1), TreeIndex(1), TreeIndex(null_parent)]) def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None: merkle_blob = MerkleBlob(blob=bytearray()) - merkle_blob.insert(KVId(0x1415161718191A1B), KVId(0x1415161718191A1B), bytes(range(12, data_size))) + merkle_blob.insert( + KeyId(KeyOrValueId(0x1415161718191A1B)), ValueId(KeyOrValueId(0x1415161718191A1B)), bytes(range(12, data_size)) + ) with pytest.raises(InvalidIndexError): merkle_blob.get_raw_node(index) @@ -497,6 +501,6 @@ def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None: total_time = 0.0 for i in range(100000): start = time.monotonic() - merkle_blob.insert(KVId(i), KVId(i), HASH) + merkle_blob.insert(KeyId(KeyOrValueId(i)), ValueId(KeyOrValueId(i)), HASH) end = time.monotonic() total_time += end - start diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 852d93161ea2..9886bacaf68a 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -42,11 +42,13 @@ unspecified, ) from chia.data_layer.util.merkle_blob import ( - KVId, + KeyId, + KeyOrValueId, MerkleBlob, RawInternalMerkleNode, RawLeafMerkleNode, TreeIndex, + ValueId, ) from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.batches import to_batches @@ -191,7 +193,7 @@ async def insert_into_data_store_from_file( filename: Path, ) -> None: internal_nodes: dict[bytes32, tuple[bytes32, bytes32]] = {} - terminal_nodes: dict[bytes32, tuple[KVId, KVId]] = {} + terminal_nodes: dict[bytes32, tuple[KeyId, ValueId]] = {} with open(filename, "rb") as reader: while True: @@ -409,7 +411,7 @@ async def insert_root_from_merkle_blob( return await self._insert_root(store_id, root_hash, status) - async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]: + async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KeyOrValueId]: async with self.db_wrapper.reader() as reader: cursor = await reader.execute( "SELECT kv_id FROM ids WHERE blob = ? AND store_id = ?", @@ -423,9 +425,9 @@ async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]: if row is None: return None - return KVId(row[0]) + return KeyOrValueId(row[0]) - async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[bytes]: + async def get_blob_from_kvid(self, kv_id: KeyOrValueId, store_id: bytes32) -> Optional[bytes]: async with self.db_wrapper.reader() as reader: cursor = await reader.execute( "SELECT blob FROM ids WHERE kv_id = ? AND store_id = ?", @@ -441,7 +443,7 @@ async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[b return bytes(row[0]) - async def get_terminal_node(self, kid: KVId, vid: KVId, store_id: bytes32) -> TerminalNode: + async def get_terminal_node(self, kid: KeyId, vid: ValueId, store_id: bytes32) -> TerminalNode: key = await self.get_blob_from_kvid(kid, store_id) value = await self.get_blob_from_kvid(vid, store_id) if key is None or value is None: @@ -449,7 +451,7 @@ async def get_terminal_node(self, kid: KVId, vid: KVId, store_id: bytes32) -> Te return TerminalNode(hash=leaf_hash(key, value), key=key, value=value) - async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId: + async def add_kvid(self, blob: bytes, store_id: bytes32) -> KeyOrValueId: kv_id = await self.get_kvid(blob, store_id) if kv_id is not None: return kv_id @@ -468,9 +470,9 @@ async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId: raise Exception("Internal error") return kv_id - async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KVId, KVId]: - kid = await self.add_kvid(key, store_id) - vid = await self.add_kvid(value, store_id) + async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KeyId, ValueId]: + kid = KeyId(await self.add_kvid(key, store_id)) + vid = ValueId(await self.add_kvid(value, store_id)) hash = leaf_hash(key, value) async with self.db_wrapper.writer() as writer: await writer.execute( @@ -484,7 +486,7 @@ async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tu ) return (kid, vid) - async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId, KVId]: + async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KeyId, ValueId]: async with self.db_wrapper.reader() as reader: cursor = await reader.execute( "SELECT * FROM hashes WHERE hash = ? AND store_id = ?", @@ -499,8 +501,8 @@ async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId if row is None: raise Exception(f"Cannot find node by hash {hash.hex()}") - kid = KVId(row["kid"]) - vid = KVId(row["vid"]) + kid = KeyId(row["kid"]) + vid = ValueId(row["vid"]) return (kid, vid) async def get_terminal_node_by_hash(self, node_hash: bytes32, store_id: bytes32) -> TerminalNode: @@ -1057,19 +1059,19 @@ async def get_keys( for kid in kv_ids.keys(): key = await self.get_blob_from_kvid(kid, store_id) if key is None: - raise Exception(f"Unknown key corresponding to KVId: {kid}") + raise Exception(f"Unknown key corresponding to KeyId: {kid}") keys.append(key) return keys - def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KVId, Side]: + def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KeyId, Side]: side_seed = bytes(seed)[0] side = Side.LEFT if side_seed < 128 else Side.RIGHT reference_node = merkle_blob.get_random_leaf_node(seed) kid = reference_node.key return (kid, side) - async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KVId, store_id: bytes32) -> TerminalNode: + async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KeyId, store_id: bytes32) -> TerminalNode: index = merkle_blob.key_to_index[kid] raw_node = merkle_blob.get_raw_node(index) assert isinstance(raw_node, RawLeafMerkleNode) @@ -1139,7 +1141,7 @@ async def delete( kid = await self.get_kvid(key, store_id) if kid is not None: - merkle_blob.delete(kid) + merkle_blob.delete(KeyId(kid)) new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status) @@ -1199,7 +1201,7 @@ async def insert_batch( first_action[hash] = change["action"] last_action[hash] = change["action"] - batch_keys_values: list[tuple[KVId, KVId]] = [] + batch_keys_values: list[tuple[KeyId, ValueId]] = [] batch_hashes: list[bytes32] = [] for change in changelist: @@ -1209,7 +1211,7 @@ async def insert_batch( reference_node_hash = change.get("reference_node_hash", None) side = change.get("side", None) - reference_kid: Optional[KVId] = None + reference_kid: Optional[KeyId] = None if reference_node_hash is not None: reference_kid, _ = await self.get_node_by_hash(reference_node_hash, store_id) @@ -1236,7 +1238,7 @@ async def insert_batch( key = change["key"] deletion_kid = await self.get_kvid(key, store_id) if deletion_kid is not None: - merkle_blob.delete(deletion_kid) + merkle_blob.delete(KeyId(deletion_kid)) elif change["action"] == "upsert": key = change["key"] new_value = change["value"] @@ -1324,9 +1326,10 @@ async def get_node_by_key( except MerkleBlobNotFoundError: raise KeyNotFoundError(key=key) - kid = await self.get_kvid(key, store_id) - if kid is None: + kvid = await self.get_kvid(key, store_id) + if kvid is None: raise KeyNotFoundError(key=key) + kid = KeyId(kvid) if not merkle_blob.key_exists(kid): raise KeyNotFoundError(key=key) return await self.get_terminal_node_from_kid(merkle_blob, kid, store_id) @@ -1389,9 +1392,10 @@ async def get_proof_of_inclusion_by_key( ) -> ProofOfInclusion: root = await self.get_tree_root(store_id=store_id) merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) - kid = await self.get_kvid(key, store_id) - if kid is None: + kvid = await self.get_kvid(key, store_id) + if kvid is None: raise Exception(f"Cannot find key: {key.hex()}") + kid = KeyId(kvid) return merkle_blob.get_proof_of_inclusion(kid) async def write_tree_to_file( diff --git a/chia/data_layer/util/merkle_blob.py b/chia/data_layer/util/merkle_blob.py index 02745d364277..4ed2197ed060 100644 --- a/chia/data_layer/util/merkle_blob.py +++ b/chia/data_layer/util/merkle_blob.py @@ -8,11 +8,14 @@ from chia.data_layer.data_layer_util import InternalNode, ProofOfInclusion, ProofOfInclusionLayer, Side, internal_hash from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.hash import std_hash +from chia.util.ints import int64 dirty_hash = bytes32(b"\x00" * 32) TreeIndex = NewType("TreeIndex", int) -KVId = NewType("KVId", int) +KeyOrValueId = int64 +KeyId = NewType("KeyId", KeyOrValueId) +ValueId = NewType("ValueId", KeyOrValueId) T = TypeVar("T") @@ -40,7 +43,7 @@ class NodeType(IntEnum): @dataclass(frozen=False) class MerkleBlob: blob: bytearray - key_to_index: dict[KVId, TreeIndex] = field(default_factory=dict) + key_to_index: dict[KeyId, TreeIndex] = field(default_factory=dict) free_indexes: list[TreeIndex] = field(default_factory=list) last_allocated_index: TreeIndex = TreeIndex(0) @@ -53,7 +56,7 @@ def __post_init__(self) -> None: def from_node_list( cls: type[MerkleBlob], internal_nodes: dict[bytes32, tuple[bytes32, bytes32]], - terminal_nodes: dict[bytes32, tuple[KVId, KVId]], + terminal_nodes: dict[bytes32, tuple[KeyId, ValueId]], root_hash: Optional[bytes32], ) -> MerkleBlob: merkle_blob = cls(blob=bytearray()) @@ -69,7 +72,7 @@ def from_node_list( def build_blob_from_node_list( self, internal_nodes: dict[bytes32, tuple[bytes32, bytes32]], - terminal_nodes: dict[bytes32, tuple[KVId, KVId]], + terminal_nodes: dict[bytes32, tuple[KeyId, ValueId]], node_hash: bytes32, ) -> TreeIndex: if node_hash not in terminal_nodes and node_hash not in internal_nodes: @@ -186,11 +189,11 @@ def calculate_lazy_hashes(self, index: TreeIndex = TreeIndex(0)) -> bytes32: self.update_metadata(index, dirty=False) return internal_node_hash - def get_proof_of_inclusion(self, kvID: KVId) -> ProofOfInclusion: - if kvID not in self.key_to_index: - raise Exception(f"Key {kvID} not present in the store") + def get_proof_of_inclusion(self, key_id: KeyId) -> ProofOfInclusion: + if key_id not in self.key_to_index: + raise Exception(f"Key {key_id} not present in the store") - index = self.key_to_index[kvID] + index = self.key_to_index[key_id] node = self.get_raw_node(index) assert isinstance(node, RawLeafMerkleNode) @@ -219,8 +222,8 @@ def get_lineage_with_indexes(self, index: TreeIndex) -> list[tuple[TreeIndex, Ra lineage.append((index, node)) return lineage - def get_lineage_by_key_id(self, kid: KVId) -> list[InternalNode]: - index = self.key_to_index[kid] + def get_lineage_by_key_id(self, key_id: KeyId) -> list[InternalNode]: + index = self.key_to_index[key_id] lineage = self.get_lineage_with_indexes(index) internal_nodes: list[InternalNode] = [] for _, node in lineage[1:]: @@ -238,8 +241,8 @@ def update_entry( left: Optional[TreeIndex] = None, right: Optional[TreeIndex] = None, hash: Optional[bytes] = None, - key: Optional[KVId] = None, - value: Optional[KVId] = None, + key: Optional[KeyId] = None, + value: Optional[ValueId] = None, ) -> None: node = self.get_raw_node(index) new_parent = parent if parent is not None else node.parent @@ -277,11 +280,11 @@ def get_random_leaf_node(self, seed: bytes) -> RawLeafMerkleNode: raise Exception("Cannot find leaf from seed") - def get_keys_indexes(self) -> dict[KVId, TreeIndex]: + def get_keys_indexes(self) -> dict[KeyId, TreeIndex]: if len(self.blob) == 0: return {} - key_to_index: dict[KVId, TreeIndex] = {} + key_to_index: dict[KeyId, TreeIndex] = {} queue: list[TreeIndex] = [TreeIndex(0)] while len(queue) > 0: node_index = queue.pop() @@ -311,11 +314,11 @@ def get_hashes_indexes(self) -> dict[bytes32, TreeIndex]: return hash_to_index - def get_keys_values(self) -> dict[KVId, KVId]: + def get_keys_values(self) -> dict[KeyId, ValueId]: if len(self.blob) == 0: return {} - keys_values: dict[KVId, KVId] = {} + keys_values: dict[KeyId, ValueId] = {} queue: list[TreeIndex] = [TreeIndex(0)] while len(queue) > 0: node_index = queue.pop() @@ -398,15 +401,15 @@ def insert_from_leaf(self, old_leaf_index: TreeIndex, new_index: TreeIndex, side if isinstance(new_node, RawLeafMerkleNode): self.key_to_index[new_node.key] = new_index - def key_exists(self, key: KVId) -> bool: + def key_exists(self, key: KeyId) -> bool: return key in self.key_to_index def insert( self, - key: KVId, - value: KVId, + key: KeyId, + value: ValueId, hash: bytes, - reference_kid: Optional[KVId] = None, + reference_kid: Optional[KeyId] = None, side: Optional[Side] = None, ) -> None: if key in self.key_to_index: @@ -468,7 +471,7 @@ def insert( ) self.insert_from_leaf(old_leaf_index, new_leaf_index, side) - def delete(self, key: KVId) -> None: + def delete(self, key: KeyId) -> None: leaf_index = self.key_to_index[key] leaf = self.get_raw_node(leaf_index) assert isinstance(leaf, RawLeafMerkleNode) @@ -518,7 +521,7 @@ def delete(self, key: KVId) -> None: self.update_entry(grandparent_index, right=sibling_index) self.mark_lineage_as_dirty(grandparent_index) - def upsert(self, key: KVId, value: KVId, hash: bytes) -> None: + def upsert(self, key: KeyId, value: ValueId, hash: bytes) -> None: if key not in self.key_to_index: self.insert(key, value, hash) return @@ -560,7 +563,7 @@ def get_nodes_with_indexes(self, index: TreeIndex = TreeIndex(0)) -> list[tuple[ return this + left_nodes + right_nodes - def batch_insert(self, keys_values: list[tuple[KVId, KVId]], hashes: list[bytes32]) -> None: + def batch_insert(self, keys_values: list[tuple[KeyId, ValueId]], hashes: list[bytes32]) -> None: indexes: list[TreeIndex] = [] if len(self.key_to_index) <= 1: @@ -713,11 +716,11 @@ class RawLeafMerkleNode: hash: bytes parent: TreeIndex # TODO: how/where are these mapping? maybe a kv table row id? - key: KVId - value: KVId + key: KeyId + value: ValueId # TODO: maybe bytes32? maybe that's not 'raw' - def as_tuple(self) -> tuple[bytes, TreeIndex, KVId, KVId]: + def as_tuple(self) -> tuple[bytes, TreeIndex, KeyId, ValueId]: return (self.hash, self.parent, self.key, self.value)