Skip to content

Commit

Permalink
feat: Add item method to Extend Template (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored Dec 24, 2024
1 parent 980e226 commit f65eafe
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
5 changes: 4 additions & 1 deletion mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def __getitem__(
start, stop, step = key.start, key.stop, key.step
return ExtendTemplate(connections=[self, start, stop, step], model="slice")
elif isinstance(key, int | tuple):
return ExtendTemplate(connections=[self, key], model="item")
return ExtendTemplate(connections=[self, key], model="get_item")
else:
raise TypeError(f"Unsupported key type: {type(key)}")

Expand Down Expand Up @@ -1072,6 +1072,9 @@ def transpose(self, axes: tuple[int, ...] | TemplateBase | None = None):
def split(self, split_size: int, axis: int):
return ExtendTemplate(connections=[self, split_size, axis], model="split")

def item(self):
return ExtendTemplate(connections=[self], model="item")


class ExtendTemplate(TemplateBase):
output_connection: ConnectionData | None
Expand Down
6 changes: 4 additions & 2 deletions mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
FloorDivide,
Greater,
GreaterEqual,
Item,
Length,
Less,
LessEqual,
Expand Down Expand Up @@ -110,6 +111,7 @@
"size": Size,
"tensor": ToTensor,
"list": TensorToList,
"item": Item,
"mean": Mean,
"sqrt": Sqrt,
"exp": Exponential,
Expand Down Expand Up @@ -138,8 +140,8 @@


coercion_table: dict[tuple[str, type[Tensor] | type[Scalar]], type[PrimitiveModel]] = {
("item", Tensor): TensorItem,
("item", Scalar): ScalarItem,
("get_item", Tensor): TensorItem,
("get_item", Scalar): ScalarItem,
("slice", Tensor): TensorSlice,
("slice", Scalar): PrimitiveSlice,
}
Expand Down
21 changes: 21 additions & 0 deletions tests/scripts/test_extend_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Greater,
GreaterEqual,
IOKey,
Item,
Less,
LessEqual,
Linear,
Expand Down Expand Up @@ -1676,3 +1677,23 @@ def test_immediate_values_with_extend_template_and_regular_case():
== big_model_1.conns.latent_input_keys
== {"$1"}
)


def test_item():
model1 = Model(enforce_jit=False)

buffer_model_1 = Buffer()
item_model = Item()
totensor = ToTensor()

model1 += buffer_model_1(input="input")
model1 += item_model(input=buffer_model_1.output)
model1 += totensor(input=item_model.output, output=IOKey("output"))

model2 = Model(enforce_jit=False)
buffer_model_1 = Buffer()
model2 += buffer_model_1(input="input")
conn = buffer_model_1.output.item()
model2 += ToTensor()(input=conn, output=IOKey("output"))

check_logical_models(model1, model2)

0 comments on commit f65eafe

Please sign in to comment.