Skip to content

Commit

Permalink
Some bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Feb 2, 2025
1 parent 45f2ee0 commit 53bb52c
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 39 deletions.
60 changes: 37 additions & 23 deletions src/adaptix/_internal/code_tools/code_gen_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager, contextmanager
from re import RegexFlag
from textwrap import dedent


class TextSliceWriter(AbstractContextManager[None]):
Expand All @@ -12,20 +13,21 @@ def write(self, text: str, /) -> None:


class LinesWriter(TextSliceWriter):
__slots__ = ("_new_line_replacer", "_slices")
__slots__ = ("_indent", "_slices")

def __init__(self, start_indent: str = ""):
self._slices: list[str] = []
self._new_line_replacer = f"\n{start_indent}"
self._indent = start_indent

def __enter__(self) -> None:
self._new_line_replacer += " "
self._indent += " "

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self._new_line_replacer = self._new_line_replacer[:-4]
self._indent = self._indent[:-4]

def write(self, text: str) -> None:
self._slices.append(text.replace("\n", self._new_line_replacer))
new_line_indented = text.replace("\n", f"\n{self._indent}")
self._slices.append(f"{self._indent}{new_line_indented}")

def make_string(self) -> str:
return "\n".join(self._slices)
Expand All @@ -43,6 +45,13 @@ def at_one_line(writer: TextSliceWriter):
writer.write(sub_writer.make_string())


@contextmanager
def at_multi_line(writer: TextSliceWriter):
sub_writer = LinesWriter()
yield sub_writer
writer.write(sub_writer.make_string())


class Statement(ABC):
@abstractmethod
def write_lines(self, writer: TextSliceWriter) -> None:
Expand Down Expand Up @@ -83,22 +92,22 @@ def __init__(self, template: str, **stmts: Statement):
self._template = template
self._name_to_stmt = stmts

_PLACEHOLDER_REGEX = re.compile(r"<\w+>", RegexFlag.MULTILINE)
_INDENT_REGEX = re.compile(r"^\s*", RegexFlag.MULTILINE)
_PLACEHOLDER_REGEX = re.compile(r"<(\w+)>", RegexFlag.MULTILINE)
_INDENT_REGEX = re.compile(r"^[ \t]*", RegexFlag.MULTILINE)

def _format_template(self) -> str:
return self._PLACEHOLDER_REGEX.sub(self._replace_placeholder, self._template)
return self._PLACEHOLDER_REGEX.sub(self._replace_placeholder, dedent(self._template).strip())

def _replace_placeholder(self, match: re.Match[str]) -> str:
stmt = self._name_to_stmt[match.group(0)]
start_idx = match.string.rfind("\n", 0, match.pos)
indent_match = self._INDENT_REGEX.search(match.string, start_idx)
stmt = self._name_to_stmt[match.group(1)]
start_idx = match.string.rfind("\n", 0, match.end(0))
indent_match = self._INDENT_REGEX.search(match.string, 0 if start_idx == -1 else start_idx)
if indent_match is None:
raise ValueError

writer = LinesWriter(indent_match.group(0))
stmt.write_lines(writer)
return writer.make_string()
return writer.make_string().lstrip()

def write_lines(self, writer: TextSliceWriter) -> None:
writer.write(self._format_template())
Expand All @@ -116,29 +125,33 @@ def __init__(self, template: str, **exprs: Expression):

class DictItem(ABC):
@abstractmethod
def write_item_line(self, sub_writer: TextSliceWriter) -> None:
def write_fragment(self, writer: TextSliceWriter) -> None:
...


class MappingUnpack(DictItem):
def __init__(self, expr: Expression):
self._expr = expr

def write_item_line(self, sub_writer: TextSliceWriter) -> None:
sub_writer.write("**")
self._expr.write_lines(sub_writer)
def write_fragment(self, writer: TextSliceWriter) -> None:
writer.write("**")
with at_multi_line(writer) as sub_writer:
self._expr.write_lines(sub_writer)
writer.write(",")


class DictKeyValue(DictItem):
def __init__(self, key: Expression, value: Expression):
self._key = key
self._value = value

def write_item_line(self, sub_writer: TextSliceWriter) -> None:
self._key.write_lines(sub_writer)
sub_writer.write(": ")
self._value.write_lines(sub_writer)
sub_writer.write(",")
def write_fragment(self, writer: TextSliceWriter) -> None:
with at_multi_line(writer) as sub_writer:
self._key.write_lines(sub_writer)
writer.write(": ")
with at_multi_line(writer) as sub_writer:
self._value.write_lines(sub_writer)
writer.write(",")


class DictLiteral(Expression):
Expand All @@ -150,7 +163,7 @@ def write_lines(self, writer: TextSliceWriter) -> None:
with writer:
for item in self._items:
with at_one_line(writer) as sub_writer:
item.write_item_line(sub_writer)
item.write_fragment(sub_writer)
writer.write("}")


Expand All @@ -163,7 +176,8 @@ def write_lines(self, writer: TextSliceWriter) -> None:
with writer:
for item in self._items:
with at_one_line(writer) as sub_writer:
item.write_lines(sub_writer)
with at_multi_line(sub_writer) as sub_sub_writer:
item.write_lines(sub_sub_writer)
sub_writer.write(",")
writer.write("]")

Expand Down
38 changes: 22 additions & 16 deletions src/adaptix/_internal/morphing/model/dumper_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def write_lines(self, writer: TextSliceWriter) -> None:
raise ValueError
self._state.error_handlings.append((self.stmt, self._trail_element_expr))
if self._state.finalized:
writer.write(f"raise error_handler({idx}, data, {exc}) from None")
else:
writer.write(f"errors.append({exc})")
else:
writer.write(f"raise error_handler({idx}, data, {exc}) from None")


