From 11123f85f72453590a32249f5b15b98c254a8dcb Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Wed, 14 Aug 2024 18:04:49 -0400 Subject: [PATCH] refine streamable field typing to avoid an ignore (#18445) * refine streamable field typing to avoid an ignore * try again * tidy * maybe more * overload * non-overload * typing_extensions --- chia/wallet/util/clvm_streamable.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/chia/wallet/util/clvm_streamable.py b/chia/wallet/util/clvm_streamable.py index 8f0c5e4f7114..109541aee8f0 100644 --- a/chia/wallet/util/clvm_streamable.py +++ b/chia/wallet/util/clvm_streamable.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_type_hints from hsms.clvm_serde import from_program_for_type, to_program_for_type +from typing_extensions import TypeGuard from chia.types.blockchain_format.program import Program from chia.util.streamable import ( @@ -53,14 +54,14 @@ def byte_serialize_clvm_streamable( def json_serialize_with_clvm_streamable( - streamable: Any, + streamable: object, next_recursion_step: Optional[Callable[..., Dict[str, Any]]] = None, translation_layer: Optional[TranslationLayer] = None, **next_recursion_env: Any, ) -> Union[str, Dict[str, Any]]: if next_recursion_step is None: next_recursion_step = recurse_jsonify - if hasattr(streamable, "_clvm_streamable"): + if is_clvm_streamable(streamable): # If we are using clvm_serde, we stop JSON serialization at this point and instead return the clvm blob return byte_serialize_clvm_streamable(streamable, translation_layer=translation_layer).hex() else: @@ -97,6 +98,18 @@ def is_compound_type(typ: Any) -> bool: return is_type_SpecificOptional(typ) or is_type_Tuple(typ) or is_type_List(typ) +# TODO: this is more than _just_ a Streamable, but it is also a Streamable and that's +# useful for now +def is_clvm_streamable_type(v: Type[object]) -> TypeGuard[Type[Streamable]]: + return issubclass(v, Streamable) and hasattr(v, "_clvm_streamable") + + +# TODO: this is more than _just_ a Streamable, but it is also a Streamable and that's +# useful for now +def is_clvm_streamable(v: object) -> TypeGuard[Streamable]: + return isinstance(v, Streamable) and hasattr(v, "_clvm_streamable") + + def json_deserialize_with_clvm_streamable( json_dict: Union[str, Dict[str, Any]], streamable_type: Type[_T_Streamable], @@ -112,7 +125,7 @@ def json_deserialize_with_clvm_streamable( for old_field in old_streamable_fields: if is_compound_type(old_field.type): inner_type = get_args(old_field.type)[0] - if hasattr(inner_type, "_clvm_streamable"): + if is_clvm_streamable_type(inner_type): new_streamable_fields.append( dataclasses.replace( old_field, @@ -128,11 +141,11 @@ def json_deserialize_with_clvm_streamable( ) else: new_streamable_fields.append(old_field) - elif hasattr(old_field.type, "_clvm_streamable"): + elif is_clvm_streamable_type(old_field.type): new_streamable_fields.append( dataclasses.replace( old_field, - convert_function=functools.partial( # type: ignore[type-var] + convert_function=functools.partial( json_deserialize_with_clvm_streamable, streamable_type=old_field.type, translation_layer=translation_layer,