From 57a54c16c337fd33b394ed8a9690ff09f17acc90 Mon Sep 17 00:00:00 2001 From: pavel Date: Sun, 2 Feb 2025 16:30:21 +0300 Subject: [PATCH] Fix bugs --- .../_internal/morphing/model/dumper_gen.py | 171 +++++++++++------- .../morphing/model/test_dumper_provider.py | 2 +- 2 files changed, 110 insertions(+), 63 deletions(-) diff --git a/src/adaptix/_internal/morphing/model/dumper_gen.py b/src/adaptix/_internal/morphing/model/dumper_gen.py index 663dd6ce..ada53e70 100644 --- a/src/adaptix/_internal/morphing/model/dumper_gen.py +++ b/src/adaptix/_internal/morphing/model/dumper_gen.py @@ -36,7 +36,7 @@ OutputShape, ) from ...special_cases_optimization import as_is_stub, get_default_clause -from ...struct_trail import TrailElement, append_trail, extend_trail +from ...struct_trail import TrailElement, append_trail, extend_trail, render_trail_as_note from ...utils import Omittable, Omitted from ..json_schema.definitions import JSONSchema from ..json_schema.schema_model import JSONSchemaType, JSONValue @@ -59,7 +59,7 @@ class GenState: - def __init__(self, namespace: CascadeNamespace, debug_trail: DebugTrail): + def __init__(self, namespace: CascadeNamespace, debug_trail: DebugTrail, error_handler_name: str): self.namespace = namespace self.debug_trail = debug_trail @@ -73,6 +73,7 @@ def __init__(self, namespace: CascadeNamespace, debug_trail: DebugTrail): self.error_collectors: list[Statement] = [] self.overriden_error_collectors: dict[TrailElement, Callable[[Statement], Statement]] = {} self.trail_to_collector_idx: dict[TrailElement, int] = {} + self.error_handler_name = error_handler_name def _ensure_path_idx(self, path: CrownPath) -> str: try: @@ -149,9 +150,9 @@ def write_lines(self, writer: TextSliceWriter) -> None: class ErrorCatching(Statement): - def __init__(self, state: GenState, trail_element: Optional[TrailElement], stmt: AssignmentStatement): - self._state = state - self._trail_element = trail_element + def __init__(self, state: GenState, trail_element: Optional[TrailElement], stmt: Statement): + self.state = state + self.trail_element = trail_element self.stmt = stmt def _get_trail_element_expr(self, trail_element: TrailElement) -> Expression: @@ -159,27 +160,32 @@ def _get_trail_element_expr(self, trail_element: TrailElement) -> Expression: if literal_expr is not None: return RawExpr(literal_expr) - if trail_element in self._state.trail_element_to_name_idx: - idx = self._state.trail_element_to_name_idx[trail_element] + if trail_element in self.state.trail_element_to_name_idx: + idx = self.state.trail_element_to_name_idx[trail_element] v_trail_element = f"trail_element_{idx}" else: - idx = len(self._state.trail_element_to_name_idx) - self._state.trail_element_to_name_idx[trail_element] = idx + idx = len(self.state.trail_element_to_name_idx) + self.state.trail_element_to_name_idx[trail_element] = idx v_trail_element = f"trail_element_{idx}" - self._state.namespace.add_constant(v_trail_element, trail_element) + self.state.namespace.add_constant(v_trail_element, trail_element) return RawExpr(v_trail_element) def _get_append_trail(self, trail_element: TrailElement) -> Expression: + if self.state.debug_trail == DebugTrail.ALL: + return CodeExpr( + "render_trail_as_note(append_trail(e, ))", + trail_element=self._get_trail_element_expr(trail_element), + ) return CodeExpr( "append_trail(e, )", - trail_element=self._get_trail_element_expr(self._trail_element), + trail_element=self._get_trail_element_expr(trail_element), ) def _wrap_stmt(self, stmt: Statement) -> Statement: - if self._state.debug_trail == DebugTrail.DISABLE: + if self.state.debug_trail == DebugTrail.DISABLE: return stmt - if self._state.debug_trail == DebugTrail.FIRST: - if self._trail_element is None: + if self.state.debug_trail == DebugTrail.FIRST: + if self.trail_element is None: return stmt return CodeBlock( """ @@ -190,38 +196,41 @@ def _wrap_stmt(self, stmt: Statement) -> Statement: raise """, stmt=self.stmt, - append_trail=self._get_append_trail(self._trail_element), + append_trail=self._get_append_trail(self.trail_element), ) - if self._state.debug_trail == DebugTrail.ALL: + if self.state.debug_trail == DebugTrail.ALL: idx = self._process_error_collecting(stmt) return CodeBlock( """ try: except Exception as e: - raise error_handler(, data, ) from None + raise (, data, ) from None """, stmt=self.stmt, idx=RawExpr(repr(idx)), append_trail=( RawExpr("e") - if self._trail_element is None else - self._get_append_trail(self._trail_element) + if self.trail_element is None else + self._get_append_trail(self.trail_element) ), + error_handler=RawExpr(self.state.error_handler_name), ) raise ValueError def _process_error_collecting(self, stmt: Statement) -> int: - idx = len(self._state.error_collectors) + idx = len(self.state.error_collectors) - if self._trail_element is None: - self._state.error_collectors.append(stmt) + if self.trail_element is None: + self.state.error_collectors.append(stmt) else: - if self._trail_element in self._state.trail_to_collector_idx: - return self._state.trail_to_collector_idx[self._trail_element] + if self.trail_element in self.state.trail_to_collector_idx: + if self.trail_element not in self.state.overriden_error_collectors: + raise ValueError + return self.state.trail_to_collector_idx[self.trail_element] - self._state.error_collectors.append(self._get_error_collector(stmt, self._trail_element)) - self._state.trail_to_collector_idx[self._trail_element] = idx + self.state.error_collectors.append(self._get_error_collector(stmt, self.trail_element)) + self.state.trail_to_collector_idx[self.trail_element] = idx return idx def _get_error_collector(self, stmt: Statement, trail_element: TrailElement) -> Statement: @@ -230,8 +239,8 @@ def _get_error_collector(self, stmt: Statement, trail_element: TrailElement) -> append_trail=self._get_append_trail(trail_element), ) - if self._trail_element in self._state.overriden_error_collectors: - return self._state.overriden_error_collectors[self._trail_element](error_saving) + if self.trail_element in self.state.overriden_error_collectors: + return self.state.overriden_error_collectors[self.trail_element](error_saving) return CodeBlock( """ @@ -299,7 +308,7 @@ def _v_dumper(self, field: OutputField) -> str: return f"dumper_{field.id}" def _create_state(self, namespace: CascadeNamespace) -> GenState: - return GenState(namespace, self._debug_trail) + return GenState(namespace, self._debug_trail, f"error_handler") def _alloc_var(self, state: GenState, name: str) -> VarExpr: state.namespace.register_var(name) @@ -328,6 +337,7 @@ def produce_code(self, closure_name: str) -> tuple[str, Mapping[str, object]]: namespace.add_constant("CompatExceptionGroup", CompatExceptionGroup) namespace.add_constant("append_trail", append_trail) namespace.add_constant("extend_trail", extend_trail) + namespace.add_constant("render_trail_as_note", render_trail_as_note) for field_id, dumper in self._fields_dumpers.items(): namespace.add_constant(self._v_dumper(self._id_to_field[field_id]), dumper) @@ -473,12 +483,14 @@ def _get_error_handler(self, state: GenState) -> Statement: ] return CodeBlock( """ - def error_handler(idx, data, e): + def (idx, data, e): errors = [e] - return ExceptionGroup("", errors) + return ExceptionGroup(, errors) """, error_collectors=statements(*error_collectors), + error_msg=StringLiteral(f"while dumping model {self._model_identity}"), + error_handler=RawExpr(state.error_handler_name), ) def _get_access_expr(self, namespace: CascadeNamespace, field: OutputField) -> str: @@ -623,12 +635,7 @@ def _get_dict_crown_out_stmt(self, state: GenState, crown: OutDictCrown) -> OutS var=var, ) - def _wrap_with_dumper_call( - self, - state: GenState, - sub_crown: OutCrown, - var: VarExpr, - ) -> OutVarStatement: + def _get_dumper_call(self, state: GenState, sub_crown: OutCrown, var: VarExpr) -> OutVarStatement: if not isinstance(sub_crown, OutFieldCrown): return OutVarStatement(var=var, stmt=statements()) @@ -652,6 +659,35 @@ def _wrap_with_dumper_call( ), ) + def _merge_error_catching( + self, + out_stmt: OutVarStatement, + dumper_call: OutVarStatement, + ) -> OutVarStatement: + if ( + isinstance(out_stmt.stmt, ErrorCatching) + and isinstance(dumper_call.stmt, ErrorCatching) + and out_stmt.stmt.trail_element == dumper_call.stmt.trail_element + ): + return OutVarStatement( + var=dumper_call.var, + stmt=ErrorCatching( + state=out_stmt.stmt.state, + trail_element=out_stmt.stmt.trail_element, + stmt=statements( + out_stmt.stmt.stmt, + dumper_call.stmt.stmt, + ), + ), + ) + return OutVarStatement( + var=dumper_call.var, + stmt=statements( + out_stmt.stmt, + dumper_call.stmt, + ), + ) + def _process_dict_sub_crown( self, state: GenState, @@ -660,14 +696,15 @@ def _process_dict_sub_crown( sub_crown: OutCrown, out_stmt: OutStatement, ) -> None: - dumper_call = self._wrap_with_dumper_call( + dumper_call = self._get_dumper_call( state=state, sub_crown=sub_crown, var=out_stmt.var, ) if isinstance(out_stmt, OutVarStatement): - builder.before_stmts.append(out_stmt.stmt) - builder.before_stmts.append(dumper_call.stmt) + builder.before_stmts.append( + self._merge_error_catching(out_stmt=out_stmt, dumper_call=dumper_call).stmt, + ) builder.dict_items.append( DictKeyValue(StringLiteral(key), dumper_call.var), ) @@ -697,15 +734,19 @@ def _process_dict_sieved_sub_crown( sub_crown: OutCrown, out_stmt: OutStatement, ) -> None: - dumper_call = self._wrap_with_dumper_call(state, sub_crown, out_stmt.var) + dumper_call = self._get_dumper_call(state, sub_crown, out_stmt.var) condition = self._get_sieve_condition(state, sieve, key, out_stmt.var) - conditional_append = CodeBlock( - """ - if : - - """, - condition=condition, - dict_append=self._get_dict_append(state, key, out_stmt.var), + conditional_append = statements( + CodeBlock( + """ + if : + + + """, + condition=condition, + dumper_call=dumper_call.stmt, + dict_append=self._get_dict_append(state, key, dumper_call.var), + ), ) if isinstance(out_stmt, OptionalOutVarStatement): stmt = out_stmt.stmt_maker( @@ -714,7 +755,7 @@ def _process_dict_sieved_sub_crown( on_unexpected_error=..., ) builder.after_stmts.append(stmt) - if isinstance(out_stmt, OutVarStatement): + elif isinstance(out_stmt, OutVarStatement): builder.after_stmts.append( statements( out_stmt.stmt, @@ -722,6 +763,8 @@ def _process_dict_sieved_sub_crown( ), ) if isinstance(sub_crown, OutFieldCrown): + assert isinstance(out_stmt.stmt, ErrorCatching) + assert isinstance(dumper_call.stmt, ErrorCatching) trail_element = self._id_to_field[sub_crown.id].accessor.trail_element state.overriden_error_collectors[trail_element] = ( lambda error_saving: CodeBlock( @@ -737,9 +780,9 @@ def _process_dict_sieved_sub_crown( except Exception as e: """, - stmt=out_stmt.stmt, + stmt=out_stmt.stmt.stmt, condition=condition, - dumper_call=dumper_call.stmt, + dumper_call=dumper_call.stmt.stmt, error_saving=error_saving, ) ) @@ -756,7 +799,7 @@ def _get_dict_append( "[] = ", key=StringLiteral(key), crown=RawExpr(state.v_crown), - dumped_value=value, + value=value, ) def _get_sieve_condition(self, state: GenState, sieve: Sieve, key: str, test_var: VarExpr) -> Expression: @@ -794,22 +837,26 @@ def _get_sieve_condition(self, state: GenState, sieve: Sieve, key: str, test_var raise TypeError def _get_list_crown_out_stmt(self, state: GenState, crown: OutListCrown) -> OutStatement: - out_stmts = [ - self._get_crown_out_stmt(state, idx, sub_crown) - for idx, sub_crown in enumerate(crown.map) - ] dumped_out_stmts = [ - self._wrap_with_dumper_call( - state, - sub_crown, - out_stmt.var, + self._merge_error_catching( + out_stmt=out_stmt, + dumper_call=self._get_dumper_call( + state, + sub_crown, + out_stmt.var, + ), + ) + for sub_crown, out_stmt in zip( + crown.map, + ( + self._get_crown_out_stmt(state, idx, sub_crown) + for idx, sub_crown in enumerate(crown.map) + ), ) - for sub_crown, out_stmt in zip(crown.map, out_stmts) ] var = self._alloc_var(state, state.v_crown) return OutVarStatement( stmt=statements( - *(out_stmt.stmt for out_stmt in out_stmts), *(out_stmt.stmt for out_stmt in dumped_out_stmts), AssignmentStatement( var=var, diff --git a/tests/unit/morphing/model/test_dumper_provider.py b/tests/unit/morphing/model/test_dumper_provider.py index 5abd7b30..00334508 100644 --- a/tests/unit/morphing/model/test_dumper_provider.py +++ b/tests/unit/morphing/model/test_dumper_provider.py @@ -139,7 +139,7 @@ class MyAccessError(Exception): pass -@dataclass +@dataclass(frozen=True) class MyTrailElemMarker(TrailElementMarker): value: Any