class OutVarStatement(NamedTuple):
Expand Down Expand Up @@ -227,7 +227,7 @@ def <closure_name>(data):
if self._debug_trail == DebugTrail.ALL:
error_handler_writer = LinesWriter()
self._get_error_handler(state).write_lines(error_handler_writer)
result += error_handler_writer.make_string()
result += "\n" + error_handler_writer.make_string()
return result, namespace.all_constants

def _get_body_statement(self, state: GenState) -> Statement:
Expand Down Expand Up @@ -383,11 +383,11 @@ def _get_extra_extract_extraction(self, state: GenState, extra_move: ExtraExtrac

def _get_error_handler(self, state: GenState) -> Statement:
error_scanning_stmts = []
for idx, (stmt, trail_element) in enumerate(state.error_handlings):
for idx, (stmt, _) in enumerate(state.error_handlings):
error_scanning_stmts.append(
CodeBlock(
"""
if idx > <idx>:
if idx < <idx>:
<stmt>
""",
idx=RawExpr(repr(idx)),
Expand Down Expand Up @@ -513,7 +513,7 @@ def _get_crown_out_stmt(self, state: GenState, key: CrownPathElem, crown: OutCro
def _get_dict_crown_out_stmt(self, state: GenState, crown: OutDictCrown) -> OutStatement:
builder = DictBuilder()
for key, sub_crown in crown.map.items():
if key in crown.sieves:
if key not in crown.sieves:
self._process_dict_sub_crown(
state=state,
builder=builder,
Expand Down Expand Up @@ -560,15 +560,20 @@ def _wrap_with_dumper_call(
if not isinstance(sub_crown, OutFieldCrown):
return expr

if self._fields_dumpers[sub_crown.id] == as_is_stub:
return expr
field = self._id_to_field[sub_crown.id]
trail_element = self._get_trail_element_expr(state.namespace, field)
dumper_call = CodeExpr(
"<dumper>(<expr>)",
dumper=RawExpr(self._v_dumper(field)),
expr=expr,
)
if self._fields_dumpers[sub_crown.id] == as_is_stub:
if self._debug_trail == DebugTrail.DISABLE:
return expr

dumper_call = expr
else:
dumper_call = CodeExpr(
"<dumper>(<expr>)",
dumper=RawExpr(self._v_dumper(field)),
expr=expr,
)

out_variable = f"dumped_{field.id}"
state.namespace.register_var(out_variable)

Expand Down Expand Up @@ -637,7 +642,7 @@ def _process_dict_sub_crown(
builder.dict_items.append(DictKeyValue(StringLiteral(key), RawExpr(dumper_call.var)))
else:
raise TypeError
if isinstance(out_stmt, OutVarStatement):
elif isinstance(out_stmt, OutVarStatement):
builder.before_stmts.append(out_stmt.stmt)
self._process_dict_sub_crown(
state=state,
Expand All @@ -646,7 +651,7 @@ def _process_dict_sub_crown(
sub_crown=sub_crown,
out_stmt=RawExpr(out_stmt.var),
)
if isinstance(out_stmt, OptionalOutVarStatement):
elif isinstance(out_stmt, OptionalOutVarStatement):
error_handler_call = self._get_error_handler_call_for_sub_crown(state, sub_crown)
stmt = out_stmt.stmt_maker(
on_access_ok=self._get_dict_append(
Expand All @@ -661,7 +666,8 @@ def _process_dict_sub_crown(
)
error_handler_call.stmt = stmt
builder.after_stmts.append(stmt)
raise TypeError
else:
raise TypeError

def _process_dict_sieved_sub_crown(
self,
Expand Down
122 changes: 122 additions & 0 deletions tests/unit/code_tools/test_code_gen_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from textwrap import dedent

from adaptix._internal.code_tools.code_gen_tree import (
CodeBlock,
DictKeyValue,
DictLiteral,
LinesWriter,
ListLiteral,
MappingUnpack,
RawExpr,
Statement,
)


def assert_string(stmt: Statement, string: str) -> str:
writer = LinesWriter()
stmt.write_lines(writer)
assert writer.make_string() == dedent(string).strip()


def test_dict_literal():
assert_string(
DictLiteral(
[
DictKeyValue(RawExpr("a"), RawExpr("1")),
DictKeyValue(RawExpr("b"), RawExpr("2")),
MappingUnpack(RawExpr("c")),
],
),
"""
{
a: 1,
b: 2,
**c,
}
""",
)


def test_dict_literal_nested():
assert_string(
DictLiteral(
[
DictKeyValue(RawExpr("a"), RawExpr("1")),
DictKeyValue(
RawExpr("b"),
DictLiteral(
[
DictKeyValue(RawExpr("c"), RawExpr("2")),
DictKeyValue(RawExpr("d"), RawExpr("3")),
],
),
),
MappingUnpack(RawExpr("e")),
],
),
"""
{
a: 1,
b: {
c: 2,
d: 3,
},
**e,
}
""",
)


def test_list_literal():
assert_string(
ListLiteral(
[
RawExpr("a"),
RawExpr("b"),
],
),
"""
[
a,
b,
]
""",
)


def test_code_block():
assert_string(
CodeBlock(
"""
if <condition>:
<true_case>
else:
<false_case>
""",
condition=RawExpr("zzz"),
true_case=ListLiteral(
[
RawExpr("a"),
RawExpr("b"),
],
),
false_case=DictLiteral(
[
DictKeyValue(RawExpr("a"), RawExpr("1")),
DictKeyValue(RawExpr("b"), RawExpr("2")),
],
),
),
"""
if zzz:
[
a,
b,
]
else:
{
a: 1,
b: 2,
}
""",
)

0 comments on commit 53bb52c

Please sign in to comment.