diff --git a/chia/cmds/passphrase_funcs.py b/chia/cmds/passphrase_funcs.py index f3b85590fa4c..4e24804d4fd5 100644 --- a/chia/cmds/passphrase_funcs.py +++ b/chia/cmds/passphrase_funcs.py @@ -20,7 +20,6 @@ colorama.Fore.YELLOW + colorama.Style.BRIGHT + "(Unlock Keyring)" + colorama.Style.RESET_ALL + " Passphrase: " ) # noqa: E501 FAILED_ATTEMPT_DELAY = 0.5 -MAX_KEYS = 100 MAX_RETRIES = 3 SAVE_MASTER_PASSPHRASE_WARNING = ( colorama.Fore.YELLOW diff --git a/chia/legacy/keyring.py b/chia/legacy/keyring.py index 1821ce98e061..8ff1a021e15d 100644 --- a/chia/legacy/keyring.py +++ b/chia/legacy/keyring.py @@ -25,7 +25,7 @@ from chia.cmds.cmds_util import prompt_yes_no from chia.util.errors import KeychainUserNotFound -from chia.util.keychain import KeyData, KeyDataSecrets, KeyTypes, get_private_key_user +from chia.util.keychain import MAX_KEYS, KeyData, KeyDataSecrets, KeyTypes, get_private_key_user LegacyKeyring = Union[MacKeyring, WinKeyring, CryptFileKeyring] @@ -33,7 +33,6 @@ CURRENT_KEY_VERSION = "1.8" DEFAULT_USER = f"user-chia-{CURRENT_KEY_VERSION}" # e.g. user-chia-1.8 DEFAULT_SERVICE = f"chia-{DEFAULT_USER}" # e.g. chia-user-chia-1.8 -MAX_KEYS = 100 # casting to compensate for a combination of mypy and keyring issues @@ -90,7 +89,7 @@ def get_key_data(keyring: LegacyKeyring, index: int) -> KeyData: def get_keys(keyring: LegacyKeyring) -> List[KeyData]: keys: List[KeyData] = [] - for index in range(MAX_KEYS + 1): + for index in range(MAX_KEYS): try: keys.append(get_key_data(keyring, index)) except KeychainUserNotFound: @@ -114,7 +113,7 @@ def print_keys(keyring: LegacyKeyring) -> None: def remove_keys(keyring: LegacyKeyring) -> None: removed = 0 - for index in range(MAX_KEYS + 1): + for index in range(MAX_KEYS): try: keyring.delete_password(DEFAULT_SERVICE, get_private_key_user(DEFAULT_USER, index)) removed += 1 diff --git a/chia/util/keychain.py b/chia/util/keychain.py index 9c312563de2c..3a965a7e9bc2 100644 --- a/chia/util/keychain.py +++ b/chia/util/keychain.py @@ -7,7 +7,7 @@ from functools import cached_property from hashlib import pbkdf2_hmac from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload import importlib_resources from bitstring import BitArray # pyright: reportMissingImports=false @@ -37,7 +37,7 @@ CURRENT_KEY_VERSION = "1.8" DEFAULT_USER = f"user-chia-{CURRENT_KEY_VERSION}" # e.g. user-chia-1.8 DEFAULT_SERVICE = f"chia-{DEFAULT_USER}" # e.g. chia-user-chia-1.8 -MAX_KEYS = 100 +MAX_KEYS = 101 MIN_PASSPHRASE_LEN = 8 @@ -486,29 +486,34 @@ def delete_label(self, fingerprint: int) -> None: """ self.keyring_wrapper.keyring.delete_label(fingerprint) + def _iterate_through_key_datas( + self, include_secrets: bool = True, skip_public_only: bool = False + ) -> Iterator[KeyData]: + for index in range(MAX_KEYS): + try: + key_data = self._get_key_data(index, include_secrets=include_secrets) + if key_data is None or (skip_public_only and key_data.secrets is None): + continue + yield key_data + except KeychainUserNotFound: + pass + return None + def get_first_private_key(self) -> Optional[Tuple[PrivateKey, bytes]]: """ Returns the first key in the keychain that has one of the passed in passphrases. """ - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index) - return key_data.private_key, key_data.entropy - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(skip_public_only=True): + return key_data.private_key, key_data.entropy return None def get_private_key_by_fingerprint(self, fingerprint: int) -> Optional[Tuple[PrivateKey, bytes]]: """ Return first private key which have the given public key fingerprint. """ - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index) - if key_data.fingerprint == fingerprint: - return key_data.private_key, key_data.entropy - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(skip_public_only=True): + if key_data.fingerprint == fingerprint: + return key_data.private_key, key_data.entropy return None def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]: @@ -517,25 +522,18 @@ def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]: A tuple of key, and entropy bytes (i.e. mnemonic) is returned for each key. """ all_keys: List[Tuple[PrivateKey, bytes]] = [] - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index) - all_keys.append((key_data.private_key, key_data.entropy)) - except (KeychainUserNotFound, KeychainSecretsMissing): - pass + for key_data in self._iterate_through_key_datas(skip_public_only=True): + all_keys.append((key_data.private_key, key_data.entropy)) return all_keys def get_key(self, fingerprint: int, include_secrets: bool = False) -> KeyData: """ Return the KeyData of the first key which has the given public key fingerprint. """ - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index, include_secrets) - if key_data.observation_root.get_fingerprint() == fingerprint: - return key_data - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(include_secrets=include_secrets, skip_public_only=False): + if key_data.observation_root.get_fingerprint() == fingerprint: + return key_data + raise KeychainFingerprintNotFound(fingerprint) def get_keys(self, include_secrets: bool = False) -> List[KeyData]: @@ -543,12 +541,9 @@ def get_keys(self, include_secrets: bool = False) -> List[KeyData]: Returns the KeyData of all keys which can be retrieved. """ all_keys: List[KeyData] = [] - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index, include_secrets) - all_keys.append(key_data) - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(include_secrets=include_secrets, skip_public_only=False): + all_keys.append(key_data) + return all_keys def get_all_public_keys(self) -> List[ObservationRoot]: @@ -556,24 +551,18 @@ def get_all_public_keys(self) -> List[ObservationRoot]: Returns all public keys. """ all_keys: List[ObservationRoot] = [] - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index) - all_keys.append(key_data.observation_root) - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(skip_public_only=False): + all_keys.append(key_data.observation_root) + return all_keys def get_all_public_keys_of_type(self, key_type: Type[_T_ObservationRoot]) -> List[_T_ObservationRoot]: all_keys: List[_T_ObservationRoot] = [] - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index) - if key_data.key_type == TYPES_TO_KEY_TYPES[key_type]: - assert isinstance(key_data.observation_root, key_type) - all_keys.append(key_data.observation_root) - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(skip_public_only=False): + if key_data.key_type == TYPES_TO_KEY_TYPES[key_type]: + assert isinstance(key_data.observation_root, key_type) + all_keys.append(key_data.observation_root) + return all_keys def get_first_public_key(self) -> Optional[G1Element]: @@ -588,10 +577,11 @@ def delete_key_by_fingerprint(self, fingerprint: int) -> int: Deletes all keys which have the given public key fingerprint and returns how many keys were removed. """ removed = 0 - for index in range(MAX_KEYS + 1): + # We duplicate ._iterate_through_key_datas due to needing the index + for index in range(MAX_KEYS): try: key_data = self._get_key_data(index, include_secrets=False) - if key_data.fingerprint == fingerprint: + if key_data is not None and key_data.fingerprint == fingerprint: try: self.keyring_wrapper.keyring.delete_label(key_data.fingerprint) except (KeychainException, NotImplementedError): @@ -623,12 +613,8 @@ def delete_all_keys(self) -> None: """ Deletes all keys from the keychain. """ - for index in range(MAX_KEYS + 1): - try: - key_data = self._get_key_data(index) - self.delete_key_by_fingerprint(key_data.fingerprint) - except KeychainUserNotFound: - pass + for key_data in self._iterate_through_key_datas(include_secrets=False, skip_public_only=False): + self.delete_key_by_fingerprint(key_data.fingerprint) @staticmethod def is_keyring_locked() -> bool: