Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Feb 2, 2025
1 parent 11c07fb commit 57a54c1
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 63 deletions.
171 changes: 109 additions & 62 deletions src/adaptix/_internal/morphing/model/dumper_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
OutputShape,
)
from ...special_cases_optimization import as_is_stub, get_default_clause
from ...struct_trail import TrailElement, append_trail, extend_trail
from ...struct_trail import TrailElement, append_trail, extend_trail, render_trail_as_note
from ...utils import Omittable, Omitted
from ..json_schema.definitions import JSONSchema
from ..json_schema.schema_model import JSONSchemaType, JSONValue
Expand All @@ -59,7 +59,7 @@


class GenState:
def __init__(self, namespace: CascadeNamespace, debug_trail: DebugTrail):
def __init__(self, namespace: CascadeNamespace, debug_trail: DebugTrail, error_handler_name: str):
self.namespace = namespace
self.debug_trail = debug_trail

Expand All @@ -73,6 +73,7 @@ def __init__(self, namespace: CascadeNamespace, debug_trail: DebugTrail):
self.error_collectors: list[Statement] = []
self.overriden_error_collectors: dict[TrailElement, Callable[[Statement], Statement]] = {}
self.trail_to_collector_idx: dict[TrailElement, int] = {}
self.error_handler_name = error_handler_name

def _ensure_path_idx(self, path: CrownPath) -> str:
try:
Expand Down Expand Up @@ -149,37 +150,42 @@ def write_lines(self, writer: TextSliceWriter) -> None:


class ErrorCatching(Statement):
def __init__(self, state: GenState, trail_element: Optional[TrailElement], stmt: AssignmentStatement):
self._state = state
self._trail_element = trail_element
def __init__(self, state: GenState, trail_element: Optional[TrailElement], stmt: Statement):
self.state = state
self.trail_element = trail_element
self.stmt = stmt

def _get_trail_element_expr(self, trail_element: TrailElement) -> Expression:
literal_expr = get_literal_expr(trail_element)
if literal_expr is not None:
return RawExpr(literal_expr)

if trail_element in self._state.trail_element_to_name_idx:
idx = self._state.trail_element_to_name_idx[trail_element]
if trail_element in self.state.trail_element_to_name_idx:
idx = self.state.trail_element_to_name_idx[trail_element]
v_trail_element = f"trail_element_{idx}"
else:
idx = len(self._state.trail_element_to_name_idx)
self._state.trail_element_to_name_idx[trail_element] = idx
idx = len(self.state.trail_element_to_name_idx)
self.state.trail_element_to_name_idx[trail_element] = idx
v_trail_element = f"trail_element_{idx}"
self._state.namespace.add_constant(v_trail_element, trail_element)
self.state.namespace.add_constant(v_trail_element, trail_element)
return RawExpr(v_trail_element)

def _get_append_trail(self, trail_element: TrailElement) -> Expression:
if self.state.debug_trail == DebugTrail.ALL:
return CodeExpr(
"render_trail_as_note(append_trail(e, <trail_element>))",
trail_element=self._get_trail_element_expr(trail_element),
)
return CodeExpr(
"append_trail(e, <trail_element>)",
trail_element=self._get_trail_element_expr(self._trail_element),
trail_element=self._get_trail_element_expr(trail_element),
)

def _wrap_stmt(self, stmt: Statement) -> Statement:
if self._state.debug_trail == DebugTrail.DISABLE:
if self.state.debug_trail == DebugTrail.DISABLE:
return stmt
if self._state.debug_trail == DebugTrail.FIRST:
if self._trail_element is None:
if self.state.debug_trail == DebugTrail.FIRST:
if self.trail_element is None:
return stmt
return CodeBlock(
"""
Expand All @@ -190,38 +196,41 @@ def _wrap_stmt(self, stmt: Statement) -> Statement:
raise
""",
stmt=self.stmt,
append_trail=self._get_append_trail(self._trail_element),
append_trail=self._get_append_trail(self.trail_element),
)
if self._state.debug_trail == DebugTrail.ALL:
if self.state.debug_trail == DebugTrail.ALL:
idx = self._process_error_collecting(stmt)
return CodeBlock(
"""
try:
<stmt>
except Exception as e:
raise error_handler(<idx>, data, <append_trail>) from None
raise <error_handler>(<idx>, data, <append_trail>) from None
""",
stmt=self.stmt,
idx=RawExpr(repr(idx)),
append_trail=(
RawExpr("e")
if self._trail_element is None else
self._get_append_trail(self._trail_element)
if self.trail_element is None else
self._get_append_trail(self.trail_element)
),
error_handler=RawExpr(self.state.error_handler_name),
)
raise ValueError

def _process_error_collecting(self, stmt: Statement) -> int:
idx = len(self._state.error_collectors)
idx = len(self.state.error_collectors)

if self._trail_element is None:
self._state.error_collectors.append(stmt)
if self.trail_element is None:
self.state.error_collectors.append(stmt)
else:
if self._trail_element in self._state.trail_to_collector_idx:
return self._state.trail_to_collector_idx[self._trail_element]
if self.trail_element in self.state.trail_to_collector_idx:
if self.trail_element not in self.state.overriden_error_collectors:
raise ValueError
return self.state.trail_to_collector_idx[self.trail_element]

self._state.error_collectors.append(self._get_error_collector(stmt, self._trail_element))
self._state.trail_to_collector_idx[self._trail_element] = idx
self.state.error_collectors.append(self._get_error_collector(stmt, self.trail_element))
self.state.trail_to_collector_idx[self.trail_element] = idx
return idx

def _get_error_collector(self, stmt: Statement, trail_element: TrailElement) -> Statement:
Expand All @@ -230,8 +239,8 @@ def _get_error_collector(self, stmt: Statement, trail_element: TrailElement) ->
append_trail=self._get_append_trail(trail_element),
)

if self._trail_element in self._state.overriden_error_collectors:
return self._state.overriden_error_collectors[self._trail_element](error_saving)
if self.trail_element in self.state.overriden_error_collectors:
return self.state.overriden_error_collectors[self.trail_element](error_saving)

return CodeBlock(
"""
Expand Down Expand Up @@ -299,7 +308,7 @@ def _v_dumper(self, field: OutputField) -> str:
return f"dumper_{field.id}"

def _create_state(self, namespace: CascadeNamespace) -> GenState:
return GenState(namespace, self._debug_trail)
return GenState(namespace, self._debug_trail, f"error_handler")

def _alloc_var(self, state: GenState, name: str) -> VarExpr:
state.namespace.register_var(name)
Expand Down Expand Up @@ -328,6 +337,7 @@ def produce_code(self, closure_name: str) -> tuple[str, Mapping[str, object]]:
namespace.add_constant("CompatExceptionGroup", CompatExceptionGroup)
namespace.add_constant("append_trail", append_trail)
namespace.add_constant("extend_trail", extend_trail)
namespace.add_constant("render_trail_as_note", render_trail_as_note)
for field_id, dumper in self._fields_dumpers.items():
namespace.add_constant(self._v_dumper(self._id_to_field[field_id]), dumper)

Expand Down Expand Up @@ -473,12 +483,14 @@ def _get_error_handler(self, state: GenState) -> Statement:
]
return CodeBlock(
"""
def error_handler(idx, data, e):
def <error_handler>(idx, data, e):
errors = [e]
<error_collectors>
return ExceptionGroup("", errors)
return ExceptionGroup(<error_msg>, errors)
""",
error_collectors=statements(*error_collectors),
error_msg=StringLiteral(f"while dumping model {self._model_identity}"),
error_handler=RawExpr(state.error_handler_name),
)

def _get_access_expr(self, namespace: CascadeNamespace, field: OutputField) -> str:
Expand Down Expand Up @@ -623,12 +635,7 @@ def _get_dict_crown_out_stmt(self, state: GenState, crown: OutDictCrown) -> OutS
var=var,
)

def _wrap_with_dumper_call(
self,
state: GenState,
sub_crown: OutCrown,
var: VarExpr,
) -> OutVarStatement:
def _get_dumper_call(self, state: GenState, sub_crown: OutCrown, var: VarExpr) -> OutVarStatement:
if not isinstance(sub_crown, OutFieldCrown):
return OutVarStatement(var=var, stmt=statements())

Expand All @@ -652,6 +659,35 @@ def _wrap_with_dumper_call(
),
)

def _merge_error_catching(
self,
out_stmt: OutVarStatement,
dumper_call: OutVarStatement,
) -> OutVarStatement:
if (
isinstance(out_stmt.stmt, ErrorCatching)
and isinstance(dumper_call.stmt, ErrorCatching)
and out_stmt.stmt.trail_element == dumper_call.stmt.trail_element
):
return OutVarStatement(
var=dumper_call.var,
stmt=ErrorCatching(
state=out_stmt.stmt.state,
trail_element=out_stmt.stmt.trail_element,
stmt=statements(
out_stmt.stmt.stmt,
dumper_call.stmt.stmt,
),
),
)
return OutVarStatement(
var=dumper_call.var,
stmt=statements(
out_stmt.stmt,
dumper_call.stmt,
),
)

def _process_dict_sub_crown(
self,
state: GenState,
Expand All @@ -660,14 +696,15 @@ def _process_dict_sub_crown(
sub_crown: OutCrown,
out_stmt: OutStatement,
) -> None:
dumper_call = self._wrap_with_dumper_call(
dumper_call = self._get_dumper_call(
state=state,
sub_crown=sub_crown,
var=out_stmt.var,
)
if isinstance(out_stmt, OutVarStatement):
builder.before_stmts.append(out_stmt.stmt)
builder.before_stmts.append(dumper_call.stmt)
builder.before_stmts.append(
self._merge_error_catching(out_stmt=out_stmt, dumper_call=dumper_call).stmt,
)
builder.dict_items.append(
DictKeyValue(StringLiteral(key), dumper_call.var),
)
Expand Down Expand Up @@ -697,15 +734,19 @@ def _process_dict_sieved_sub_crown(
sub_crown: OutCrown,
out_stmt: OutStatement,
) -> None:
dumper_call = self._wrap_with_dumper_call(state, sub_crown, out_stmt.var)
dumper_call = self._get_dumper_call(state, sub_crown, out_stmt.var)
condition = self._get_sieve_condition(state, sieve, key, out_stmt.var)
conditional_append = CodeBlock(
"""
if <condition>:
<dict_append>
""",
condition=condition,
dict_append=self._get_dict_append(state, key, out_stmt.var),
conditional_append = statements(
CodeBlock(
"""
if <condition>:
<dumper_call>
<dict_append>
""",
condition=condition,
dumper_call=dumper_call.stmt,
dict_append=self._get_dict_append(state, key, dumper_call.var),
),
)
if isinstance(out_stmt, OptionalOutVarStatement):
stmt = out_stmt.stmt_maker(
Expand All @@ -714,14 +755,16 @@ def _process_dict_sieved_sub_crown(
on_unexpected_error=...,
)
builder.after_stmts.append(stmt)
if isinstance(out_stmt, OutVarStatement):
elif isinstance(out_stmt, OutVarStatement):
builder.after_stmts.append(
statements(
out_stmt.stmt,
conditional_append,
),
)
if isinstance(sub_crown, OutFieldCrown):
assert isinstance(out_stmt.stmt, ErrorCatching)
assert isinstance(dumper_call.stmt, ErrorCatching)
trail_element = self._id_to_field[sub_crown.id].accessor.trail_element
state.overriden_error_collectors[trail_element] = (
lambda error_saving: CodeBlock(
Expand All @@ -737,9 +780,9 @@ def _process_dict_sieved_sub_crown(
except Exception as e:
<error_saving>
""",
stmt=out_stmt.stmt,
stmt=out_stmt.stmt.stmt,
condition=condition,
dumper_call=dumper_call.stmt,
dumper_call=dumper_call.stmt.stmt,
error_saving=error_saving,
)
)
Expand All @@ -756,7 +799,7 @@ def _get_dict_append(
"<crown>[<key>] = <value>",
key=StringLiteral(key),
crown=RawExpr(state.v_crown),
dumped_value=value,
value=value,
)

