Skip to content

Commit

Permalink
Add connections field to IOKey, remove Connect object, update extend …
Browse files Browse the repository at this point in the history
…dict_conversions and corresponding tests
  • Loading branch information
kberat-synnada committed Dec 18, 2024
1 parent d65f1e7 commit fce6951
Show file tree
Hide file tree
Showing 18 changed files with 212 additions and 234 deletions.
3 changes: 1 addition & 2 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
short,
)
from .framework.codegen import code_gen_map
from .framework.common import TBD, Connect, Connection, Constant, IOKey
from .framework.common import TBD, Connection, Constant, IOKey
from .framework.physical.model import PhysicalConstantType, PhysicalShapeType
from .models import BaseModel, PhysicalModel
from .models.train_model import TrainModel
Expand All @@ -60,7 +60,6 @@
"Backend",
"IOKey",
"TBD",
"Connect",
"Constant",
"epsilon_table",
]
Expand Down
23 changes: 9 additions & 14 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
"IOHyperEdge",
"Connection",
"ConnectionData",
"Connect",
"Connections",
"ShapeNode",
"ShapeRepr",
Expand Down Expand Up @@ -1102,8 +1101,9 @@ def __init__(
value: TensorValueType | MainValueType | ToBeDetermined | str = TBD,
shape: ShapeTemplateType | None = None,
type: NestedListType | UnionType | type | None = None,
expose: bool = True,
expose: bool | None = None,
interval: list[float | int] | None = None,
connections: list[Connection | str] | None = None,
) -> None:
super().__init__()
self._name = name
Expand All @@ -1112,6 +1112,7 @@ def __init__(
self._type = type
self._expose = expose
self._interval = interval
self._connections: OrderedSet[ConnectionData | str] = OrderedSet()

# TODO: Shape should not be [] also!
if self._value is not TBD and self._shape is not None and self._shape != []:
Expand All @@ -1128,6 +1129,12 @@ def __init__(
f"type is {self._type} while type of value is {value_type}"
)

connections = connections or []
for item in connections:
conn: ConnectionData | str
conn = item.data if isinstance(item, Connection) else item
self._connections.add(conn)

def __hash__(self) -> int:
return hash(id(self))

Expand Down Expand Up @@ -1197,20 +1204,9 @@ def set_differentiable(self, differentiable: bool = True) -> None:
)


class Connect:
def __init__(self, *connections: Connection | str, key: IOKey | None = None):
self.connections: OrderedSet[ConnectionData | str] = OrderedSet()
self.key = key
for item in connections:
conn: ConnectionData | str
conn = item.data if isinstance(item, Connection) else item
self.connections.add(conn)


