Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

separate KvID to KeyId and ValueId #19180

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions chia/_tests/core/data_layer/test_merkle_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from chia.data_layer.data_layer_util import InternalNode, Side, internal_hash
from chia.data_layer.util.merkle_blob import (
InvalidIndexError,
KVId,
KeyId,
KeyOrValueId,
MerkleBlob,
NodeMetadata,
NodeType,
RawInternalMerkleNode,
RawLeafMerkleNode,
RawMerkleNodeProtocol,
TreeIndex,
ValueId,
data_size,
metadata_size,
null_parent,
Expand Down Expand Up @@ -135,8 +137,8 @@ def id(self) -> str:
raw=RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0x20212223),
key=KVId(0x2425262728292A2B),
value=KVId(0x2C2D2E2F30313233),
key=KeyId(KeyOrValueId(0x2425262728292A2B)),
value=ValueId(KeyOrValueId(0x2C2D2E2F30313233)),
),
),
]
Expand Down Expand Up @@ -169,8 +171,8 @@ def test_merkle_blob_one_leaf_loads() -> None:
leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=null_parent,
key=KVId(0x0405060708090A0B),
value=KVId(0x0405060708090A1B),
key=KeyId(KeyOrValueId(0x0405060708090A0B)),
value=ValueId(KeyOrValueId(0x0405060708090A1B)),
)
blob = bytearray(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(leaf))

Expand All @@ -190,14 +192,14 @@ def test_merkle_blob_two_leafs_loads() -> None:
left_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KVId(0x0405060708090A0B),
value=KVId(0x0405060708090A1B),
key=KeyId(KeyOrValueId(0x0405060708090A0B)),
value=ValueId(KeyOrValueId(0x0405060708090A1B)),
)
right_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KVId(0x1415161718191A1B),
value=KVId(0x1415161718191A2B),
key=KeyId(KeyOrValueId(0x1415161718191A1B)),
value=ValueId(KeyOrValueId(0x1415161718191A2B)),
)
blob = bytearray()
blob.extend(NodeMetadata(type=NodeType.internal, dirty=True).pack() + pack_raw_node(root))
Expand All @@ -218,20 +220,20 @@ def test_merkle_blob_two_leafs_loads() -> None:
son_hash = bytes32(range(32))
root_hash = internal_hash(son_hash, son_hash)
expected_node = InternalNode(root_hash, son_hash, son_hash)
assert merkle_blob.get_lineage_by_key_id(KVId(0x0405060708090A0B)) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KVId(0x1415161718191A1B)) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x0405060708090A0B))) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x1415161718191A1B))) == [expected_node]


def generate_kvid(seed: int) -> tuple[KVId, KVId]:
kv_ids: list[KVId] = []
def generate_kvid(seed: int) -> tuple[KeyId, ValueId]:
kv_ids: list[KeyOrValueId] = []

for offset in range(2):
seed_bytes = (2 * seed + offset).to_bytes(8, byteorder="big", signed=True)
hash_obj = hashlib.sha256(seed_bytes)
hash_int = int.from_bytes(hash_obj.digest()[:8], byteorder="big", signed=True)
kv_ids.append(KVId(hash_int))
kv_ids.append(KeyOrValueId(hash_int))

return kv_ids[0], kv_ids[1]
return KeyId(kv_ids[0]), ValueId(kv_ids[1])


def generate_hash(seed: int) -> bytes32:
Expand All @@ -245,7 +247,7 @@ def test_insert_delete_loads_all_keys() -> None:
num_keys = 200000
extra_keys = 100000
max_height = 25
keys_values: dict[KVId, KVId] = {}
keys_values: dict[KeyId, ValueId] = {}

random = Random()
random.seed(100, version=2)
Expand Down Expand Up @@ -304,7 +306,7 @@ def test_small_insert_deletes() -> None:

for repeats in range(num_repeats):
for num_inserts in range(1, max_inserts):
keys_values: dict[KVId, KVId] = {}
keys_values: dict[KeyId, ValueId] = {}
for inserts in range(num_inserts):
seed += 1
key, value = generate_kvid(seed)
Expand All @@ -330,13 +332,13 @@ def test_proof_of_inclusion_merkle_blob() -> None:
random.seed(100, version=2)

merkle_blob = MerkleBlob(blob=bytearray())
keys_values: dict[KVId, KVId] = {}
keys_values: dict[KeyId, ValueId] = {}