def _get_sieve_condition(self, state: GenState, sieve: Sieve, key: str, test_var: VarExpr) -> Expression:
Expand Down Expand Up @@ -794,22 +837,26 @@ def _get_sieve_condition(self, state: GenState, sieve: Sieve, key: str, test_var
raise TypeError

def _get_list_crown_out_stmt(self, state: GenState, crown: OutListCrown) -> OutStatement:
out_stmts = [
self._get_crown_out_stmt(state, idx, sub_crown)
for idx, sub_crown in enumerate(crown.map)
]
dumped_out_stmts = [
self._wrap_with_dumper_call(
state,
sub_crown,
out_stmt.var,
self._merge_error_catching(
out_stmt=out_stmt,
dumper_call=self._get_dumper_call(
state,
sub_crown,
out_stmt.var,
),
)
for sub_crown, out_stmt in zip(
crown.map,
(
self._get_crown_out_stmt(state, idx, sub_crown)
for idx, sub_crown in enumerate(crown.map)
),
)
for sub_crown, out_stmt in zip(crown.map, out_stmts)
]
var = self._alloc_var(state, state.v_crown)
return OutVarStatement(
stmt=statements(
*(out_stmt.stmt for out_stmt in out_stmts),
*(out_stmt.stmt for out_stmt in dumped_out_stmts),
AssignmentStatement(
var=var,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/morphing/model/test_dumper_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class MyAccessError(Exception):
pass


@dataclass
@dataclass(frozen=True)
class MyTrailElemMarker(TrailElementMarker):
value: Any

Expand Down

0 comments on commit 57a54c1

Please sign in to comment.