Skip to content

Commit

Permalink
[CHIA-1124] Deduplicate keychain code a bit (#18455)
Browse files Browse the repository at this point in the history
The keychain has a couple of weird cleanliness mistakes. The first is
that we separately define `MAX_KEYS` three times across our code base
and then always use it with `+1`. The second is that many keychain
functions iterate through all existing keys to find the one it's looking
for. This results in a lot of code duplication that can be extracted to
its own function. This PR is just for cleanliness for the sake of making
upstream refactors easier.
  • Loading branch information
Quexington authored Aug 14, 2024
2 parents aa520d3 + 27dda22 commit 1f50bcc
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 61 deletions.
1 change: 0 additions & 1 deletion chia/cmds/passphrase_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
colorama.Fore.YELLOW + colorama.Style.BRIGHT + "(Unlock Keyring)" + colorama.Style.RESET_ALL + " Passphrase: "
) # noqa: E501
FAILED_ATTEMPT_DELAY = 0.5
MAX_KEYS = 100
MAX_RETRIES = 3
SAVE_MASTER_PASSPHRASE_WARNING = (
colorama.Fore.YELLOW
Expand Down
7 changes: 3 additions & 4 deletions chia/legacy/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@

from chia.cmds.cmds_util import prompt_yes_no
from chia.util.errors import KeychainUserNotFound
from chia.util.keychain import KeyData, KeyDataSecrets, KeyTypes, get_private_key_user
from chia.util.keychain import MAX_KEYS, KeyData, KeyDataSecrets, KeyTypes, get_private_key_user

LegacyKeyring = Union[MacKeyring, WinKeyring, CryptFileKeyring]


CURRENT_KEY_VERSION = "1.8"
DEFAULT_USER = f"user-chia-{CURRENT_KEY_VERSION}" # e.g. user-chia-1.8
DEFAULT_SERVICE = f"chia-{DEFAULT_USER}" # e.g. chia-user-chia-1.8
MAX_KEYS = 100


# casting to compensate for a combination of mypy and keyring issues
Expand Down Expand Up @@ -90,7 +89,7 @@ def get_key_data(keyring: LegacyKeyring, index: int) -> KeyData:

def get_keys(keyring: LegacyKeyring) -> List[KeyData]:
keys: List[KeyData] = []
for index in range(MAX_KEYS + 1):
for index in range(MAX_KEYS):
try:
keys.append(get_key_data(keyring, index))
except KeychainUserNotFound:
Expand All @@ -114,7 +113,7 @@ def print_keys(keyring: LegacyKeyring) -> None:

def remove_keys(keyring: LegacyKeyring) -> None:
removed = 0
for index in range(MAX_KEYS + 1):
for index in range(MAX_KEYS):
try:
keyring.delete_password(DEFAULT_SERVICE, get_private_key_user(DEFAULT_USER, index))
removed += 1
Expand Down
98 changes: 42 additions & 56 deletions chia/util/keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import cached_property
from hashlib import pbkdf2_hmac
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload

import importlib_resources
from bitstring import BitArray # pyright: reportMissingImports=false
Expand Down Expand Up @@ -37,7 +37,7 @@
CURRENT_KEY_VERSION = "1.8"
DEFAULT_USER = f"user-chia-{CURRENT_KEY_VERSION}" # e.g. user-chia-1.8
DEFAULT_SERVICE = f"chia-{DEFAULT_USER}" # e.g. chia-user-chia-1.8
MAX_KEYS = 100
MAX_KEYS = 101
MIN_PASSPHRASE_LEN = 8


Expand Down Expand Up @@ -486,29 +486,34 @@ def delete_label(self, fingerprint: int) -> None:
"""
self.keyring_wrapper.keyring.delete_label(fingerprint)

def _iterate_through_key_datas(
self, include_secrets: bool = True, skip_public_only: bool = False
) -> Iterator[KeyData]:
for index in range(MAX_KEYS):
try:
key_data = self._get_key_data(index, include_secrets=include_secrets)
if key_data is None or (skip_public_only and key_data.secrets is None):
continue
yield key_data
except KeychainUserNotFound:
pass
return None

def get_first_private_key(self) -> Optional[Tuple[PrivateKey, bytes]]:
"""
Returns the first key in the keychain that has one of the passed in passphrases.
"""
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index)
return key_data.private_key, key_data.entropy
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(skip_public_only=True):
return key_data.private_key, key_data.entropy
return None

def get_private_key_by_fingerprint(self, fingerprint: int) -> Optional[Tuple[PrivateKey, bytes]]:
"""
Return first private key which have the given public key fingerprint.
"""
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index)
if key_data.fingerprint == fingerprint:
return key_data.private_key, key_data.entropy
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(skip_public_only=True):
if key_data.fingerprint == fingerprint:
return key_data.private_key, key_data.entropy
return None

def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]:
Expand All @@ -517,63 +522,47 @@ def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]:
A tuple of key, and entropy bytes (i.e. mnemonic) is returned for each key.
"""
all_keys: List[Tuple[PrivateKey, bytes]] = []
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index)
all_keys.append((key_data.private_key, key_data.entropy))
except (KeychainUserNotFound, KeychainSecretsMissing):
pass
for key_data in self._iterate_through_key_datas(skip_public_only=True):
all_keys.append((key_data.private_key, key_data.entropy))
return all_keys