for repeats in range(num_repeats):
num_inserts = 1 + repeats * 100
num_deletes = 1 + repeats * 10

kv_ids: list[tuple[KVId, KVId]] = []
kv_ids: list[tuple[KeyId, ValueId]] = []
hashes: list[bytes32] = []
for _ in range(num_inserts):
seed += 1
Expand All @@ -363,7 +365,7 @@ def test_proof_of_inclusion_merkle_blob() -> None:
with pytest.raises(Exception, match=f"Key {kv_id} not present in the store"):
merkle_blob.get_proof_of_inclusion(kv_id)

new_keys_values: dict[KVId, KVId] = {}
new_keys_values: dict[KeyId, ValueId] = {}
for old_kv in keys_values.keys():
seed += 1
_, value = generate_kvid(seed)
Expand All @@ -382,7 +384,9 @@ def test_proof_of_inclusion_merkle_blob() -> None:
@pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(-1), TreeIndex(1), TreeIndex(null_parent)])
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob = MerkleBlob(blob=bytearray())
merkle_blob.insert(KVId(0x1415161718191A1B), KVId(0x1415161718191A1B), bytes(range(12, data_size)))
merkle_blob.insert(
KeyId(KeyOrValueId(0x1415161718191A1B)), ValueId(KeyOrValueId(0x1415161718191A1B)), bytes(range(12, data_size))
)

with pytest.raises(InvalidIndexError):
merkle_blob.get_raw_node(index)
Expand Down Expand Up @@ -497,6 +501,6 @@ def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None:
total_time = 0.0
for i in range(100000):
start = time.monotonic()
merkle_blob.insert(KVId(i), KVId(i), HASH)
merkle_blob.insert(KeyId(KeyOrValueId(i)), ValueId(KeyOrValueId(i)), HASH)
end = time.monotonic()
total_time += end - start
52 changes: 28 additions & 24 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
unspecified,
)
from chia.data_layer.util.merkle_blob import (
KVId,
KeyId,
KeyOrValueId,
MerkleBlob,
RawInternalMerkleNode,
RawLeafMerkleNode,
TreeIndex,
ValueId,
)
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.batches import to_batches
Expand Down Expand Up @@ -191,7 +193,7 @@ async def insert_into_data_store_from_file(
filename: Path,
) -> None:
internal_nodes: dict[bytes32, tuple[bytes32, bytes32]] = {}
terminal_nodes: dict[bytes32, tuple[KVId, KVId]] = {}
terminal_nodes: dict[bytes32, tuple[KeyId, ValueId]] = {}

with open(filename, "rb") as reader:
while True:
Expand Down Expand Up @@ -409,7 +411,7 @@ async def insert_root_from_merkle_blob(

return await self._insert_root(store_id, root_hash, status)

async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]:
async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KeyOrValueId]:
async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
"SELECT kv_id FROM ids WHERE blob = ? AND store_id = ?",
Expand All @@ -423,9 +425,9 @@ async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]:
if row is None:
return None

return KVId(row[0])
return KeyOrValueId(row[0])

async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[bytes]:
async def get_blob_from_kvid(self, kv_id: KeyOrValueId, store_id: bytes32) -> Optional[bytes]:
async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
"SELECT blob FROM ids WHERE kv_id = ? AND store_id = ?",
Expand All @@ -441,15 +443,15 @@ async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[b

return bytes(row[0])

async def get_terminal_node(self, kid: KVId, vid: KVId, store_id: bytes32) -> TerminalNode:
async def get_terminal_node(self, kid: KeyId, vid: ValueId, store_id: bytes32) -> TerminalNode:
key = await self.get_blob_from_kvid(kid, store_id)
value = await self.get_blob_from_kvid(vid, store_id)
if key is None or value is None:
raise Exception("Cannot find the key/value pair")

return TerminalNode(hash=leaf_hash(key, value), key=key, value=value)

async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId:
async def add_kvid(self, blob: bytes, store_id: bytes32) -> KeyOrValueId:
kv_id = await self.get_kvid(blob, store_id)
if kv_id is not None:
return kv_id
Expand All @@ -468,9 +470,9 @@ async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId:
raise Exception("Internal error")
return kv_id

async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KVId, KVId]:
kid = await self.add_kvid(key, store_id)
vid = await self.add_kvid(value, store_id)
async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KeyId, ValueId]:
kid = KeyId(await self.add_kvid(key, store_id))
vid = ValueId(await self.add_kvid(value, store_id))
hash = leaf_hash(key, value)
async with self.db_wrapper.writer() as writer:
await writer.execute(
Expand All @@ -484,7 +486,7 @@ async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tu
)
return (kid, vid)

async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId, KVId]:
async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KeyId, ValueId]:
async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
"SELECT * FROM hashes WHERE hash = ? AND store_id = ?",
Expand All @@ -499,8 +501,8 @@ async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId
if row is None:
raise Exception(f"Cannot find node by hash {hash.hex()}")

kid = KVId(row["kid"])
vid = KVId(row["vid"])
kid = KeyId(row["kid"])
vid = ValueId(row["vid"])
return (kid, vid)

async def get_terminal_node_by_hash(self, node_hash: bytes32, store_id: bytes32) -> TerminalNode:
Expand Down Expand Up @@ -1057,19 +1059,19 @@ async def get_keys(
for kid in kv_ids.keys():
key = await self.get_blob_from_kvid(kid, store_id)
if key is None:
raise Exception(f"Unknown key corresponding to KVId: {kid}")
raise Exception(f"Unknown key corresponding to KeyId: {kid}")
keys.append(key)

return keys

def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KVId, Side]:
def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KeyId, Side]:
side_seed = bytes(seed)[0]
side = Side.LEFT if side_seed < 128 else Side.RIGHT
reference_node = merkle_blob.get_random_leaf_node(seed)
kid = reference_node.key
return (kid, side)

async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KVId, store_id: bytes32) -> TerminalNode:
async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KeyId, store_id: bytes32) -> TerminalNode:
index = merkle_blob.key_to_index[kid]
raw_node = merkle_blob.get_raw_node(index)
assert isinstance(raw_node, RawLeafMerkleNode)
Expand Down Expand Up @@ -1139,7 +1141,7 @@ async def delete(

kid = await self.get_kvid(key, store_id)
if kid is not None:
merkle_blob.delete(kid)
merkle_blob.delete(KeyId(kid))

new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status)

Expand Down Expand Up @@ -1199,7 +1201,7 @@ async def insert_batch(
first_action[hash] = change["action"]
last_action[hash] = change["action"]

batch_keys_values: list[tuple[KVId, KVId]] = []
batch_keys_values: list[tuple[KeyId, ValueId]] = []
batch_hashes: list[bytes32] = []

for change in changelist:
Expand All @@ -1209,7 +1211,7 @@ async def insert_batch(

reference_node_hash = change.get("reference_node_hash", None)
side = change.get("side", None)
reference_kid: Optional[KVId] = None
reference_kid: Optional[KeyId] = None
if reference_node_hash is not None:
reference_kid, _ = await self.get_node_by_hash(reference_node_hash, store_id)

Expand All @@ -1236,7 +1238,7 @@ async def insert_batch(
key = change["key"]
deletion_kid = await self.get_kvid(key, store_id)
if deletion_kid is not None:
merkle_blob.delete(deletion_kid)
merkle_blob.delete(KeyId(deletion_kid))
elif change["action"] == "upsert":
key = change["key"]
new_value = change["value"]
Expand Down Expand Up @@ -1324,9 +1326,10 @@ async def get_node_by_key(
except MerkleBlobNotFoundError:
raise KeyNotFoundError(key=key)

kid = await self.get_kvid(key, store_id)
if kid is None:
kvid = await self.get_kvid(key, store_id)
if kvid is None:
raise KeyNotFoundError(key=key)
kid = KeyId(kvid)
if not merkle_blob.key_exists(kid):
raise KeyNotFoundError(key=key)
return await self.get_terminal_node_from_kid(merkle_blob, kid, store_id)
Expand Down Expand Up @@ -1389,9 +1392,10 @@ async def get_proof_of_inclusion_by_key(
) -> ProofOfInclusion:
root = await self.get_tree_root(store_id=store_id)
merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash)
kid = await self.get_kvid(key, store_id)
if kid is None:
kvid = await self.get_kvid(key, store_id)
if kvid is None:
raise Exception(f"Cannot find key: {key.hex()}")
kid = KeyId(kvid)
return merkle_blob.get_proof_of_inclusion(kid)

async def write_tree_to_file(
Expand Down
Loading
Loading