ConnectionType = (
str
| ConnectionData
| Connect
| MainValueType
| ExtendTemplate
| NullConnection
Expand All @@ -1223,7 +1219,6 @@ def __init__(self, *connections: Connection | str, key: IOKey | None = None):
ConnectionInstanceType = (
str
| ConnectionData
| Connect
| MainValueInstance
| ExtendTemplate
| NullConnection
Expand Down
20 changes: 6 additions & 14 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
NOT_AVAILABLE,
NOT_GIVEN,
TBD,
Connect,
Connection,
ConnectionData,
Connections,
Expand Down Expand Up @@ -96,7 +95,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo:
continue
match con:
case Connection():
kwargs[key] = Connect(con, key=IOKey(value=val, expose=False))
kwargs[key] = IOKey(value=val, connections=[con])
# TODO: Maybe we could check con's value if matches with val
case item if isinstance(item, MainValueInstance) and con != val:
raise ValueError(
Expand All @@ -111,25 +110,18 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo:
f"Given IOKey for local key: '{key}' is not valid!"
)
else:
conns = [
item.conn if isinstance(item, ConnectionData) else item
for item in con._connections
]
kwargs[key] = IOKey(
name=con._name,
value=val,
shape=con._shape,
type=con._type,
expose=con._expose,
connections=conns,
)
case Connect():
if (io_key := con.key) is not None:
if io_key._value is not TBD and io_key._value != val:
raise ValueError(
"Given IOKey in Connect for "
f"local key: '{key}' is not valid!"
)
else:
io_key._value = val
else:
io_key = IOKey(value=val, expose=False)
kwargs[key] = Connect(*con.connections, key=io_key)
case ExtendTemplate():
raise ValueError(
"Multi-write detected for a valued "
Expand Down
34 changes: 15 additions & 19 deletions mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
NOT_AVAILABLE,
NOT_GIVEN,
TBD,
Connect,
Connection,
ConnectionData,
ConnectionInstanceType,
Expand Down Expand Up @@ -353,9 +352,12 @@ def _add_connection(
outer_key = con_obj.key
expose = outer_key in self.conns.output_keys and not is_input
match_connection = True
elif isinstance(given_connection, IOKey):
elif isinstance(
given_connection, IOKey
) and given_connection._connections == OrderedSet([]):
outer_key = given_connection._name
expose = given_connection._expose
if (expose := given_connection._expose) is None:
expose = True
if outer_key is None or self.conns.get_connection(outer_key) is None:
create_connection = True # Create new connection.
else:
Expand All @@ -374,16 +376,17 @@ def _add_connection(
"Expose flag cannot be false when "
"no value is provided for input keys!"
)
elif isinstance(given_connection, Connect):
elif isinstance(
given_connection, IOKey
) and given_connection._connections != OrderedSet([]):
match_connection = True
if (iokey := given_connection.key) is not None:
expose = iokey._expose
if iokey._name is not None:
outer_key = iokey._name
if iokey._value is not TBD:
set_value = iokey._value
expose = given_connection._expose
if given_connection._name is not None:
outer_key = given_connection._name
if given_connection._value is not TBD:
set_value = given_connection._value
initial_conn: ConnectionData
for idx, conn in enumerate(given_connection.connections):
for idx, conn in enumerate(given_connection._connections):
if isinstance(conn, str):
_conn = self.conns.get_connection(conn)
else:
Expand All @@ -409,7 +412,7 @@ def _add_connection(
if _conn in d_map:
if initial_conn in d_map:
raise KeyError(
"Connect object can not have more than one output "
"IOKey object can not have more than one output "
"connection. Multi-write error!"
)
initial_conn, _conn = _conn, initial_conn
Expand Down Expand Up @@ -460,7 +463,6 @@ def _add_connection(
local_connection.metadata.data.value is not TBD
and con_obj not in self.conns.input_connections
and not isinstance(given_connection, IOKey)
and not isinstance(given_connection, Connect)
):
expose = False
# If any value provided, set.
Expand Down Expand Up @@ -783,12 +785,6 @@ def extend(
"but the model canonical connections is not determined. Please "
"provide connection/key explicitly, or set canonical connections."
)
elif isinstance(value, Connect) and value.key is not None:
if value.key._shape is not None:
shape_info |= {key: value.key._shape}

if value.key._type is not None:
type_info[key] = value.key._type

if (updated_conn := self.create_connection_model(kwargs[key])) is not None:
kwargs[key] = updated_conn
Expand Down
10 changes: 5 additions & 5 deletions mithril/utils/dict_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ..framework.common import (
TBD,
AllValueType,
Connect,
ConnectionData,
GenericTensorType,
IOHyperEdge,
Expand Down Expand Up @@ -138,7 +137,7 @@ def dict_to_model(modelparams: dict[str, Any]) -> BaseModel:
for m_key, v in submodels.items():
m = dict_to_model(v)
submodels_dict[m_key] = m
mappings: dict[str, IOKey | float | int | list | tuple | str | Connect] = {}
mappings: dict[str, IOKey | float | int | list | tuple | str] = {}
for k, conn in connections[m_key].items():
if conn in unnamed_keys and k in m._input_keys:
continue
Expand All @@ -148,17 +147,18 @@ def dict_to_model(modelparams: dict[str, Any]) -> BaseModel:

elif isinstance(conn, dict):
if "connect" in conn:
key_kwargs = {}
if (key := conn.get("key")) is not None:
key_kwargs = create_iokey_kwargs(conn["key"])
key = IOKey(**key_kwargs)
mappings[k] = Connect(
*[
mappings[k] = IOKey(
**key_kwargs,
connections=[
getattr(submodels_dict[value[0]], value[1])
if isinstance(value, Sequence)
else value
for value in conn["connect"]
],
key=key,
)
elif "name" in conn:
key_kwargs = create_iokey_kwargs(conn)
Expand Down
3 changes: 1 addition & 2 deletions tests/scripts/test_constant_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from mithril.framework.common import (
NOT_GIVEN,
TBD,
Connect,
Connection,
ConnectionType,
IOKey,
Expand Down Expand Up @@ -970,7 +969,7 @@ def test_nontensor_extend_from_input_multiple_connection():
model += mean1
model += mean2
model += mean3
model += mean4(axis=Connect(mean1.axis, mean2.axis, mean3.axis))
model += mean4(axis=IOKey(connections=[mean1.axis, mean2.axis, mean3.axis]))
assert (
mean1.axis.data.metadata
== mean2.axis.data.metadata
Expand Down
17 changes: 5 additions & 12 deletions tests/scripts/test_io_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from mithril.models import (
Add,
Buffer,
Connect,
Linear,
Mean,
Model,
Expand Down Expand Up @@ -273,7 +272,7 @@ def test_7():
model += (relu1 := Relu())(input="in1", output="relu1_output")
model += (relu2 := Relu())(input="in2", output="relu2_output")
model += (relu3 := Relu())(
input="", output=Connect(relu1.input, relu2.input, key=IOKey(name="my_input"))
input="", output=IOKey(name="my_input", connections=[relu1.input, relu2.input])
)
assert (
model.dag[relu1]["input"].metadata
Expand Down Expand Up @@ -443,21 +442,15 @@ def test_iokey_shapes_3():
model += buff3(input="input3")

main_model = Model()
conn = Connect()
main_model += model(
input1=IOKey(name="input1", shape=["a", "b"]),
input2=IOKey(name="input2", shape=["b", "a"]),
input3=IOKey(name="input3", shape=[3, "a"]),
)

conn = Connect(
main_model.input1, # type: ignore
main_model.input2, # type: ignore
main_model.input3, # type: ignore
key=IOKey("input"),
)

main_model += Buffer()(input=conn, output="output1")
conns = [main_model.input1, main_model.input2, main_model.input3] # type: ignore
key = IOKey(name="input", connections=conns)
main_model += Buffer()(input=key, output="output1")

expected_shapes = {"$_Model_0_output": [3, 3], "output1": [3, 3], "input": [3, 3]}

Expand Down Expand Up @@ -1157,7 +1150,7 @@ def test_compare_models_5():
sigmoid = Sigmoid()
add = Add()
model2 += add(output=IOKey(name="output"))
conn = Connect(add.left, add.right)
conn = IOKey(connections=[add.left, add.right])
model2 += sigmoid(input="input", output=conn)
model2.set_shapes({"input": [2, 2]})

Expand Down
15 changes: 9 additions & 6 deletions tests/scripts/test_jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from mithril.models import (
TBD,
Add,
Connect,
CustomPrimitiveModel,
IOKey,
Item,
Expand Down Expand Up @@ -258,7 +257,7 @@ def test_logical_model_jittable_1():
model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1"))
model += (add2 := Add())(left="l3", right="l4")
with pytest.raises(Exception) as error_info:
model += Item()(input=Connect(add1.left, add2.left, key=IOKey(name="input")))
model += Item()(input=IOKey(name="input", connections=[add1.left, add2.left]))
modified_msg = re.sub("\\s*", "", str(error_info.value))
expected_msg = (
"Model with enforced Jit can not be extended by a non-jittable model! \
Expand All @@ -275,7 +274,8 @@ def test_logical_model_jittable_2():
model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1"))
model += (add2 := Add())(left="l3", right="l4")
model.enforce_jit = False
model += Item()(input=Connect(add1.left, add2.left, key=IOKey(name="input")))
input = IOKey(name="input", connections=[add1.left, add2.left], expose=True)
model += Item()(input=input)
assert not model.enforce_jit


Expand All @@ -287,7 +287,8 @@ def test_logical_model_jittable_3():
model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1"))
model += (add2 := Add())(left="l3", right="l4")
model.enforce_jit = False
model += Item()(input=Connect(add1.left, add2.left, key=IOKey(name="input")))
input = IOKey(name="input", connections=[add1.left, add2.left], expose=True)
model += Item()(input=input)
assert not model.enforce_jit


Expand All @@ -299,7 +300,8 @@ def test_physical_model_jit_1():
model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1"))
model += (add2 := Add())(left="l3", right="l4")
model.enforce_jit = False
model += Item()(input=Connect(add1.left, add2.left, key=IOKey(name="input")))
input = IOKey(name="input", connections=[add1.left, add2.left], expose=True)
model += Item()(input=input)

backend = JaxBackend()
compiled_model = compile(model=model, backend=backend, jit=False)
Expand All @@ -318,7 +320,8 @@ def test_physical_model_jit_2():
model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1"))
model += (add2 := Add())(left="l3", right="l4")
model.enforce_jit = False
model += Item()(input=Connect(add1.left, add2.left, key=IOKey(name="input")))
input = IOKey(name="input", connections=[add1.left, add2.left], expose=True)
model += Item()(input=input)

backend = JaxBackend()

Expand Down
Loading

0 comments on commit fce6951

Please sign in to comment.