Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHIA-1126] Remove PrivateKey type from wallet #18458

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chia/_tests/wallet/test_main_wallet_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def generate_signed_transaction(
)
)

def puzzle_for_pk(self, pubkey: G1Element) -> Program: # pragma: no cover
def puzzle_for_pk(self, pubkey: ObservationRoot) -> Program: # pragma: no cover
raise ValueError("This won't work")

async def puzzle_for_puzzle_hash(self, puzzle_hash: bytes32) -> Program:
Expand Down Expand Up @@ -183,7 +183,7 @@ async def make_solution(
async def get_puzzle(self, new: bool) -> Program: # pragma: no cover
return ACS

def puzzle_hash_for_pk(self, pubkey: G1Element) -> bytes32: # pragma: no cover
def puzzle_hash_for_pk(self, pubkey: ObservationRoot) -> bytes32: # pragma: no cover
raise ValueError("This won't work")

def require_derivation_paths(self) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/wallet/test_signer_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ async def test_p2dohp_wallet_signer_protocol(wallet_environments: WalletTestFram
@pytest.mark.anyio
async def test_p2blsdohp_execute_signing_instructions(wallet_environments: WalletTestFramework) -> None:
wallet: MainWalletProtocol = wallet_environments.environments[0].xch_wallet
root_sk: PrivateKey = wallet.wallet_state_manager.get_master_private_key()
root_sk = wallet.wallet_state_manager.get_master_private_key()
assert isinstance(root_sk, PrivateKey)
root_pk: G1Element = root_sk.get_g1()
root_fingerprint: bytes = root_pk.get_fingerprint().to_bytes(4, "big")

Expand Down
8 changes: 5 additions & 3 deletions chia/_tests/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, List, Optional, Tuple

import pytest
from chia_rs import AugSchemeMPL, G1Element, G2Element
from chia_rs import AugSchemeMPL, G1Element, G2Element, PrivateKey

from chia._tests.environments.wallet import WalletStateTransition, WalletTestFramework
from chia._tests.util.time_out_assert import time_out_assert
Expand Down Expand Up @@ -1875,9 +1875,11 @@ async def test_address_sliding_window(self, wallet_environments: WalletTestFrame
peak = full_node_api.full_node.blockchain.get_peak_height()
assert peak is not None

puzzle_hashes = []
puzzle_hashes: List[bytes32] = []
for i in range(211):
pubkey = master_sk_to_wallet_sk(wallet.wallet_state_manager.get_master_private_key(), uint32(i)).get_g1()
sk = wallet.wallet_state_manager.get_master_private_key()
assert isinstance(sk, PrivateKey)
pubkey = master_sk_to_wallet_sk(sk, uint32(i)).public_key()
puzzle: Program = wallet.puzzle_for_pk(pubkey)
puzzle_hash: bytes32 = puzzle.get_tree_hash()
puzzle_hashes.append(puzzle_hash)
Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/wallet/test_wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ async def test_get_last_used_fingerprint_if_exists(
assert node.wallet_state_manager.private_key is not None
assert (
await node.get_last_used_fingerprint_if_exists()
== node.wallet_state_manager.private_key.get_g1().get_fingerprint()
== node.wallet_state_manager.private_key.public_key().get_fingerprint()
)
await node.keychain_proxy.delete_all_keys()
assert await node.get_last_used_fingerprint_if_exists() is None
Expand Down
8 changes: 5 additions & 3 deletions chia/_tests/wallet/test_wallet_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import AsyncIterator, List

import pytest
from chia_rs import G2Element
from chia_rs import G2Element, PrivateKey

from chia._tests.environments.wallet import WalletTestFramework
from chia._tests.util.setup_nodes import OldSimulatorsAndWallets
Expand Down Expand Up @@ -68,11 +68,13 @@ async def test_get_private_key(simulator_and_wallet: OldSimulatorsAndWallets, ha
wallet_state_manager: WalletStateManager = wallet_node.wallet_state_manager
derivation_index = uint32(10000)
conversion_method = master_sk_to_wallet_sk if hardened else master_sk_to_wallet_sk_unhardened
expected_private_key = conversion_method(wallet_state_manager.get_master_private_key(), derivation_index)
sk = wallet_state_manager.get_master_private_key()
assert isinstance(sk, PrivateKey)
expected_private_key = conversion_method(sk, derivation_index)
record = DerivationRecord(
derivation_index,
bytes32(b"0" * 32),
expected_private_key.get_g1(),
bytes(expected_private_key.public_key()),
WalletType.STANDARD_WALLET,
uint32(1),
hardened,
Expand Down
4 changes: 3 additions & 1 deletion chia/pools/pool_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,11 @@ async def create_new_pool_wallet_transaction(
return p2_singleton_puzzle_hash, launcher_coin_id

async def _get_owner_key_cache(self) -> Tuple[PrivateKey, uint32]:
private_key = self.wallet_state_manager.get_master_private_key()
assert isinstance(private_key, PrivateKey)
if self._owner_sk_and_index is None:
self._owner_sk_and_index = find_owner_sk(
[self.wallet_state_manager.get_master_private_key()],
[private_key],
(await self.get_current_state()).current.owner_pubkey,
)
assert self._owner_sk_and_index is not None
Expand Down
6 changes: 3 additions & 3 deletions chia/rpc/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,9 @@ async def create_new_wallet(
if max_pwi + 1 >= (MAX_POOL_WALLETS - 1):
raise ValueError(f"Too many pool wallets ({max_pwi}), cannot create any more on this key.")

owner_sk: PrivateKey = master_sk_to_singleton_owner_sk(
self.service.wallet_state_manager.get_master_private_key(), uint32(max_pwi + 1)
)
master_sk = self.service.wallet_state_manager.get_master_private_key()
assert isinstance(master_sk, PrivateKey), "Pooling only works with BLS keys at this time"
owner_sk: PrivateKey = master_sk_to_singleton_owner_sk(master_sk, uint32(max_pwi + 1))
owner_pk: G1Element = owner_sk.get_g1()

initial_target_state = initial_pool_state_from_dict(
Expand Down
7 changes: 7 additions & 0 deletions chia/wallet/derivation_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@ class DerivationRecord:
def pubkey(self) -> G1Element:
assert isinstance(self._pubkey, G1Element)
return self._pubkey

@property
def pubkey_bytes(self) -> bytes:
if isinstance(self._pubkey, G1Element):
return bytes(self._pubkey)
else:
return self._pubkey
5 changes: 3 additions & 2 deletions chia/wallet/vault/vault_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from chia.types.spend_bundle import SpendBundle
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.util.observation_root import ObservationRoot
from chia.wallet.coin_selection import select_coins
from chia.wallet.conditions import (
AssertCoinAnnouncement,
Expand Down Expand Up @@ -314,7 +315,7 @@ async def _generate_unsigned_transaction(

return all_spends

def puzzle_for_pk(self, pubkey: G1Element) -> Program:
def puzzle_for_pk(self, pubkey: ObservationRoot) -> Program:
raise NotImplementedError("vault wallet")

async def puzzle_for_puzzle_hash(self, puzzle_hash: bytes32) -> Program:
Expand Down Expand Up @@ -438,7 +439,7 @@ async def get_puzzle(self, new: bool) -> Program:
)
return puzzle

def puzzle_hash_for_pk(self, pubkey: G1Element) -> bytes32:
def puzzle_hash_for_pk(self, pubkey: ObservationRoot) -> bytes32:
raise ValueError("This won't work")

def require_derivation_paths(self) -> bool:
Expand Down
38 changes: 21 additions & 17 deletions chia/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from chia.util.condition_tools import conditions_dict_for_solution, pkm_pairs_for_conditions_dict
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.util.observation_root import ObservationRoot
from chia.util.streamable import Streamable
from chia.wallet.coin_selection import select_coins
from chia.wallet.conditions import (
Expand Down Expand Up @@ -169,10 +170,12 @@ async def get_pending_change_balance(self) -> uint64:
def require_derivation_paths(self) -> bool:
return True

def puzzle_for_pk(self, pubkey: G1Element) -> Program:
def puzzle_for_pk(self, pubkey: ObservationRoot) -> Program:
assert isinstance(pubkey, G1Element), "Standard wallet cannot support non-BLS keys yet"
return puzzle_for_pk(pubkey)

def puzzle_hash_for_pk(self, pubkey: G1Element) -> bytes32:
def puzzle_hash_for_pk(self, pubkey: ObservationRoot) -> bytes32:
assert isinstance(pubkey, G1Element), "Standard wallet cannot support non-BLS keys yet"
return puzzle_hash_for_pk(pubkey)

async def convert_puzzle_hash(self, puzzle_hash: bytes32) -> bytes32:
Expand Down Expand Up @@ -397,6 +400,7 @@ async def sign_message(self, message: str, puzzle_hash: bytes32, mode: SigningMo
# CHIP-0002 message signing as documented at:
# https://github.com/Chia-Network/chips/blob/80e4611fe52b174bf1a0382b9dff73805b18b8c6/CHIPs/chip-0002.md#signmessage
private = await self.wallet_state_manager.get_private_key(puzzle_hash)
assert isinstance(private, PrivateKey)
synthetic_secret_key = calculate_synthetic_secret_key(private, DEFAULT_HIDDEN_PUZZLE_HASH)
synthetic_pk = synthetic_secret_key.get_g1()
if mode == SigningMode.CHIP_0002_HEX_INPUT:
Expand Down Expand Up @@ -557,7 +561,9 @@ async def path_hint_for_pubkey(self, pk: bytes) -> Optional[PathHint]:
root_fingerprint: bytes = self.wallet_state_manager.observation_root.get_fingerprint().to_bytes(4, "big")
if index is None:
# Pool wallet may have a secret key here
if self.wallet_state_manager.private_key is not None:
if self.wallet_state_manager.private_key is not None and isinstance(
self.wallet_state_manager.private_key, PrivateKey
):
for pool_wallet_index in range(MAX_POOL_WALLETS):
try_owner_sk = master_sk_to_singleton_owner_sk(
self.wallet_state_manager.private_key, uint32(pool_wallet_index)
Expand All @@ -578,20 +584,20 @@ async def execute_signing_instructions(
) -> List[SigningResponse]:
assert isinstance(self.wallet_state_manager.observation_root, G1Element)
root_pubkey: G1Element = self.wallet_state_manager.observation_root
pk_lookup: Dict[int, G1Element] = (
{root_pubkey.get_fingerprint(): root_pubkey} if self.wallet_state_manager.private_key is not None else {}
)
sk_lookup: Dict[int, PrivateKey] = (
{root_pubkey.get_fingerprint(): self.wallet_state_manager.get_master_private_key()}
if self.wallet_state_manager.private_key is not None
else {}
)
pk_lookup: Dict[int, G1Element] = {}
sk_lookup: Dict[int, PrivateKey] = {}
aggregate_responses_at_end: bool = True
responses: List[SigningResponse] = []

# TODO: expand path hints and sum hints recursively (a sum hint can give a new key to path hint)
# Next, expand our pubkey set with path hints
if self.wallet_state_manager.private_key is not None:
root_secret_key = self.wallet_state_manager.get_master_private_key()
assert isinstance(root_secret_key, PrivateKey)
root_fingerprint = root_pubkey.get_fingerprint()
pk_lookup[root_fingerprint] = root_pubkey
sk_lookup[root_fingerprint] = root_secret_key

for path_hint in signing_instructions.key_hints.path_hints:
if int.from_bytes(path_hint.root_fingerprint, "big") != root_pubkey.get_fingerprint():
if not partial_allowed:
Expand All @@ -600,12 +606,10 @@ async def execute_signing_instructions(
continue
else:
path = [int(step) for step in path_hint.path]
derive_child_sk = _derive_path(self.wallet_state_manager.get_master_private_key(), path)
derive_child_sk_unhardened = _derive_path_unhardened(
self.wallet_state_manager.get_master_private_key(), path
)
derive_child_pk = derive_child_sk.get_g1()
derive_child_pk_unhardened = derive_child_sk_unhardened.get_g1()
derive_child_sk = _derive_path(root_secret_key, path)
derive_child_sk_unhardened = _derive_path_unhardened(root_secret_key, path)
derive_child_pk = derive_child_sk.public_key()
derive_child_pk_unhardened = derive_child_sk_unhardened.public_key()
pk_lookup[derive_child_pk.get_fingerprint()] = derive_child_pk
pk_lookup[derive_child_pk_unhardened.get_fingerprint()] = derive_child_pk_unhardened
sk_lookup[derive_child_pk.get_fingerprint()] = derive_child_sk
Expand Down
3 changes: 2 additions & 1 deletion chia/wallet/wallet_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from chia.types.signing_mode import SigningMode
from chia.types.spend_bundle import SpendBundle
from chia.util.ints import uint32, uint64, uint128
from chia.util.observation_root import ObservationRoot
from chia.wallet.conditions import Condition
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.nft_wallet.nft_info import NFTCoinInfo
Expand Down Expand Up @@ -114,7 +115,7 @@ async def generate_signed_transaction(
**kwargs: Unpack[GSTOptionalArgs],
) -> None: ...

def puzzle_for_pk(self, pubkey: G1Element) -> Program: ...
def puzzle_for_pk(self, pubkey: ObservationRoot) -> Program: ...

async def puzzle_for_puzzle_hash(self, puzzle_hash: bytes32) -> Program: ...

Expand Down
2 changes: 1 addition & 1 deletion chia/wallet/wallet_puzzle_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def add_derivation_paths(self, records: List[DerivationRecord]) -> None:
sql_records.append(
(
record.index,
bytes(record._pubkey).hex(),
record.pubkey_bytes.hex(),
record.puzzle_hash.hex(),
record.wallet_type,
record.wallet_id,
Expand Down
33 changes: 20 additions & 13 deletions chia/wallet/wallet_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from chia.util.lru_cache import LRUCache
from chia.util.observation_root import ObservationRoot
from chia.util.path import path_from_root
from chia.util.secret_info import SecretInfo
from chia.util.streamable import Streamable, UInt32Range, UInt64Range, VersionedBlob
from chia.wallet.cat_wallet.cat_constants import DEFAULT_CATS
from chia.wallet.cat_wallet.cat_info import CATCoinData, CATInfo, CRCATInfo
Expand Down Expand Up @@ -202,7 +203,7 @@ class WalletStateManager:

main_wallet: MainWalletProtocol
wallets: Dict[uint32, WalletProtocol[Any]]
private_key: Optional[PrivateKey]
private_key: Optional[SecretInfo[Any]]
observation_root: ObservationRoot

trade_manager: TradeManager
Expand All @@ -225,7 +226,7 @@ class WalletStateManager:

@staticmethod
async def create(
private_key: Optional[PrivateKey],
private_key: Optional[SecretInfo[Any]],
config: Dict[str, Any],
db_path: Path,
constants: ConsensusConstants,
Expand Down Expand Up @@ -293,7 +294,7 @@ async def create(
else:
self.observation_root = observation_root
else:
calculated_root_public_key: G1Element = private_key.get_g1()
calculated_root_public_key: ObservationRoot = private_key.public_key()
if observation_root is not None:
assert observation_root == calculated_root_public_key
self.observation_root = calculated_root_public_key
Expand Down Expand Up @@ -392,13 +393,17 @@ def get_public_key_unhardened(self, index: uint32) -> G1Element:
raise ValueError("Public key derivation is not supported for non-G1Element keys")
return master_pk_to_wallet_pk_unhardened(self.observation_root, index)

async def get_private_key(self, puzzle_hash: bytes32) -> PrivateKey:
async def get_private_key(self, puzzle_hash: bytes32) -> SecretInfo[Any]:
record = await self.puzzle_store.record_for_puzzle_hash(puzzle_hash)
if record is None:
raise ValueError(f"No key for puzzle hash: {puzzle_hash.hex()}")
sk = self.get_master_private_key()
# This will need to work when other key types are derivable but for now we will just sanitize and move on
assert isinstance(sk, PrivateKey)
if record.hardened:
return master_sk_to_wallet_sk(self.get_master_private_key(), record.index)
return master_sk_to_wallet_sk_unhardened(self.get_master_private_key(), record.index)
return master_sk_to_wallet_sk(sk, record.index)

return master_sk_to_wallet_sk_unhardened(sk, record.index)

async def get_public_key(self, puzzle_hash: bytes32) -> bytes:
record = await self.puzzle_store.record_for_puzzle_hash(puzzle_hash)
Expand All @@ -410,7 +415,7 @@ async def get_public_key(self, puzzle_hash: bytes32) -> bytes:
pk_bytes = bytes(record._pubkey)
return pk_bytes

def get_master_private_key(self) -> PrivateKey:
def get_master_private_key(self) -> SecretInfo[Any]:
if self.private_key is None: # pragma: no cover
raise ValueError("Wallet is currently in observer mode and access to private key is denied")

Expand Down Expand Up @@ -485,16 +490,18 @@ async def create_more_puzzle_hashes(
hardened_keys: Dict[int, G1Element] = {}
unhardened_keys: Dict[int, G1Element] = {}

if self.private_key is not None:
# Hardened
intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key)
for index in range(start_index, last_index):
hardened_keys[index] = _derive_path(intermediate_sk, [index]).get_g1()

# This function shoul work for other types of observation roots too
# However to generalize this function beyond pubkeys is beyond the scope of current work
# So we're just going to sanitize and move on
assert isinstance(self.observation_root, G1Element)
if self.private_key is not None:
assert isinstance(self.private_key, PrivateKey)

if self.private_key is not None:
# Hardened
intermediate_sk = master_sk_to_wallet_sk_intermediate(self.private_key)
for index in range(start_index, last_index):
hardened_keys[index] = _derive_path(intermediate_sk, [index]).public_key()

# Unhardened
intermediate_pk_un = master_pk_to_wallet_pk_unhardened_intermediate(self.observation_root)
Expand Down
Loading