def get_key(self, fingerprint: int, include_secrets: bool = False) -> KeyData:
"""
Return the KeyData of the first key which has the given public key fingerprint.
"""
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index, include_secrets)
if key_data.observation_root.get_fingerprint() == fingerprint:
return key_data
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(include_secrets=include_secrets, skip_public_only=False):
if key_data.observation_root.get_fingerprint() == fingerprint:
return key_data

raise KeychainFingerprintNotFound(fingerprint)

def get_keys(self, include_secrets: bool = False) -> List[KeyData]:
"""
Returns the KeyData of all keys which can be retrieved.
"""
all_keys: List[KeyData] = []
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index, include_secrets)
all_keys.append(key_data)
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(include_secrets=include_secrets, skip_public_only=False):
all_keys.append(key_data)

return all_keys

def get_all_public_keys(self) -> List[ObservationRoot]:
"""
Returns all public keys.
"""
all_keys: List[ObservationRoot] = []
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index)
all_keys.append(key_data.observation_root)
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(skip_public_only=False):
all_keys.append(key_data.observation_root)

return all_keys

def get_all_public_keys_of_type(self, key_type: Type[_T_ObservationRoot]) -> List[_T_ObservationRoot]:
all_keys: List[_T_ObservationRoot] = []
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index)
if key_data.key_type == TYPES_TO_KEY_TYPES[key_type]:
assert isinstance(key_data.observation_root, key_type)
all_keys.append(key_data.observation_root)
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(skip_public_only=False):
if key_data.key_type == TYPES_TO_KEY_TYPES[key_type]:
assert isinstance(key_data.observation_root, key_type)
all_keys.append(key_data.observation_root)

return all_keys

def get_first_public_key(self) -> Optional[G1Element]:
Expand All @@ -588,10 +577,11 @@ def delete_key_by_fingerprint(self, fingerprint: int) -> int:
Deletes all keys which have the given public key fingerprint and returns how many keys were removed.
"""
removed = 0
for index in range(MAX_KEYS + 1):
# We duplicate ._iterate_through_key_datas due to needing the index
for index in range(MAX_KEYS):
try:
key_data = self._get_key_data(index, include_secrets=False)
if key_data.fingerprint == fingerprint:
if key_data is not None and key_data.fingerprint == fingerprint:
try:
self.keyring_wrapper.keyring.delete_label(key_data.fingerprint)
except (KeychainException, NotImplementedError):
Expand Down Expand Up @@ -623,12 +613,8 @@ def delete_all_keys(self) -> None:
"""
Deletes all keys from the keychain.
"""
for index in range(MAX_KEYS + 1):
try:
key_data = self._get_key_data(index)
self.delete_key_by_fingerprint(key_data.fingerprint)
except KeychainUserNotFound:
pass
for key_data in self._iterate_through_key_datas(include_secrets=False, skip_public_only=False):
self.delete_key_by_fingerprint(key_data.fingerprint)

@staticmethod
def is_keyring_locked() -> bool:
Expand Down

0 comments on commit 1f50bcc

Please sign in to comment.