diff --git a/examples/flux/auto_encoder.py b/examples/flux/auto_encoder.py index 1ed7fbed..5cda7307 100644 --- a/examples/flux/auto_encoder.py +++ b/examples/flux/auto_encoder.py @@ -66,7 +66,7 @@ def attn_block(n_channels: int, name: str | None = None): key = block.key # type: ignore[attr-defined] value = block.value # type: ignore[attr-defined] - shape = query.shape() + shape = query.shape query = query.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1])) key = key.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1])) @@ -92,7 +92,7 @@ def downsample(n_channels: int): def upsample(n_channels: int, name: str | None = None): block = Model(enforce_jit=False, name=name) # TODO: Remove enfor jit false input = IOKey("input") - input_shape = input.shape() + input_shape = input.shape B, C, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3] input = input[:, :, :, None, :, None] diff --git a/examples/flux/layers.py b/examples/flux/layers.py index 33a87547..dcf0758f 100644 --- a/examples/flux/layers.py +++ b/examples/flux/layers.py @@ -65,8 +65,8 @@ def apply_rope() -> Model: xk = IOKey("xk") freqs_cis = IOKey("freqs_cis") - xq_shape = xq.shape() - xk_shape = xk.shape() + xq_shape = xq.shape + xk_shape = xk.shape B, L, H = xq_shape[0], xq_shape[1], xq_shape[2] block += Reshape()(xq, shape=(B, L, H, -1, 1, 2), output="xq_") B, L, H = xk_shape[0], xk_shape[1], xk_shape[2] @@ -96,7 +96,7 @@ def attention() -> Model: ) # We can get named connection as model.'connection_name' - context_shape = block.context.shape() # type: ignore[attr-defined] + context_shape = block.context.shape # type: ignore[attr-defined] block += Transpose(axes=(0, 2, 1, 3))(block.context) # type: ignore[attr-defined] # NOTE: Reshape input is automatically connected to Transpose output block += Reshape()( @@ -137,7 +137,7 @@ def modulation(dim: int, double: bool, name: str | None = None): def rearrange(num_heads: int): block = Model() input = IOKey("input") - input_shaepe = input.shape() + input_shaepe = input.shape B, L = input_shaepe[0], input_shaepe[1] block += Reshape()(shape=(B, L, 3, num_heads, -1)) block += Transpose(axes=(2, 0, 3, 1, 4))(output=IOKey("output")) @@ -209,7 +209,7 @@ def double_stream_block( block += Concat(axis=2, n=2)(input1=txt_v, input2=img_v, output="v_concat") block += attention()(q="q_concat", k="k_concat", v="v_concat", pe=pe, output="attn") - # TODO: use'[:, txt.shape()[1] :]' when fixed. + # TODO: use'[:, txt.shape[1] :]' when fixed. img_attn = block.attn[:, 256:] # type: ignore[attr-defined] block += Linear(hidden_size, name="img_attn_proj")(img_attn, output="img_proj") @@ -234,7 +234,7 @@ def double_stream_block( ) img = img + block.img_mod_2[2] * block.img_mlp # type: ignore[attr-defined] - # TODO: Use txt.shape()[1]] + # TODO: Use txt.shape[1]] txt_attn = block.attn[:, :256] # type: ignore[attr-defined] block += Linear(hidden_size, name="txt_attn_proj")(txt_attn, output="txt_proj") @@ -355,7 +355,7 @@ def rope(dim: int, theta: int) -> Model: omega = 1.0 / (theta ** (block.arange / dim)) # type: ignore out = input[..., None] * omega - out_shape = out.shape() + out_shape = out.shape B, N, D = out_shape[0], out_shape[1], out_shape[2] block += Cosine()(out, output="cos") diff --git a/examples/gpt/model.py b/examples/gpt/model.py index 4d503a28..4a73aaa2 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/model.py @@ -41,7 +41,7 @@ def causal_attention(input_dim, num_heads, bias=True): model += Linear(input_dim * 3, name="c_attn")("input", output="c_attn_out") t_axes = (0, 2, 1, 3) - shp_con = model.input.shape() # type: ignore + shp_con = model.input.shape # type: ignore reshape_con = (shp_con[0], shp_con[1], num_heads, -1) model += Split(3, axis=-1)(model.c_attn_out, output="split_out") # type: ignore diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 3a3c75f3..90192e30 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -40,7 +40,7 @@ constant_type_table, epsilon_table, ) -from ..utils.utils import OrderedSet, PaddingType, find_dominant_type +from ..utils.utils import PaddingType, find_dominant_type from .utils import ( NestedListType, align_shapes, @@ -1001,6 +1001,7 @@ def abs(self): def len(self): return ExtendTemplate(connections=[self], model="len") + @property def shape(self): return ExtendTemplate(connections=[self], model="shape") @@ -1100,6 +1101,14 @@ def __init__( self.output_connection = None +@dataclass +class BaseKey: + value: TensorValueType | MainValueType | ToBeDetermined | str = TBD + shape: ShapeTemplateType | None = None + type: NestedListType | UnionType | type | None = None + interval: list[float | int] | None = None + + class IOKey(TemplateBase): def __init__( self, @@ -1109,38 +1118,35 @@ def __init__( type: NestedListType | UnionType | type | None = None, expose: bool | None = None, interval: list[float | int] | None = None, - connections: list[Connection | str] | None = None, + connections: set[Connection | str] | None = None, ) -> None: super().__init__() - self._name = name - self._value = value - self._shape = shape - self._type = type - self._expose = expose - self._interval = interval - self._connections: OrderedSet[ConnectionData | str] = OrderedSet() + self.name = name + self.expose = expose + if connections is None: + connections = set() + self.connections: set[Connection | str] = connections + self.data = BaseKey(value, shape, type, interval) # TODO: Shape should not be [] also! - if self._value is not TBD and self._shape is not None and self._shape != []: + if ( + self.data.value is not TBD + and self.data.shape is not None + and self.data.shape != [] + ): raise ValueError( f"Scalar values are shapeless, shape should be None or []. " - f"Got {self._shape}." + f"Got {self.data.shape}." ) - if self._value is not TBD and self._type is not None: - value_type = find_type(self._value) - if find_intersection_type(value_type, self._type) is None: + if self.data.value is not TBD and self.data.type is not None: + value_type = find_type(self.data.value) + if find_intersection_type(value_type, self.data.type) is None: raise TypeError( f"type of the given value and given type does not match. Given " - f"type is {self._type} while type of value is {value_type}" + f"type is {self.data.type} while type of value is {value_type}" ) - connections = connections or [] - for item in connections: - conn: ConnectionData | str - conn = item.data if isinstance(item, Connection) else item - self._connections.add(conn) - class Connection(TemplateBase): def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool): diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index b6e475f0..74e30fe1 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -95,7 +95,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: continue match con: case Connection(): - kwargs[key] = IOKey(value=val, connections=[con]) + kwargs[key] = IOKey(value=val, connections={con}) # TODO: Maybe we could check con's value if matches with val case item if isinstance(item, MainValueInstance) and con != val: raise ValueError( @@ -103,24 +103,20 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: f"has already being set to {val}!" ) case str(): - kwargs[key] = IOKey(con, value=val, expose=False) + kwargs[key] = IOKey(name=con, value=val, expose=False) case IOKey(): - if con._value is not TBD and con._value != val: + if con.data.value is not TBD and con.data.value != val: raise ValueError( f"Given IOKey for local key: '{key}' is not valid!" ) else: - _conns: list[Connection | str] = [ - item.conn if isinstance(item, ConnectionData) else item - for item in con._connections - ] kwargs[key] = IOKey( - name=con._name, + name=con.name, + expose=con.expose, + connections=con.connections, + type=con.data.type, + shape=con.data.shape, value=val, - shape=con._shape, - type=con._type, - expose=con._expose, - connections=_conns, ) case ExtendTemplate(): raise ValueError( diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 4235719a..051d9c94 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -20,10 +20,10 @@ from ..common import ( NOT_GIVEN, TBD, + BaseKey, Connection, ConnectionType, GenericTensorType, - IOKey, MyTensor, ShapeTemplateType, TensorValueType, @@ -123,8 +123,8 @@ def __init__( super().__init__( formula_key="buffer", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -149,10 +149,10 @@ def __init__( ) -> None: self.factory_args = {"n": n} key_definitions = { - "output": IOKey(type=tuple[int | float | bool | list | tuple, ...]) + "output": BaseKey(type=tuple[int | float | bool | list | tuple, ...]) } key_definitions |= { - f"input{idx+1}": IOKey(type=int | float | bool | list | tuple) + f"input{idx+1}": BaseKey(type=int | float | bool | list | tuple) for idx in range(n) } self.factory_inputs = kwargs # type: ignore @@ -173,9 +173,9 @@ def __init__(self, formula_key: str, name: str | None = None) -> None: super().__init__( formula_key=formula_key, name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self._set_constraint( @@ -216,18 +216,18 @@ def __init__( super().__init__( formula_key="robust_power", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - base=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - exponent=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - threshold=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + base=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + exponent=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + threshold=BaseKey(shape=[], type=GenericTensorType), ) self.threshold.set_differentiable(False) # type: ignore else: super().__init__( formula_key="power", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - base=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - exponent=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + base=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + exponent=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self._set_constraint( @@ -303,9 +303,9 @@ def __init__( super().__init__( formula_key="divide", name=name, - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), - numerator=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - denominator=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), + numerator=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + denominator=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.factory_inputs = {"numerator": numerator, "denominator": denominator} self._set_constraint( @@ -339,9 +339,9 @@ def __init__( super().__init__( formula_key="floor_divide", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - numerator=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - denominator=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + numerator=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + denominator=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.factory_inputs = {"numerator": numerator, "denominator": denominator} @@ -378,9 +378,9 @@ def __init__( super().__init__( formula_key="matrix_multiplication", name=name, - output=IOKey(shape=[("Var3", ...), "x", "z"], type=GenericTensorType), - left=IOKey(shape=[("Var1", ...), "x", "y"], type=GenericTensorType), - right=IOKey(shape=[("Var2", ...), "y", "z"], type=GenericTensorType), + output=BaseKey(shape=[("Var3", ...), "x", "z"], type=GenericTensorType), + left=BaseKey(shape=[("Var1", ...), "x", "y"], type=GenericTensorType), + right=BaseKey(shape=[("Var2", ...), "y", "z"], type=GenericTensorType), ) self.factory_inputs = {"left": left, "right": right} self._set_constraint( @@ -410,8 +410,8 @@ def __init__( super().__init__( formula_key="shape", name=name, - output=IOKey(shape=[], type=tuple[int, ...]), - input=IOKey(shape=[("input", ...)], type=GenericTensorType), + output=BaseKey(shape=[], type=tuple[int, ...]), + input=BaseKey(shape=[("input", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} self._set_constraint(fn=shape_constraints, keys=["output", "input"]) @@ -442,9 +442,9 @@ def __init__( super().__init__( formula_key="reshape", name=name, - output=IOKey(shape=output_shape_map, type=GenericTensorType), - input=IOKey(shape=[("input", ...)], type=GenericTensorType), - shape=IOKey(type=tuple[int | None, ...] | list[int | None], value=shape), + output=BaseKey(shape=output_shape_map, type=GenericTensorType), + input=BaseKey(shape=[("input", ...)], type=GenericTensorType), + shape=BaseKey(type=tuple[int | None, ...] | list[int | None], value=shape), ) self.factory_inputs = {"input": input, "shape": shape} self._set_constraint(fn=reshape_constraints, keys=["output", "input", "shape"]) @@ -468,8 +468,8 @@ def __init__( super().__init__( formula_key="length", name=name, - output=IOKey(type=int), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=int), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -490,9 +490,9 @@ def __init__( super().__init__( formula_key="astype", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - dtype=IOKey(type=Dtype), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + dtype=BaseKey(type=Dtype), ) self.factory_inputs = {"dtype": dtype} @@ -515,8 +515,8 @@ def __init__( super().__init__( formula_key="dtype", name=name, - output=IOKey(type=core.Dtype), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=core.Dtype), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -541,9 +541,9 @@ def __init__( super().__init__( formula_key="size", name=name, - output=IOKey(type=int | tuple[int, ...]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - dim=IOKey(type=int | tuple[int, ...] | None, value=dim), + output=BaseKey(type=int | tuple[int, ...]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + dim=BaseKey(type=int | tuple[int, ...] | None, value=dim), ) self.factory_inputs = {"input": input} self._set_constraint(fn=size_constraints, keys=["output", "input", "dim"]) @@ -576,13 +576,15 @@ def __init__( super().__init__( formula_key="sequence_slice", name=name, - output=IOKey( + output=BaseKey( type=tuple[int | float | bool, ...] | list[int | float | bool] ), - input=IOKey(type=tuple[int | float | bool, ...] | list[int | float | bool]), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), + input=BaseKey( + type=tuple[int | float | bool, ...] | list[int | float | bool] + ), + start=BaseKey(type=int | None, value=start), + stop=BaseKey(type=int | None, value=stop), + step=BaseKey(type=int | None, value=step), ) self.factory_inputs = {"input": input} @@ -623,11 +625,11 @@ def __init__( super().__init__( formula_key="tensor_slice", name=name, - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - input=IOKey(shape=["b", ("Var1", ...)], type=GenericTensorType), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["b", ("Var1", ...)], type=GenericTensorType), + start=BaseKey(type=int | None, value=start), + stop=BaseKey(type=int | None, value=stop), + step=BaseKey(type=int | None, value=step), ) self.factory_inputs = {"input": input} @@ -662,8 +664,8 @@ def __init__( super().__init__( formula_key="item", name=name, - output=IOKey(type=int | float), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=int | float), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} self._set_constraint( @@ -692,9 +694,9 @@ def __init__( super().__init__( formula_key="scalar_item", name=name, - output=IOKey(type=int | float | list | tuple), - input=IOKey(type=list | tuple), - index=IOKey(type=int, value=index), + output=BaseKey(type=int | float | list | tuple), + input=BaseKey(type=list | tuple), + index=BaseKey(type=int, value=index), ) self.factory_inputs = {"input": input, "index": index} @@ -730,9 +732,9 @@ def __init__( super().__init__( formula_key="tensor_item", name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - index=IOKey( + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + index=BaseKey( type=int | slice | EllipsisType @@ -770,8 +772,8 @@ def __init__( super().__init__( formula_key="to_tensor", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(type=int | float | list | tuple), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(type=int | float | list | tuple), ) self._set_constraint( @@ -791,9 +793,11 @@ class ToList(PrimitiveModel): def __init__(self, n: int, name: str | None = None, **kwargs) -> None: self.factory_args = {"n": n} key_definitions = {} - key_definitions["output"] = IOKey(type=list[int | float | bool | list | tuple]) + key_definitions["output"] = BaseKey( + type=list[int | float | bool | list | tuple] + ) key_definitions |= { - f"input{idx+1}": IOKey(type=int | float | bool | list | tuple) + f"input{idx+1}": BaseKey(type=int | float | bool | list | tuple) for idx in range(n) } self.factory_inputs = kwargs @@ -816,8 +820,8 @@ def __init__( super().__init__( formula_key="tensor_to_list", name=name, - output=IOKey(type=NestedListType(int | float | bool)), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=NestedListType(int | float | bool)), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} self._set_constraint( @@ -847,7 +851,7 @@ def __init__( name: str | None = None, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool | ToBeDetermined = False, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: # TODO: Handle axis type for conditional cases below. self.factory_args = {"axis": axis, "keepdim": keepdim} @@ -862,11 +866,11 @@ def __init__( else: raise ValueError("Requires valid axis type!") - init_kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - "input": IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - "axis": IOKey(type=axis_type, value=axis), - "keepdim": IOKey(type=bool, value=keepdim), + init_kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + "input": BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + "axis": BaseKey(type=axis_type, value=axis), + "keepdim": BaseKey(type=bool, value=keepdim), } super().__init__(formula_key=formula_key, name=name, **(init_kwargs | kwargs)) @@ -899,7 +903,7 @@ def __init__( name=name, axis=axis, keepdim=keepdim, - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "axis": axis, "keepdim": keepdim} # self.factory_inputs = {"input": input} @@ -954,7 +958,7 @@ def __init__( axis=axis, keepdim=keepdim, # axis = Scalar(axis_type, axis), # TODO: Change axis type to int - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input, "axis": axis, "keepdim": keepdim} @@ -991,7 +995,7 @@ def __init__( axis=axis, keepdim=keepdim, # axis = Scalar(axis_type, axis), # TODO: Change axis type to int - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input, "axis": axis, "keepdim": keepdim} @@ -1029,8 +1033,8 @@ def __init__( name=name, axis=axis, keepdim=keepdim, - correction=IOKey(type=float | int | None, value=correction), - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), + correction=BaseKey(type=float | int | None, value=correction), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), ) self.factory_args = {"axis": axis, "correction": correction, "keepdim": keepdim} # TODO: Should we remove axis, correction and keepdim from factory_args? @@ -1067,14 +1071,14 @@ def __init__( formula_key: str, polymorphic_constraint: bool = True, name: str | None = None, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: default_kwargs = dict( - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) # Finalize kwargs. - new_kwargs: Mapping[str, IOKey] = default_kwargs | kwargs + new_kwargs: Mapping[str, BaseKey] = default_kwargs | kwargs super().__init__(formula_key, name=name, **new_kwargs) if polymorphic_constraint: @@ -1113,7 +1117,7 @@ def __init__( formula_key="exp", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input} @@ -1137,16 +1141,16 @@ def __init__( super().__init__( formula_key="robust_sqrt", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "cutoff": cutoff} else: super().__init__( formula_key="sqrt", - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1177,9 +1181,9 @@ def __init__(self, formula_key: str, name: str | None = None) -> None: super().__init__( formula_key=formula_key, name=name, - output=IOKey(shape=[("Var1", ...)], type=MyTensor[bool]), - left=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var3", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=MyTensor[bool]), + left=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var3", ...)], type=GenericTensorType), ) self._set_constraint(bcast, ["output", "left", "right"]) @@ -1269,8 +1273,8 @@ def __init__( super().__init__( formula_key="logical_not", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[bool]), - input=IOKey(shape=[("Var", ...)], type=MyTensor[bool]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[bool]), + input=BaseKey(shape=[("Var", ...)], type=MyTensor[bool]), ) self.factory_inputs = {"input": input} @@ -1295,9 +1299,9 @@ def __init__( super().__init__( formula_key=formula_key, name=name, - output=IOKey(shape=[("Var1", ...)], type=MyTensor[bool]), - left=IOKey(shape=[("Var2", ...)], type=MyTensor[bool]), - right=IOKey(shape=[("Var3", ...)], type=MyTensor[bool]), + output=BaseKey(shape=[("Var1", ...)], type=MyTensor[bool]), + left=BaseKey(shape=[("Var2", ...)], type=MyTensor[bool]), + right=BaseKey(shape=[("Var3", ...)], type=MyTensor[bool]), ) self.factory_inputs = {"left": left, "right": right} self._set_constraint(bcast, ["output", "left", "right"]) @@ -1359,9 +1363,9 @@ def __init__( super().__init__( formula_key="shift_left", name=name, - output=IOKey(shape=[("Var3", ...)], type=MyTensor[int]), - input=IOKey(shape=[("Var1", ...)], type=MyTensor[int]), - shift=IOKey(shape=[("Var2", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var3", ...)], type=MyTensor[int]), + input=BaseKey(shape=[("Var1", ...)], type=MyTensor[int]), + shift=BaseKey(shape=[("Var2", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input, "shift": shift} @@ -1390,9 +1394,9 @@ def __init__( super().__init__( formula_key="shift_right", name=name, - output=IOKey(shape=[("Var3", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - shift=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var3", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + shift=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input, "shift": shift} @@ -1426,9 +1430,9 @@ def __init__( super().__init__( formula_key="transpose", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - axes=IOKey(type=NoneType, value=axes), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + axes=BaseKey(type=NoneType, value=axes), ) self.factory_inputs = {"input": input, "axes": axes} self._set_constraint( @@ -1442,18 +1446,18 @@ def __init__( super().__init__( formula_key="transpose", name=name, - output=IOKey(shape=output_shapes, type=GenericTensorType), - input=IOKey(shape=input_shapes, type=GenericTensorType), - axes=IOKey(type=int | tuple[int, ...], value=axes), + output=BaseKey(shape=output_shapes, type=GenericTensorType), + input=BaseKey(shape=input_shapes, type=GenericTensorType), + axes=BaseKey(type=int | tuple[int, ...], value=axes), ) elif axes is TBD: super().__init__( formula_key="transpose", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - axes=IOKey(type=int | tuple[int, ...] | None, value=axes), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + axes=BaseKey(type=int | tuple[int, ...] | None, value=axes), ) self._set_constraint( fn=reverse_constraints, keys=["output", "input", "axes"] @@ -1488,10 +1492,10 @@ def __init__( super().__init__( formula_key="split", name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - split_size=IOKey(type=int, value=split_size), - axis=IOKey(type=int, value=axis), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + split_size=BaseKey(type=int, value=split_size), + axis=BaseKey(type=int, value=axis), ) self.factory_inputs = {"input": input, "split_size": split_size, "axis": axis} @@ -1527,10 +1531,10 @@ def __init__( super().__init__( formula_key="primitive_slice", name=name, - output=IOKey(type=slice), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), + output=BaseKey(type=slice), + start=BaseKey(type=int | None, value=start), + stop=BaseKey(type=int | None, value=stop), + step=BaseKey(type=int | None, value=step), ) self.factory_inputs = {"start": start, "stop": stop, "step": step} diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 87213225..e9d976ea 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -311,7 +311,7 @@ def _convert_to_iokey( case str(): connection = IOKey(name=connection) case Connection(): - connection = IOKey(connections=[connection]) + connection = IOKey(connections={connection}) case ExtendTemplate(): # Unroll ExtendTemplate template_conn = model.conns.get_connection(key) @@ -319,7 +319,7 @@ def _convert_to_iokey( con_data = self._unroll_template( connection, type(template_conn.metadata.data) ) - connection = IOKey(connections=[con_data.conn], expose=False) + connection = IOKey(connections={con_data.conn}, expose=False) case _ if isinstance(connection, MainValueInstance): # find_dominant_type returns the dominant type in a container. # If a container has a value of type Connection or ExtendTemplate @@ -338,34 +338,29 @@ def _convert_to_iokey( result = conv_model.conns.get_connection("output") assert result is not None - connection = IOKey(connections=[result.conn], expose=None) + connection = IOKey(connections={result.conn}, expose=None) else: assert isinstance(connection, MainValueInstance) connection = IOKey(value=connection) case IOKey(): - expose = connection._expose - name = connection._name - # TODO: This check should be removed: conn._connections==OrderedSet([]) + expose = connection.expose + name = connection.name + # TODO: This check should be removed: conn.connections==set() # We should not operate different if _connections is given. Fix this and # also fix corresponding tests and dict conversions with "connect". if ( expose is None and (name is None or self.conns.get_connection(name) is None) - and connection._connections == OrderedSet([]) + and connection.connections == set() ): expose = True - _conns: list[Connection | str] = [ - item.conn if isinstance(item, ConnectionData) else item - for item in connection._connections - ] - # TODO: Add replicate method to IOKey (update def __call__ in BaseModel) connection = IOKey( name=name, - value=connection._value, - shape=connection._shape, - type=connection._type, expose=expose, - connections=_conns, + connections=connection.connections, + type=connection.data.type, + shape=connection.data.shape, + value=connection.data.value, ) case NotAvailable(): raise ValueError( @@ -402,14 +397,14 @@ def _add_connection( is_not_valued = local_connection.metadata.data.value is TBD d_map = self.dependency_map._local_output_dependency_map - expose = given_connection._expose - outer_key = given_connection._name + expose = given_connection.expose + outer_key = given_connection.name con_obj = None set_value: ToBeDetermined | str | MainValueType | NullConnection = NOT_GIVEN - if given_connection._value is not TBD: - set_value = given_connection._value + if given_connection.data.value is not TBD: + set_value = given_connection.data.value - if given_connection._connections == OrderedSet([]): + if given_connection.connections == set(): if outer_key is not None: con_obj = self.conns.get_connection(outer_key) if outer_key is None or con_obj is None: @@ -429,18 +424,19 @@ def _add_connection( ) else: initial_conn: ConnectionData - for idx, conn in enumerate(given_connection._connections): + for idx, conn in enumerate(given_connection.connections): if isinstance(conn, str): _conn = self.conns.get_connection(conn) else: - _conn = self.conns.get_con_by_metadata(conn.metadata) + _conn = self.conns.get_con_by_metadata(conn.data.metadata) + if conn.data in model.conns.all.values(): + raise ValueError( + f"Given connection '{conn.data.key}' should not " + "belong to the extending model!" + ) + if not isinstance(_conn, ConnectionData): raise KeyError("Requires accessible connection to be processed!") - elif conn in model.conns.all.values(): - raise ValueError( - f"Given connection '{conn.key}' should not " # type: ignore - "belong to the extending model!" - ) if idx == 0: initial_conn = _conn if outer_key is not None: @@ -716,11 +712,11 @@ def extend( } for local_key, value in io_keys.items(): - if value._shape is not None: - shape_info |= {local_key: value._shape} + if value.data.shape is not None: + shape_info |= {local_key: value.data.shape} - if value._type is not None: - type_info[local_key] = value._type + if value.data.type is not None: + type_info[local_key] = value.data.type con_obj, _updates = self._add_connection(model, local_key, value) updates |= _updates @@ -813,7 +809,7 @@ def _extend(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: kwargs[model._canonical_input.key] = self.canonical_output for key, value in kwargs.items(): - _value = value._name if isinstance(value, IOKey) else value + _value = value.name if isinstance(value, IOKey) else value if isinstance(_value, str) and _value == "": if key in model._input_keys: @@ -824,7 +820,7 @@ def _extend(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: ) if isinstance(value, IOKey): - value._name = None + value.name = None else: kwargs[key] = _value diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 9d12df5e..f80b1371 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -20,9 +20,9 @@ from ..common import ( NOT_AVAILABLE, TBD, + BaseKey, Connection, IOHyperEdge, - IOKey, KeyType, NotAvailable, Scalar, @@ -54,7 +54,7 @@ def __init__( self, formula_key: str, name: str | None = None, - **kwargs: IOKey | Tensor | Scalar, + **kwargs: BaseKey | Tensor | Scalar, ) -> None: self._formula_key = formula_key self.grad_formula = formula_key + "_grad" @@ -62,9 +62,9 @@ def __init__( super().__init__(name=name) # Get shape_templates of TensorTypes and create corresponding shapes. shape_templates = { - key: value._shape + key: value.shape for key, value in kwargs.items() - if isinstance(value, IOKey) and value._shape is not None + if isinstance(value, BaseKey) and value.shape is not None } shapes = create_shape_map(shape_templates, self.constraint_solver) data_set: set[Tensor] = set() @@ -73,8 +73,9 @@ def __init__( for key, value in kwargs.items(): # TODO: The first if block is temporary. All if else blocks will be # removed after the implementation of the new type system. - if get_origin(value._type) is Union: - args = get_args(value._type) + value_type = value.type if isinstance(value, BaseKey) else value._type + if get_origin(value_type) is Union: + args = get_args(value_type) types = [] for _type in args: # TODO: assertion will be removed, @@ -83,30 +84,30 @@ def __init__( types.append(get_mytensor_subtype(_type)) possible_types = reduce(lambda x, y: x | y, types) # type: ignore - assert isinstance(value, IOKey) + assert isinstance(value, BaseKey) _value: Tensor | Scalar = Tensor( shape=shapes[key].node, possible_types=possible_types, - value=value._value, # type: ignore - interval=value._interval, + value=value.value, # type: ignore + interval=value.interval, ) assert isinstance(_value, Tensor) data_set.add(_value) - elif is_mytensor_type(value._type): - assert isinstance(value, IOKey) + elif is_mytensor_type(value_type): + assert isinstance(value, BaseKey) _value = Tensor( shape=shapes[key].node, - possible_types=get_mytensor_subtype(value._type), # type: ignore - value=value._value, # type: ignore - interval=value._interval, + possible_types=get_mytensor_subtype(value_type), # type: ignore + value=value.value, # type: ignore + interval=value.interval, ) data_set.add(_value) elif isinstance(value, Tensor | Scalar): _value = value else: _value = Scalar( - possible_types=value._type, # type: ignore - value=value._value, # type: ignore + possible_types=value_type, # type: ignore + value=value.value, # type: ignore ) conn_data = self.create_connection(IOHyperEdge(_value), key) diff --git a/mithril/models/models.py b/mithril/models/models.py index 26917614..47ad5b2a 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -776,7 +776,7 @@ def __init__( # Assumed input shape is [N, C, H, W] input_key = IOKey(name="input") - input_shape = input_key.shape() + input_shape = input_key.shape B = input_shape[0] input_key = input_key.reshape((B, num_groups, -1)) @@ -2899,7 +2899,7 @@ def __init__( )("pred", "label", "metric_out", "pred_formatted", "label_formatted") true_predictions = self.metric_out == 0 - n_prediction = self.label_formatted.shape()[0] + n_prediction = self.label_formatted.shape[0] self += Sum()(input=true_predictions, output="n_true_predictions") self += Divide()( @@ -3009,13 +3009,13 @@ def __init__( self += Divide()( numerator=sum_precision, - denominator=self.n_classes.shape()[0].tensor(), + denominator=self.n_classes.shape[0].tensor(), output=IOKey(name="output"), ) elif average == "weighted": precision = None - n_element = self.label_formatted.shape()[0] + n_element = self.label_formatted.shape[0] assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" @@ -3153,7 +3153,7 @@ def __init__( self += Divide()( numerator=sum_recall, - denominator=self.n_classes.shape()[0].tensor(), + denominator=self.n_classes.shape[0].tensor(), output=IOKey(name="output"), ) @@ -3162,7 +3162,7 @@ def __init__( assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" - n_element = self.label_formatted.shape()[0] + n_element = self.label_formatted.shape[0] for idx in range(n_classes): class_idxs = self.label_formatted == idx true_positive = (self.metric_out == 0) & class_idxs @@ -3299,7 +3299,7 @@ def __init__( self += Unique()(input=self.label_formatted, output="n_classes") self += Divide()( numerator=sum_precision, - denominator=self.n_classes.shape()[0].tensor(), + denominator=self.n_classes.shape[0].tensor(), output=IOKey(name="output"), ) @@ -3308,7 +3308,7 @@ def __init__( assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" - n_element = self.label_formatted.shape()[0].tensor() + n_element = self.label_formatted.shape[0].tensor() for idx in range(n_classes): class_idxs = self.label_formatted == idx true_positive = (self.metric_out == 0) & class_idxs diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 9aaf7277..a3beff13 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -20,10 +20,10 @@ from ..framework.common import ( NOT_GIVEN, TBD, + BaseKey, Connection, ConnectionType, GenericTensorType, - IOKey, MyTensor, TensorValueType, ToBeDetermined, @@ -129,7 +129,7 @@ class CustomPrimitiveModel(PrimitiveModel): def __init__( - self, formula_key: str, name: str | None = None, **kwargs: IOKey + self, formula_key: str, name: str | None = None, **kwargs: BaseKey ) -> None: self.factory_args = {"formula_key": formula_key} | kwargs super().__init__(formula_key=formula_key, name=name, **kwargs) @@ -155,12 +155,12 @@ def __init__( formula_key: str, polymorphic_constraint: bool = True, name: str | None = None, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: - default_kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - "input": IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - "target": IOKey(shape=[("Var_3", ...)], type=GenericTensorType), + default_kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + "input": BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + "target": BaseKey(shape=[("Var_3", ...)], type=GenericTensorType), } # Finalize kwargs. kwargs = default_kwargs | kwargs @@ -224,7 +224,7 @@ def __init__( polymorphic_constraint=False, formula_key="hinge_loss", name=name, - output=IOKey(shape=["N", ("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=["N", ("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "target": target} @@ -240,7 +240,7 @@ def __init__( polymorphic_constraint=False, formula_key="quad_hinge_loss", name=name, - output=IOKey(shape=["N", ("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=["N", ("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "target": target} @@ -265,10 +265,10 @@ def __init__( super().__init__( formula_key="quantile_loss", name=name, - output=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - target=IOKey(shape=[("Var_3", ...)], type=GenericTensorType), - quantile=IOKey(shape=[], type=MyTensor[int] | MyTensor[float]), + output=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + target=BaseKey(shape=[("Var_3", ...)], type=GenericTensorType), + quantile=BaseKey(shape=[], type=MyTensor[int] | MyTensor[float]), ) self.factory_inputs = {"input": input, "target": target, "quantile": quantile} @@ -337,14 +337,14 @@ def __init__( else: final_weights = weights - kwargs: dict[str, IOKey] = { - "output": IOKey(shape=["N", ("Var", ...)], type=MyTensor[float]), - "input": IOKey(shape=["N", "C", ("Var", ...)], type=GenericTensorType), - "target": IOKey(shape=["N", ("VarTarget", ...)], type=GenericTensorType), - "weights": IOKey(type=weights_type, value=final_weights), - "categorical": IOKey(type=bool), - "cutoff": IOKey(shape=[], type=GenericTensorType), - "robust": IOKey(type=bool), + kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=["N", ("Var", ...)], type=MyTensor[float]), + "input": BaseKey(shape=["N", "C", ("Var", ...)], type=GenericTensorType), + "target": BaseKey(shape=["N", ("VarTarget", ...)], type=GenericTensorType), + "weights": BaseKey(type=weights_type, value=final_weights), + "categorical": BaseKey(type=bool), + "cutoff": BaseKey(shape=[], type=GenericTensorType), + "robust": BaseKey(type=bool), } if input_type == "logits": @@ -427,10 +427,10 @@ def __init__( super().__init__( formula_key="kl_divergence", name=name, - output=IOKey(shape=[("Var_1", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - target=IOKey(shape=[("Var_3", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var_1", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + target=BaseKey(shape=[("Var_3", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "target": target, "cutoff": cutoff} @@ -490,15 +490,15 @@ def __init__( pos_weight_type = ( float | bool if pos_weight in (..., None) else type(pos_weight) ) - kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), - "input": IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - "target": IOKey( + kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), + "input": BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + "target": BaseKey( shape=[("Var_out", ...)], type=MyTensor[int] | MyTensor[float] ), - "pos_weight": IOKey(type=pos_weight_type, value=pos_weight), - "cutoff": IOKey(shape=[], type=GenericTensorType), - "robust": IOKey(type=bool), + "pos_weight": BaseKey(type=pos_weight_type, value=pos_weight), + "cutoff": BaseKey(shape=[], type=GenericTensorType), + "robust": BaseKey(type=bool), } if input_type == "logits": @@ -560,17 +560,17 @@ def __init__( super().__init__( formula_key="robust_log", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "cutoff": cutoff} else: super().__init__( formula_key="log", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -606,9 +606,9 @@ def __init__( super().__init__( formula_key="stable_reciprocal", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "cutoff": cutoff} @@ -629,7 +629,7 @@ def __init__( formula_key="sin", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input} @@ -642,7 +642,7 @@ def __init__( formula_key="cos", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input} @@ -655,7 +655,7 @@ def __init__( formula_key="sign", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input} @@ -678,15 +678,15 @@ def __init__( formula_key: str, polymorphic_constraint: bool = False, name: str | None = None, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: # NOTE: Torch and JAX behave different for some activation functions. # For example JAX handles int type inputs for GELU or LeakyRelu while # Torch assumes only float inputs for these activations. Since JAX handles # more general case, default types are written taking this into account. - default_kwargs: dict[str, IOKey] = dict( - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + default_kwargs: dict[str, BaseKey] = dict( + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) # Finalize kwargs. kwargs = default_kwargs | kwargs @@ -712,8 +712,8 @@ def __init__( formula_key="relu", name=name, polymorphic_constraint=True, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -727,7 +727,7 @@ def __init__( ) -> None: super().__init__( formula_key="gelu", - approximate=IOKey(value=approximate, type=(bool)), + approximate=BaseKey(value=approximate, type=(bool)), name=name, ) self.factory_inputs = {"input": input} @@ -758,7 +758,9 @@ def __init__( input: TensorValueType | ToBeDetermined = TBD, axis: int | None | ToBeDetermined = TBD, ) -> None: - super().__init__(formula_key="softmax", name=name, axis=IOKey(type=int | None)) + super().__init__( + formula_key="softmax", name=name, axis=BaseKey(type=int | None) + ) self.factory_inputs = {"input": input, "axis": axis} def __call__( # type: ignore[override] @@ -800,7 +802,7 @@ def __init__( super().__init__( formula_key="leaky_relu", name=name, - slope=IOKey(shape=[], type=MyTensor[float]), + slope=BaseKey(shape=[], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "slope": slope} @@ -823,8 +825,8 @@ def __init__( super().__init__( formula_key="stop_gradient", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -848,9 +850,9 @@ def __init__( super().__init__( formula_key="cartesian_diff", name=name, - output=IOKey(shape=["N", "M", "dim"], type=GenericTensorType), - left=IOKey(shape=["N", "dim"], type=GenericTensorType), - right=IOKey(shape=["M", "dim"], type=GenericTensorType), + output=BaseKey(shape=["N", "M", "dim"], type=GenericTensorType), + left=BaseKey(shape=["N", "dim"], type=GenericTensorType), + right=BaseKey(shape=["M", "dim"], type=GenericTensorType), ) self.factory_inputs = {"left": left, "right": right} self._set_constraint( @@ -880,17 +882,17 @@ def __init__( ) -> None: self.factory_args = {"n": n, "axis": axis} - key_definitions: dict[str, IOKey] = {} - key_definitions["output"] = IOKey( + key_definitions: dict[str, BaseKey] = {} + key_definitions["output"] = BaseKey( shape=[("Var_out", ...)], type=GenericTensorType ) key_definitions |= { - f"input{idx+1}": IOKey( + f"input{idx+1}": BaseKey( shape=[(f"Var_{idx + 1}", ...)], type=GenericTensorType ) for idx in range(n) } - key_definitions["axis"] = IOKey(type=int | None, value=axis) + key_definitions["axis"] = BaseKey(type=int | None, value=axis) super().__init__(formula_key="concat", name=name, **key_definitions) # self.factory_inputs = {key: value for key, value in kwargs.items()} @@ -917,14 +919,14 @@ def __init__( ) -> None: self.factory_args = {"n": n} input_definitions = { - f"input{idx + 1}": IOKey(type=int | float | tuple[int | float, ...]) + f"input{idx + 1}": BaseKey(type=int | float | tuple[int | float, ...]) for idx in range(n) } super().__init__( formula_key="union", name=name, - output=IOKey(type=tuple[int | float, ...]), + output=BaseKey(type=tuple[int | float, ...]), **input_definitions, ) self.factory_inputs = kwargs # type: ignore @@ -944,9 +946,9 @@ def __init__( super().__init__( formula_key="permute_tensor", name=name, - output=IOKey(shape=["N", ("Var", ...)], type=GenericTensorType), - input=IOKey(shape=["N", ("Var", ...)], type=GenericTensorType), - indices=IOKey(shape=["N"], type=GenericTensorType), + output=BaseKey(shape=["N", ("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=["N", ("Var", ...)], type=GenericTensorType), + indices=BaseKey(shape=["N"], type=GenericTensorType), ) self.factory_inputs = {"input": input, "indices": indices} @@ -987,18 +989,18 @@ def __init__( ) -> None: self.factory_args = {"use_bias": use_bias} formula_key = "conv1d_bias" - kwargs: dict[str, IOKey] = { - "output": IOKey( + kwargs: dict[str, BaseKey] = { + "output": BaseKey( shape=["N", "out_channels", "d_out"], type=GenericTensorType ), - "input": IOKey(shape=["N", "C_in", "d_in"], type=GenericTensorType), - "weight": IOKey( + "input": BaseKey(shape=["N", "C_in", "d_in"], type=GenericTensorType), + "weight": BaseKey( shape=["out_channels", "C_in", "kernel_size"], type=GenericTensorType ), - "bias": IOKey(shape=[1, "out_channels", 1], type=GenericTensorType), - "stride": IOKey(type=int), - "padding": IOKey(type=int | tuple[int, int]), - "dilation": IOKey(type=int), + "bias": BaseKey(shape=[1, "out_channels", 1], type=GenericTensorType), + "stride": BaseKey(type=int), + "padding": BaseKey(type=int | tuple[int, int]), + "dilation": BaseKey(type=int), } self.factory_inputs = { "input": input, @@ -1084,21 +1086,21 @@ def __init__( ) -> None: self.factory_args = {"use_bias": use_bias} formula_key = "conv2d_bias" - kwargs: dict[str, IOKey] = { - "output": IOKey( + kwargs: dict[str, BaseKey] = { + "output": BaseKey( shape=["N", "out_channels", "H_out", "W_out"], type=GenericTensorType ), - "input": IOKey(shape=["N", "C_in", "H", "W"], type=GenericTensorType), - "weight": IOKey( + "input": BaseKey(shape=["N", "C_in", "H", "W"], type=GenericTensorType), + "weight": BaseKey( shape=["out_channels", "C_in", "kernel_size_0", "kernel_size_1"], type=GenericTensorType, ), - "bias": IOKey(shape=[1, "out_channels", 1, 1], type=GenericTensorType), - "stride": IOKey(type=int | tuple[int, int]), - "padding": IOKey( + "bias": BaseKey(shape=[1, "out_channels", 1, 1], type=GenericTensorType), + "stride": BaseKey(type=int | tuple[int, int]), + "padding": BaseKey( type=int | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - "dilation": IOKey(type=int | tuple[int, int]), + "dilation": BaseKey(type=int | tuple[int, int]), } if not use_bias: @@ -1172,11 +1174,11 @@ def __init__( ) -> None: self.factory_args = {"start_dim": start_dim, "end_dim": end_dim} - key_definitions: dict[str, IOKey] = { - "output": IOKey(shape=[("C_out", ...)], type=GenericTensorType), - "input": IOKey(shape=[("C_in", ...)], type=GenericTensorType), - "start_dim": IOKey(type=int, value=start_dim), - "end_dim": IOKey(type=int, value=end_dim), + key_definitions: dict[str, BaseKey] = { + "output": BaseKey(shape=[("C_out", ...)], type=GenericTensorType), + "input": BaseKey(shape=[("C_in", ...)], type=GenericTensorType), + "start_dim": BaseKey(type=int, value=start_dim), + "end_dim": BaseKey(type=int, value=end_dim), } super().__init__(formula_key="flatten", name=name, **key_definitions) # self.factory_inputs = {"input": input} @@ -1226,12 +1228,12 @@ def __init__( super().__init__( formula_key="max_pool1d", name=name, - output=IOKey(shape=["N", ("C_in", ...), "W_out"], type=GenericTensorType), - input=IOKey(shape=["N", ("C_in", ...), "W"], type=GenericTensorType), - kernel_size=IOKey(type=int), - stride=IOKey(type=int), - padding=IOKey(type=tuple[int, int]), - dilation=IOKey(type=int), + output=BaseKey(shape=["N", ("C_in", ...), "W_out"], type=GenericTensorType), + input=BaseKey(shape=["N", ("C_in", ...), "W"], type=GenericTensorType), + kernel_size=BaseKey(type=int), + stride=BaseKey(type=int), + padding=BaseKey(type=tuple[int, int]), + dilation=BaseKey(type=int), ) self.factory_inputs = { "input": input, @@ -1282,9 +1284,9 @@ def __init__( super().__init__( formula_key="padding_converter_1d", name=name, - output=IOKey(type=tuple[int, int]), - input=IOKey(type=int | PaddingType | tuple[int, int]), - kernel_size=IOKey(type=int), + output=BaseKey(type=tuple[int, int]), + input=BaseKey(type=int | PaddingType | tuple[int, int]), + kernel_size=BaseKey(type=int), ) self.factory_inputs = {"input": input, "kernel_size": kernel_size} @@ -1320,16 +1322,16 @@ def __init__( super().__init__( formula_key="padding_converter_2d", name=name, - output=IOKey( + output=BaseKey( type=tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - input=IOKey( + input=BaseKey( type=int | PaddingType | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - kernel_size=IOKey(type=tuple[int, int]), + kernel_size=BaseKey(type=tuple[int, int]), ) self.factory_inputs = {"input": input, "kernel_size": kernel_size} @@ -1361,9 +1363,9 @@ def __init__( super().__init__( formula_key="stride_converter", name=name, - output=IOKey(type=int | tuple[int, int]), - input=IOKey(type=int | PaddingType | tuple[int, int] | None), - kernel_size=IOKey(type=int | tuple[int, int]), + output=BaseKey(type=int | tuple[int, int]), + input=BaseKey(type=int | PaddingType | tuple[int, int] | None), + kernel_size=BaseKey(type=int | tuple[int, int]), ) self.factory_inputs = {"input": input, "kernel_size": kernel_size} self._set_constraint( @@ -1396,10 +1398,10 @@ def __init__( super().__init__( formula_key="tuple_converter", name=name, - output=IOKey( + output=BaseKey( type=tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - input=IOKey( + input=BaseKey( type=int | PaddingType | tuple[int, int] @@ -1440,16 +1442,16 @@ def __init__( super().__init__( formula_key="max_pool2d", name=name, - output=IOKey( + output=BaseKey( shape=["N", ("C_in", ...), "H_out", "W_out"], type=GenericTensorType ), - input=IOKey(shape=["N", ("C_in", ...), "H", "W"], type=GenericTensorType), - kernel_size=IOKey(type=tuple[int, int]), - stride=IOKey(type=tuple[int, int]), - padding=IOKey( + input=BaseKey(shape=["N", ("C_in", ...), "H", "W"], type=GenericTensorType), + kernel_size=BaseKey(type=tuple[int, int]), + stride=BaseKey(type=tuple[int, int]), + padding=BaseKey( type=tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - dilation=IOKey(type=tuple[int, int]), + dilation=BaseKey(type=tuple[int, int]), ) self.factory_inputs = { "input": input, @@ -1514,8 +1516,8 @@ def __init__( super().__init__( formula_key="norm_modifier", name=name, - output=IOKey(shape=[], type=GenericTensorType), - input=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[], type=GenericTensorType), + input=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1545,10 +1547,10 @@ def __init__( super().__init__( formula_key="distance_matrix", name=name, - output=IOKey(shape=["N", "M"], type=GenericTensorType), - left=IOKey(shape=["N", "d"], type=GenericTensorType), - right=IOKey(shape=["M", "d"], type=GenericTensorType), - norm=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=["N", "M"], type=GenericTensorType), + left=BaseKey(shape=["N", "d"], type=GenericTensorType), + right=BaseKey(shape=["M", "d"], type=GenericTensorType), + norm=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"left": left, "right": right} @@ -1581,9 +1583,9 @@ def __init__( super().__init__( formula_key="polynomial_features", name=name, - output=IOKey(shape=["N", "d_out"], type=GenericTensorType), - input=IOKey(shape=["N", "d_in"], type=GenericTensorType), - degree=IOKey(type=int, value=degree), + output=BaseKey(shape=["N", "d_out"], type=GenericTensorType), + input=BaseKey(shape=["N", "d_in"], type=GenericTensorType), + degree=BaseKey(type=int, value=degree), ) self.factory_inputs = {"input": input, "degree": degree} @@ -1621,10 +1623,10 @@ def __init__( super().__init__( formula_key="tsne_p_joint", name=name, - output=IOKey(shape=["N", "M"], type=MyTensor[float]), - squared_distances=IOKey(shape=["N", "M"], type=GenericTensorType), - target_perplexity=IOKey(shape=[], type=MyTensor[float]), - threshold=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=["N", "M"], type=MyTensor[float]), + squared_distances=BaseKey(shape=["N", "M"], type=GenericTensorType), + target_perplexity=BaseKey(shape=[], type=MyTensor[float]), + threshold=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = { "squared_distances": squared_distances, @@ -1661,9 +1663,9 @@ def __init__( super().__init__( formula_key="ones_with_zero_diag", name=name, - output=IOKey(shape=["N", "M"], type=MyTensor[float]), - N=IOKey(type=int, value=N), - M=IOKey(type=int | None, value=M), + output=BaseKey(shape=["N", "M"], type=MyTensor[float]), + N=BaseKey(type=int, value=N), + M=BaseKey(type=int | None, value=M), ) self.factory_inputs = {"N": N, "M": M} self._set_constraint(fn=eye_constraints, keys=["output", "N", "M"]) @@ -1691,9 +1693,9 @@ def __init__( super().__init__( formula_key="eye", name=name, - output=IOKey(shape=["N", "M"], type=MyTensor[float]), - N=IOKey(type=int, value=N), - M=IOKey(type=int | None, value=M), + output=BaseKey(shape=["N", "M"], type=MyTensor[float]), + N=BaseKey(type=int, value=N), + M=BaseKey(type=int | None, value=M), ) self.factory_inputs = {"N": N, "M": M} @@ -1718,8 +1720,8 @@ def __init__( super().__init__( formula_key="cholesky", name=name, - output=IOKey(shape=["N", "N"], type=MyTensor[float]), - input=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", "N"], type=MyTensor[float]), + input=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1745,10 +1747,10 @@ def __init__( super().__init__( formula_key="gpr_alpha", name=name, - output=IOKey(shape=["N", 1], type=MyTensor[float]), - label_mu_diff=IOKey(shape=["N", 1], type=GenericTensorType), - L=IOKey(shape=["N", "N"], type=GenericTensorType), - K_term=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", 1], type=MyTensor[float]), + label_mu_diff=BaseKey(shape=["N", 1], type=GenericTensorType), + L=BaseKey(shape=["N", "N"], type=GenericTensorType), + K_term=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"label_mu_diff": label_mu_diff, "L": L, "K_term": K_term} @@ -1780,10 +1782,10 @@ def __init__( super().__init__( formula_key="gpr_v_outer", name=name, - output=IOKey(shape=["N", "N"], type=MyTensor[float]), - K=IOKey(shape=["N", "N"], type=GenericTensorType), - K_term=IOKey(shape=["N", "N"], type=GenericTensorType), - L=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", "N"], type=MyTensor[float]), + K=BaseKey(shape=["N", "N"], type=GenericTensorType), + K_term=BaseKey(shape=["N", "N"], type=GenericTensorType), + L=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"K": K, "K_term": K_term, "L": L} @@ -1807,8 +1809,8 @@ def __init__( super().__init__( formula_key="transposed_diag", name=name, - output=IOKey(shape=["N", 1], type=GenericTensorType), - input=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", 1], type=GenericTensorType), + input=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1857,10 +1859,10 @@ def __init__( super().__init__( formula_key="arange", name=name, - output=IOKey(shape=output_shp, type=GenericTensorType), - start=IOKey(type=int | float, value=start), - stop=IOKey(type=int | float, value=stop), - step=IOKey(type=int | float, value=step), + output=BaseKey(shape=output_shp, type=GenericTensorType), + start=BaseKey(type=int | float, value=start), + stop=BaseKey(type=int | float, value=stop), + step=BaseKey(type=int | float, value=step), ) self.set_canonical_input("stop") self.factory_inputs = {"start": start, "stop": stop, "step": step} @@ -1894,8 +1896,8 @@ def __init__( super().__init__( formula_key="randn", name=name, - output=IOKey(shape=[("output", ...)], type=GenericTensorType), - shape=IOKey(type=tuple[int, ...], value=shape), + output=BaseKey(shape=[("output", ...)], type=GenericTensorType), + shape=BaseKey(type=tuple[int, ...], value=shape), ) self.set_constraint(randn_constraints, keys=["output", "shape"]) @@ -1922,9 +1924,9 @@ def __init__( super().__init__( formula_key="broadcast_to", name=name, - output=IOKey(shape=[("output", ...)], type=GenericTensorType), - input=IOKey(shape=[("input", ...)], type=GenericTensorType), - shape=IOKey(type=tuple[int, ...], value=shape), + output=BaseKey(shape=[("output", ...)], type=GenericTensorType), + input=BaseKey(shape=[("input", ...)], type=GenericTensorType), + shape=BaseKey(type=tuple[int, ...], value=shape), ) self.factory_inputs = {"input": input, "shape": shape} @@ -1960,10 +1962,10 @@ def __init__( super().__init__( formula_key="eigvalsh", name=name, - output=IOKey(shape=["N", 1], type=MyTensor[float]), - K_term=IOKey(shape=["N", "N"], type=GenericTensorType), - L=IOKey(shape=["N", "N"], type=GenericTensorType), - threshold=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=["N", 1], type=MyTensor[float]), + K_term=BaseKey(shape=["N", "N"], type=GenericTensorType), + L=BaseKey(shape=["N", "N"], type=GenericTensorType), + threshold=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"K_term": K_term, "L": L, "threshold": threshold} @@ -1987,8 +1989,8 @@ def __init__( super().__init__( formula_key="squeeze", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -2017,9 +2019,9 @@ def __init__( super().__init__( formula_key="auc_core", name=name, - output=IOKey(shape=[2, "M"], type=MyTensor[float]), - input=IOKey(shape=["N"], type=GenericTensorType), - label=IOKey(shape=["N"], type=GenericTensorType), + output=BaseKey(shape=[2, "M"], type=MyTensor[float]), + input=BaseKey(shape=["N"], type=GenericTensorType), + label=BaseKey(shape=["N"], type=GenericTensorType), ) self.factory_inputs = {"input": input, "label": label} @@ -2050,9 +2052,9 @@ def __init__( super().__init__( formula_key="primitive_embedding", name=name, - output=IOKey(shape=[("N1", ...), "d1", out_dim], type=GenericTensorType), - input=IOKey(shape=[("N1", ...), "d1"], type=MyTensor[int]), - weight=IOKey(shape=[num_embeddings, out_dim], type=GenericTensorType), + output=BaseKey(shape=[("N1", ...), "d1", out_dim], type=GenericTensorType), + input=BaseKey(shape=[("N1", ...), "d1"], type=MyTensor[int]), + weight=BaseKey(shape=[num_embeddings, out_dim], type=GenericTensorType), ) self.factory_inputs = {"input": input, "weight": weight} @@ -2100,19 +2102,19 @@ def __init__( self.use_attn_mask = use_attn_mask formula_key = "scaled_dot_product_attention" - kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var", ...), "L", "O"], type=MyTensor[float]), - "query": IOKey(shape=[("Var", ...), "L", "E"], type=GenericTensorType), - "key": IOKey(shape=[("Var", ...), "S", "E"], type=GenericTensorType), - "value": IOKey(shape=[("Var", ...), "S", "O"], type=GenericTensorType), - "dropout_p": IOKey(type=float, value=dropout_p), - "attn_mask": IOKey(type=NoneType, value=None), - "is_causal": IOKey(type=bool, value=is_causal), - "scale": IOKey(type=NoneType | int | float, value=scale), + kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var", ...), "L", "O"], type=MyTensor[float]), + "query": BaseKey(shape=[("Var", ...), "L", "E"], type=GenericTensorType), + "key": BaseKey(shape=[("Var", ...), "S", "E"], type=GenericTensorType), + "value": BaseKey(shape=[("Var", ...), "S", "O"], type=GenericTensorType), + "dropout_p": BaseKey(type=float, value=dropout_p), + "attn_mask": BaseKey(type=NoneType, value=None), + "is_causal": BaseKey(type=bool, value=is_causal), + "scale": BaseKey(type=NoneType | int | float, value=scale), } if use_attn_mask: - kwargs["attn_mask"] = IOKey( + kwargs["attn_mask"] = BaseKey( shape=["L", "S"], type=GenericTensorType, value=TBD ) @@ -2143,8 +2145,8 @@ def __call__( # type: ignore[override] not self.use_attn_mask and attn_mask is not NOT_GIVEN and not isinstance(attn_mask, str) - and isinstance(attn_mask, IOKey) - and attn_mask._value is not None # TODO: Here will be updated! + and isinstance(attn_mask, BaseKey) + and attn_mask.value is not None # TODO: Here will be updated! ): raise KeyError( "Model does not have 'attn_mask' input. Got attn_mask argument!" @@ -2181,10 +2183,10 @@ def __init__( super().__init__( formula_key="positional_encoding", name=name, - output=IOKey(shape=[("N1", ...)], type=GenericTensorType), - input=IOKey(shape=[("N1", ...)], type=GenericTensorType), - hidden_dim=IOKey(type=int, value=hidden_dim), - max_len=IOKey(type=int, value=max_len), + output=BaseKey(shape=[("N1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("N1", ...)], type=GenericTensorType), + hidden_dim=BaseKey(type=int, value=hidden_dim), + max_len=BaseKey(type=int, value=max_len), ) self.factory_inputs = { "input": input, @@ -2222,10 +2224,10 @@ def __init__( super().__init__( formula_key="swapaxes", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - axis1=IOKey(type=int, value=axis1), - axis2=IOKey(type=int, value=axis2), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + axis1=BaseKey(type=int, value=axis1), + axis2=BaseKey(type=int, value=axis2), ) self.factory_inputs = {"input": input, "axis1": axis1, "axis2": axis2} @@ -2262,10 +2264,10 @@ def __init__( super().__init__( formula_key="where", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - cond=IOKey(shape=[("Var3", ...)], type=MyTensor[bool], value=TBD), - input1=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + cond=BaseKey(shape=[("Var3", ...)], type=MyTensor[bool], value=TBD), + input1=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self.factory_inputs = {"cond": cond, "input1": input1, "input2": input2} @@ -2297,8 +2299,8 @@ def __init__( super().__init__( formula_key="isnan", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[bool]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[bool]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -2318,8 +2320,8 @@ def __init__( super().__init__( formula_key="unique", name=name, - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -2343,9 +2345,9 @@ def __init__( super().__init__( formula_key="trapezoid", name=name, - output=IOKey(shape=[], type=GenericTensorType), - y=IOKey(shape=[("Var", ...)], type=GenericTensorType), - x=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[], type=GenericTensorType), + y=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + x=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"y": y, "x": x} @@ -2376,11 +2378,11 @@ def __init__( super().__init__( formula_key="nan_to_num", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - nan=IOKey(type=float, value=nan), - posinf=IOKey(type=float | None, value=posinf), - neginf=IOKey(type=float | None, value=neginf), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + nan=BaseKey(type=float, value=nan), + posinf=BaseKey(type=float | None, value=posinf), + neginf=BaseKey(type=float | None, value=neginf), ) self.factory_inputs = { "input": input, @@ -2417,9 +2419,9 @@ def __init__( super().__init__( formula_key="pad", name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - pad_width=IOKey(type=tuple[tuple[int, int], ...], value=pad_width), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + pad_width=BaseKey(type=tuple[tuple[int, int], ...], value=pad_width), ) self.factory_inputs = {"input": input, "pad_width": pad_width} @@ -2446,8 +2448,8 @@ def __init__( super().__init__( formula_key="zeros_like", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 8cb17f93..a77406ca 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -153,12 +153,12 @@ def dict_to_model(modelparams: dict[str, Any]) -> BaseModel: key = IOKey(**key_kwargs) mappings[k] = IOKey( **key_kwargs, - connections=[ + connections={ getattr(submodels_dict[value[0]], value[1]) if isinstance(value, Sequence) else value for value in conn["connect"] - ], + }, ) elif "name" in conn: key_kwargs = create_iokey_kwargs(conn) @@ -538,21 +538,21 @@ def item_to_json(item: IOKey): # TODO: Currently type is not supported for Tensors. # Handle This whit conversion test updates. result: dict[str, Any] = {} - if not isinstance(item._value, ToBeDetermined): - result["value"] = item._value - if item._shape is not None: + if not isinstance(item.data.value, ToBeDetermined): + result["value"] = item.data.value + if item.shape is not None: shape_template = [] - for symbol in item._shape: + for symbol in item.shape: if isinstance(symbol, tuple): # variadic shape_template.append(f"{symbol[0]},...") else: shape_template.append(str(symbol)) result["shape_template"] = shape_template - elif isinstance(item._type, UnionType): - result["type"] = [type_to_str(item) for item in item._type.__args__] + elif isinstance(item.data.type, UnionType): + result["type"] = [type_to_str(item) for item in item.data.type.__args__] else: result["type"] = [ - type_to_str(item._type), + type_to_str(item.data.type), ] return result diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 3aee00d6..5944e96b 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -969,7 +969,7 @@ def test_nontensor_extend_from_input_multiple_connection(): model += mean1 model += mean2 model += mean3 - model += mean4(axis=IOKey(connections=[mean1.axis, mean2.axis, mean3.axis])) + model += mean4(axis=IOKey(connections={mean1.axis, mean2.axis, mean3.axis})) assert ( mean1.axis.data.metadata == mean2.axis.data.metadata diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index 43421af3..fdf38fa6 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -18,9 +18,9 @@ from mithril.framework import Scalar, Tensor from mithril.framework.common import ( NOT_GIVEN, + BaseKey, ConnectionType, GenericTensorType, - IOKey, MyTensor, ShapeRepr, Uniadic, @@ -98,8 +98,8 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) @@ -111,8 +111,8 @@ class Model2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) self._set_constraint( @@ -127,8 +127,8 @@ class Model3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=MyTensor[int] | MyTensor[bool]), - output=IOKey(shape=[("Var2", ...)], type=MyTensor[int] | MyTensor[bool]), + input=BaseKey(shape=[("Var1", ...)], type=MyTensor[int] | MyTensor[bool]), + output=BaseKey(shape=[("Var2", ...)], type=MyTensor[int] | MyTensor[bool]), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) @@ -141,9 +141,9 @@ class MyAdd2(PrimitiveModel): def __init__(self, left, right, output) -> None: super().__init__( formula_key="add", - output=IOKey(shape=output, type=GenericTensorType), - left=IOKey(shape=left, type=GenericTensorType), - right=IOKey(shape=right, type=GenericTensorType), + output=BaseKey(shape=output, type=GenericTensorType), + left=BaseKey(shape=left, type=GenericTensorType), + right=BaseKey(shape=right, type=GenericTensorType), ) self._set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] diff --git a/tests/scripts/test_extend_template.py b/tests/scripts/test_extend_template.py index 01b0979a..a192cab1 100644 --- a/tests/scripts/test_extend_template.py +++ b/tests/scripts/test_extend_template.py @@ -150,7 +150,7 @@ def test_shape_reshape(): # Create with shortcut. model_1 = Model() model_1 += (lin_1 := Linear(dimension=1))(input="input_1", weight="w_1", bias="b_1") - shp = lin_1.input.shape() + shp = lin_1.input.shape model_1 += (lin_2 := Linear(dimension=2))(input="input_2", weight="w_2", bias="b_2") reshaped = lin_2.output.reshape(shp) model_1 += Add()(left=lin_1.output, right=reshaped, output=IOKey(name="output")) @@ -210,7 +210,7 @@ def test_slice_item(): model_1 += (lin_1 := Linear(dimension=1))( input="input", weight="weight", bias="bias" ) - shp = lin_1.input.shape() + shp = lin_1.input.shape item = shp[1].tensor() slc = shp[:].tensor() model_1 += Add()(left=item, right=slc, output=IOKey(name="output")) @@ -1391,68 +1391,6 @@ def test_invalid_input(): "asd" + model.input # type: ignore -# def test_coercion_models_1(): -# backend = JaxBackend() - -# data = {"left": backend.randn(3, 4, 5), "right": backend.randn(3, 4, 5)} - -# model1 = Model() -# model1 += (add_model := Add())(left="left", right="right") -# out = add_model.output -# scalar_item_output = out.shape()[1] -# tensor_item_output = out[1] -# model1 += Buffer()( -# input=scalar_item_output + tensor_item_output, output=IOKey(name="output") -# ) - -# model2 = Model() -# model2 += (add_model := Add())(left="left", right="right") -# model2 += (shp_model := Shape())(input=add_model.output) -# model2 += (to_tensor_model := ToTensor())(input=shp_model.output) -# model2 += (tensor_item_model1 := TensorItem())( -# input=to_tensor_model.output, index=1 -# ) -# model2 += (tensor_item_model2 := TensorItem())(input=add_model.output, index=1) -# model2 += (add_model_2 := Add())( -# left=tensor_item_model1.output, right=tensor_item_model2.output -# ) -# model2 += Buffer()(input=add_model_2.output, output=IOKey(name="output")) - -# compare_models(model1, model2, backend, data, check_internals=False) - - -# def test_coercion_models_2(): -# backend = JaxBackend() - -# data = {"left": backend.randn(5, 6, 2), "right": backend.randn(5, 6, 2)} - -# model1 = Model() -# model1 += (add_model := Add())(left="left", right="right") -# out = add_model.output -# scalar_item_output = out.shape()[1:3] -# tensor_item_output = out[1:3] -# model1 += Buffer()( -# input=scalar_item_output + tensor_item_output, output=IOKey(name="output") -# ) - -# model2 = Model() -# model2 += (add_model := Add())(left="left", right="right") -# model2 += (shp_model := Shape())(input=add_model.output) -# model2 += (to_tensor_model := ToTensor())(input=shp_model.output) -# model2 += (tensor_item_model1 := TensorSlice(start=TBD, stop=TBD, step=TBD))( -# input=to_tensor_model.output, start=1, stop=3, step=None -# ) -# model2 += (tensor_item_model2 := TensorSlice(start=TBD, stop=TBD, step=TBD))( -# input=add_model.output, start=1, stop=3, step=None -# ) -# model2 += (add_model_2 := Add())( -# left=tensor_item_model1.output, right=tensor_item_model2.output -# ) -# model2 += Buffer()(input=add_model_2.output, output=IOKey(name="output")) - -# compare_models(model1, model2, backend, data, check_internals=False) - - def test_tensoritem_multiple_slice_1(): model1 = Model() diff --git a/tests/scripts/test_functions.py b/tests/scripts/test_functions.py index 55a31abb..bbe5ecf6 100644 --- a/tests/scripts/test_functions.py +++ b/tests/scripts/test_functions.py @@ -20,7 +20,7 @@ from mithril import CBackend, JaxBackend, NumpyBackend, TorchBackend from mithril.backends.with_manualgrad.numpy_backend.ops_grad import add_grad from mithril.framework import NOT_GIVEN, ConnectionType, ExtendInfo -from mithril.framework.common import GenericTensorType, IOKey +from mithril.framework.common import BaseKey, GenericTensorType, IOKey from mithril.framework.constraints import bcast from mithril.models import ( Absolute, @@ -413,9 +413,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - rhs=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + rhs=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "input", "rhs"] @@ -522,9 +522,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - rhs=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + rhs=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "input", "rhs"] diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index a3425b78..ce717da9 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -272,7 +272,7 @@ def test_7(): model += (relu1 := Relu())(input="in1", output="relu1_output") model += (relu2 := Relu())(input="in2", output="relu2_output") model += (relu3 := Relu())( - input="", output=IOKey(name="my_input", connections=[relu1.input, relu2.input]) + input="", output=IOKey(name="my_input", connections={relu1.input, relu2.input}) ) assert ( model.dag[relu1]["input"].metadata @@ -448,7 +448,7 @@ def test_iokey_shapes_3(): input3=IOKey(name="input3", shape=[3, "a"]), ) - conns = [main_model.input1, main_model.input2, main_model.input3] # type: ignore + conns = {main_model.input1, main_model.input2, main_model.input3} # type: ignore key = IOKey(name="input", connections=conns) main_model += Buffer()(input=key, output="output1") @@ -1150,7 +1150,7 @@ def test_compare_models_5(): sigmoid = Sigmoid() add = Add() model2 += add(output=IOKey(name="output")) - conn = IOKey(connections=[add.left, add.right]) + conn = IOKey(connections={add.left, add.right}) model2 += sigmoid(input="input", output=conn) model2.set_shapes({"input": [2, 2]}) @@ -1245,7 +1245,7 @@ def test_iokey_template_4(): model = Model() left = IOKey("left") - res = left.shape()[0] + res = left.shape[0] model += Buffer()(res.tensor(), IOKey("output")) diff --git a/tests/scripts/test_jittable.py b/tests/scripts/test_jittable.py index ae3ec005..2ae0831c 100644 --- a/tests/scripts/test_jittable.py +++ b/tests/scripts/test_jittable.py @@ -29,7 +29,7 @@ to_tensor, ) from mithril.framework import NOT_GIVEN, ConnectionType, ExtendInfo -from mithril.framework.common import GenericTensorType +from mithril.framework.common import BaseKey, GenericTensorType from mithril.framework.constraints import bcast from mithril.models import ( TBD, @@ -221,9 +221,9 @@ class Adder(CustomPrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint(fn=bcast, keys=["output", "left", "right"]) @@ -257,7 +257,7 @@ def test_logical_model_jittable_1(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") with pytest.raises(Exception) as error_info: - model += Item()(input=IOKey(name="input", connections=[add1.left, add2.left])) + model += Item()(input=IOKey(name="input", connections={add1.left, add2.left})) modified_msg = re.sub("\\s*", "", str(error_info.value)) expected_msg = ( "Model with enforced Jit can not be extended by a non-jittable model! \ @@ -274,7 +274,7 @@ def test_logical_model_jittable_2(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) assert not model.enforce_jit @@ -287,7 +287,7 @@ def test_logical_model_jittable_3(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) assert not model.enforce_jit @@ -300,7 +300,7 @@ def test_physical_model_jit_1(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) backend = JaxBackend() @@ -320,7 +320,7 @@ def test_physical_model_jit_2(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) backend = JaxBackend() @@ -345,9 +345,9 @@ class Adder(CustomPrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint(fn=bcast, keys=["output", "left", "right"]) @@ -367,7 +367,7 @@ def test_jit_2(): model = Model(enforce_jit=False) model += (add_model := Add())(left="left", right="right") in1 = add_model.output - out1 = in1.shape() + out1 = in1.shape out2 = out1.tensor().sum() mean_model = Mean(axis=TBD) model += (to_list := Item())(input=out2) diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index bf4b8ef6..a0763227 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -16,7 +16,6 @@ import mithril as ml from mithril.models import Add, Model -from mithril.utils.utils import OrderedSet def test_directed_call_connection(): @@ -28,9 +27,9 @@ def test_directed_call_connection(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) - assert left_info._name is None - assert left_info._value == 1 + assert left_info.connections == {connection} + assert left_info.name is None + assert left_info.data.value == 1 def test_directed_call_int(): @@ -59,8 +58,8 @@ def test_directed_call_str(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 + assert left_info.name == "in1" + assert left_info.data.value == 1 def test_directed_call_iokey_value_equal(): @@ -71,8 +70,8 @@ def test_directed_call_iokey_value_equal(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 + assert left_info.name == "in1" + assert left_info.data.value == 1 def test_directed_call_iokey_value_not_equal(): @@ -92,13 +91,13 @@ def test_directed_call_iokey_value_tbd(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.name == "in1" + assert left_info.data.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_not_equal(): add1 = Add(left=1) - iokey = ml.IOKey("in1", value=2, connections=[Add().left]) + iokey = ml.IOKey("in1", value=2, connections={Add().left}) with pytest.raises(ValueError) as err_info: add1(left=iokey) @@ -108,40 +107,40 @@ def test_directed_call_connect_key_value_not_equal(): def test_directed_call_connect_key_none(): add1 = Add(left=1) connection = Add().left - con = ml.IOKey(connections=[connection]) + con = ml.IOKey(connections={connection}) info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) - assert left_info._value == 1 # key is set to IOKey with val from factory_inputs + assert left_info.connections == {connection} + assert left_info.data.value == 1 # key is set to IOKey with val from factory_inputs def test_directed_call_connect_key_value_tbd(): add1 = Add(left=1) connection = Add().left - con = ml.IOKey(name="in1", connections=[connection]) + con = ml.IOKey(name="in1", connections={connection}) info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) + assert left_info.connections == {connection} assert isinstance(left_info, ml.IOKey) - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.data.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_equal(): add1 = Add(left=1) connection = Add().left - con = ml.IOKey("in1", value=1, connections=[connection]) + con = ml.IOKey("in1", value=1, connections={connection}) info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.connections == {connection} + assert left_info.data.value == 1 # value is set to val from factory_inputs def test_directed_call_extend_template(): @@ -256,7 +255,7 @@ def test_integration_call_arg_iokey_value_tbd(): def test_integration_call_arg_connect_key_value_not_equal(): add1 = Add(left=1) - connect = ml.IOKey("in1", value=2, connections=[Add().left]) + connect = ml.IOKey("in1", value=2, connections={Add().left}) model = Model() with pytest.raises(ValueError) as err_info: @@ -267,7 +266,7 @@ def test_integration_call_arg_connect_key_value_not_equal(): def test_integration_call_arg_connect_key_none(): add1 = Add(left=1) add2 = Add() - con = ml.IOKey(connections=[add2.left]) + con = ml.IOKey(connections={add2.left}) model = Model() model += add2(left="in1", right="in2") @@ -281,7 +280,7 @@ def test_integration_call_arg_connect_key_none(): def test_integration_call_arg_connect_key_value_tbd(): add1 = Add(left=1) add2 = Add() - con = ml.IOKey(name="in1", expose=True, connections=[add2.left]) + con = ml.IOKey(name="in1", expose=True, connections={add2.left}) model = Model() model += add2(right="in2") @@ -295,7 +294,7 @@ def test_integration_call_arg_connect_key_value_tbd(): def test_integration_call_arg_connect_key_value_equal(): add1 = Add(left=1) add2 = Add() - con = ml.IOKey(connections=[add2.left], value=1) + con = ml.IOKey(connections={add2.left}, value=1) model = Model() model += add2(right="in2") diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index 460bf2cd..c6efe302 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -16,7 +16,7 @@ import mithril from mithril import JaxBackend, TorchBackend -from mithril.framework.common import TBD, GenericTensorType, IOKey +from mithril.framework.common import TBD, BaseKey, GenericTensorType, IOKey from mithril.framework.constraints import squeeze_constraints from mithril.models import ( L2, @@ -489,7 +489,7 @@ def test_composite_9(): input="", weight="weight1", output=IOKey(name="output2") ) model += Linear(dimension=71)( - input="input", weight="weight2", output=IOKey(connections=[l1.input, l2.input]) + input="input", weight="weight2", output=IOKey(connections={l1.input, l2.input}) ) model_dict_created = dict_conversions.model_to_dict(model) @@ -516,7 +516,7 @@ def test_composite_10(): model += Linear(dimension=71)( input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"]), + output=IOKey(name="my_input", connections={"input1", "input2"}), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -543,7 +543,7 @@ def test_composite_10_expose_false(): model += Linear(dimension=71)( input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"], expose=False), + output=IOKey(name="my_input", connections={"input1", "input2"}, expose=False), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -588,7 +588,7 @@ def test_composite_12(): Linear(dimension=71), input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"]), + output=IOKey(name="my_input", connections={"input1", "input2"}), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -622,7 +622,7 @@ def test_composite_13(): Linear(dimension=71), input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"]), + output=IOKey(name="my_input", connections={"input1", "input2"}), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -920,9 +920,9 @@ def __init__(self, threshold=3) -> None: threshold *= 2 super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - rhs=IOKey(type=int, value=threshold), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + rhs=BaseKey(type=int, value=threshold), ) self.set_constraint( fn=squeeze_constraints, keys=[CustomPrimitiveModel.output_key, "input"] diff --git a/tests/scripts/test_parallel.py b/tests/scripts/test_parallel.py index 24472dee..ea8767a6 100644 --- a/tests/scripts/test_parallel.py +++ b/tests/scripts/test_parallel.py @@ -383,7 +383,7 @@ def test_torch_parallel_2(): # primitive eye. model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = create_parallel_backend(device_mesh=(4, 1)) backend.ones([256]) @@ -507,7 +507,7 @@ def test_torch_parallel_5(): # primitive eye. model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.TorchBackend() @@ -957,7 +957,7 @@ def test_jax_parallel_2(): if "cuda" in mithril.JaxBackend.get_available_devices(): model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.JaxBackend(device="cuda", device_mesh=(4, 1)) backend.ones([256]) @@ -1090,7 +1090,7 @@ def test_jax_parallel_5(): if "cuda" in mithril.JaxBackend.get_available_devices(): model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.JaxBackend(device="cuda") diff --git a/tests/scripts/test_ref_counts.py b/tests/scripts/test_ref_counts.py index c7001ac2..4e86d3a0 100644 --- a/tests/scripts/test_ref_counts.py +++ b/tests/scripts/test_ref_counts.py @@ -18,6 +18,7 @@ from mithril.framework.common import ( NOT_GIVEN, + BaseKey, Connection, ConnectionType, GenericTensorType, @@ -81,8 +82,8 @@ class TestModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["c", "d", ("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["c", "d", ("Var2", ...)], type=GenericTensorType), ) model = Model() @@ -114,8 +115,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), ) buff_model1 = MyModel() @@ -238,14 +239,14 @@ def test_deleted_variadic_ref_count_7(): model += add_5(left="") conn = IOKey( - connections=[ + connections={ add_1.left, add_1.right, add_2.left, add_2.right, add_3.left, add_3.right, - ] + } ) model += add_6(left=conn, right="right", output="output") @@ -345,8 +346,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) all_uniadics = set() @@ -384,8 +385,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", "c"], type=GenericTensorType), - output=IOKey(shape=["c", "d", "e"], type=GenericTensorType), + input=BaseKey(shape=["a", "b", "c"], type=GenericTensorType), + output=BaseKey(shape=["c", "d", "e"], type=GenericTensorType), ) all_uniadics = set() @@ -409,8 +410,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", "c"], type=GenericTensorType), - output=IOKey(shape=["d", "e", "f"], type=GenericTensorType), + input=BaseKey(shape=["a", "b", "c"], type=GenericTensorType), + output=BaseKey(shape=["d", "e", "f"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -440,8 +441,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[1, 1, 1], type=GenericTensorType), - output=IOKey(shape=[1, 1, 1], type=GenericTensorType), + input=BaseKey(shape=[1, 1, 1], type=GenericTensorType), + output=BaseKey(shape=[1, 1, 1], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -564,8 +565,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -618,8 +619,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -671,8 +672,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -707,8 +708,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["b", "a"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["b", "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -911,7 +912,7 @@ def test_deleted_tensors_ref_count_3(): model += buffer4(input="input4", output=IOKey(name="output4")) model += buffer5(input="input5", output=IOKey(name="output5")) model += buffer6(input="input6", output=IOKey(name="output6")) - connections = [buffer1.input, buffer2.input, buffer3.input, model.output4] # type: ignore + connections = {buffer1.input, buffer2.input, buffer3.input, model.output4} # type: ignore conn = IOKey(connections=connections) model += buffer7(input=conn, output=IOKey(name="output")) @@ -1137,7 +1138,7 @@ def test_deleted_edge_ref_count_6(): output2=IOKey(name="output2"), output3=IOKey(name="output3"), ) - connections = [main_model.output1, main_model.input2] # type: ignore + connections = {main_model.output1, main_model.input2} # type: ignore conn = IOKey(name="abcd", expose=True, connections=connections) main_model += sigmoid4(input=conn, output=IOKey(name="output5")) @@ -1268,8 +1269,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a1"], type=GenericTensorType), - output=IOKey(shape=["a2"], type=GenericTensorType), + input=BaseKey(shape=["a1"], type=GenericTensorType), + output=BaseKey(shape=["a2"], type=GenericTensorType), ) buff_model1 = MyModel() @@ -1308,8 +1309,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1340,8 +1341,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1372,8 +1373,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1402,8 +1403,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1430,8 +1431,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 235020f4..17c4b287 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -34,12 +34,12 @@ NOT_AVAILABLE, NOT_GIVEN, TBD, + BaseKey, ConnectionData, ConnectionType, GenericTensorType, IOKey, NotAvailable, - OrderedSet, ToBeDetermined, UniadicRecord, Variadic, @@ -109,6 +109,7 @@ Where, ) from mithril.utils.type_utils import is_list_int +from mithril.utils.utils import OrderedSet from .helper import assert_models_equal from .test_shapes import check_shapes_semantically @@ -261,7 +262,7 @@ def test_cyclic_extension_5(): left="input5", right="input6", output=IOKey( - name="my_input", expose=False, connections=[sum1.left, sum2.right] + name="my_input", expose=False, connections={sum1.left, sum2.right} ), ) @@ -1226,9 +1227,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] @@ -2007,7 +2008,7 @@ def test_multiple_output_connections(): with pytest.raises(Exception) as err_info: model += add_1( - left="left", right="right", output=IOKey(connections=[add_2.left, "out2"]) + left="left", right="right", output=IOKey(connections={add_2.left, "out2"}) ) assert ( @@ -2024,7 +2025,7 @@ def test_multiple_output_connections_2(): model += add_1( left="left", right="right", - output=IOKey(name="my_internal_key", connections=[add_2.left, "in3"]), + output=IOKey(name="my_internal_key", connections={add_2.left, "in3"}), ) assert ( @@ -3944,7 +3945,7 @@ def test_connect_1(): relu3 = Relu() model += relu1(output="relu_output_1") model += relu2(input="", output="relu_output_2") - model += relu3(input="", output=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input="", output=IOKey(connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].metadata @@ -3961,7 +3962,7 @@ def test_connect_2(): model += relu1(input="in1", output="relu_output_1") model += relu2(input="in2", output="relu_output_2") model += relu3( - input="", output=IOKey(name="my_input", connections=[relu1.input, relu2.input]) + input="", output=IOKey(name="my_input", connections={relu1.input, relu2.input}) ) assert ( @@ -3978,7 +3979,7 @@ def test_connect_3(): relu3 = Relu() model += relu1(output="relu_output_1") model += relu2(input="", output="relu_output_2") - model += relu3(input=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input=IOKey(connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].metadata @@ -3994,7 +3995,7 @@ def test_connect_4(): relu3 = Relu() model += relu1(input="in1", output="relu_output_1") model += relu2(input="in2", output="relu_output_2") - model += relu3(input=IOKey(name="my_input", connections=[relu1.input, relu2.input])) + model += relu3(input=IOKey(name="my_input", connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].metadata @@ -4011,7 +4012,7 @@ def test_connect_5(): relu3 = Relu() model += relu1(input="in1", output="relu_output_1") model += relu2(input="", output="relu_output_2") - model += relu3(input=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input=IOKey(connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].key @@ -4034,7 +4035,7 @@ def test_connect_6(): model += relu2(input="in2", output="relu_output_2") with pytest.raises(KeyError) as error_info: - model += Relu()(input=IOKey(connections=[relu1.input, relu2.input])) + model += Relu()(input=IOKey(connections={relu1.input, relu2.input})) assert str(error_info.value) == ( "'Requires a connection to have only one unique key name but " @@ -4118,10 +4119,10 @@ def test_connect_composite_2_extend_from_inputs(): m2 = deepcopy(submodel) subcopy = deepcopy(submodel) model += m1(left="left", right="right") - model += m2(left=IOKey(connections=[m1.output]), right="right") # type: ignore + model += m2(left=IOKey(connections={m1.output}), right="right") # type: ignore model += subcopy( - left=IOKey(connections=[m2.output]), # type: ignore - right=IOKey(connections=[m2.output]), # type: ignore + left=IOKey(connections={m2.output}), # type: ignore + right=IOKey(connections={m2.output}), # type: ignore output="output", ) @@ -4139,9 +4140,9 @@ def test_composite_6_extend_from_inputs_connect(): relu3 = Relu() relu4 = Relu() model += relu1(output="output") - model += relu2(input=IOKey(connections=[relu1.input])) - model += relu3(input="my_input", output=IOKey(connections=[relu2.input])) - model += relu4(input=IOKey(connections=[relu3.input])) + model += relu2(input=IOKey(connections={relu1.input})) + model += relu3(input="my_input", output=IOKey(connections={relu2.input})) + model += relu4(input=IOKey(connections={relu3.input})) assert ( relu2.input.data.metadata @@ -4163,8 +4164,8 @@ def test_composite_4_extend_from_inputs_connect(): relu3 = Relu() relu4 = Relu() model += relu1(input="my_input", output=IOKey(name="output")) - model += relu2(input=IOKey(connections=[relu1.input])) - model += relu3(input=IOKey(connections=[relu2.input])) + model += relu2(input=IOKey(connections={relu1.input})) + model += relu3(input=IOKey(connections={relu2.input})) model += relu4(input="input1", output="my_input") backend = TorchBackend() @@ -4184,7 +4185,7 @@ def test_integration_composite_1_extend_from_inputs_1_with_connect(): m1 = Layer(dimension=2, activation=Sigmoid()) model += m2(weight="w1", bias="b1", output="output") model += m1( - input="input", weight="w0", bias="b0", output=IOKey(connections=[m2.input]) + input="input", weight="w0", bias="b0", output=IOKey(connections={m2.input}) ) assert m1.output.data.metadata == m2.input.data.metadata @@ -4229,7 +4230,7 @@ def test_connect_8(): r2 = Relu() model += t(output="output1") model += r1(input="input2", output="output2") - model += r2(input="", output=IOKey(connections=[t.input, r1.input])) + model += r2(input="", output=IOKey(connections={t.input, r1.input})) assert r1.input.data.metadata == r2.output.data.metadata == t.input.data.metadata @@ -4241,7 +4242,7 @@ def test_connect_9(): r2 = Relu() model += t(input="input1", output="output1") model += r1(input="", output="output2") - model += r2(input="", output=IOKey(connections=["input1", r1.input])) + model += r2(input="", output=IOKey(connections={"input1", r1.input})) assert ( r1.input.data.metadata @@ -4260,7 +4261,7 @@ def test_connect_10(): model += r1(input="input2", output=IOKey(name="output2")) model += r2( input="", - output=IOKey(connections=["input1", "input2"], expose=True, name="internal"), + output=IOKey(connections={"input1", "input2"}, expose=True, name="internal"), ) assert ( @@ -4292,7 +4293,7 @@ def test_connect_12(): model += add2(left="l3", right="l4", output=IOKey(name="out2")) model += add3( - left=IOKey(name="left", connections=[add1.left, add2.left]), + left=IOKey(name="left", connections={add1.left, add2.left}), right="right", output=IOKey(name="out3"), ) @@ -4310,7 +4311,7 @@ def test_connect_13(): buf = Buffer() model += add1(left="l1", right="l2", output=IOKey(name="out1")) model += add2(left="l3", right="l4") - model += buf(input=IOKey(name="input", connections=[add1.left, add2.left])) + model += buf(input=IOKey(name="input", connections={add1.left, add2.left})) model += Add()(left=add2.output, right=buf.output, output=IOKey(name="out2")) assert model._input_keys == {"input", "l2", "l4"} @@ -4334,7 +4335,7 @@ def test_connect_error_1(): with pytest.raises(Exception) as error_info: model += Relu()( input="input", - output=IOKey(name="my_input", connections=["input1", "input2", "output3"]), + output=IOKey(name="my_input", connections={"input1", "input2", "output3"}), ) assert ( @@ -4353,7 +4354,7 @@ def test_connect_error_2(): with pytest.raises(KeyError) as error_info: model += Relu()( input=IOKey( - name="my_input", connections=["input1", "input2", "output3", "output4"] + name="my_input", connections={"input1", "input2", "output3", "output4"} ) ) @@ -4370,7 +4371,7 @@ def test_connect_error_5(): with pytest.raises(KeyError) as error_info: model_2 += Relu()( - output=IOKey(expose=True, connections=[tanh.input, relu.input]) + output=IOKey(expose=True, connections={tanh.input, relu.input}) ) assert ( @@ -4388,7 +4389,7 @@ def test_connect_error_6(): model += l2(input="input1", weight="w1", output=IOKey(name="output2")) model += l3(input="", output=IOKey(name="output3")) model += l4( - input=IOKey(name="my_output", connections=["input1", "input2", "output3"]) + input=IOKey(name="my_output", connections={"input1", "input2", "output3"}) ) assert ( @@ -4436,9 +4437,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] @@ -5121,7 +5122,7 @@ def test_dependency_map_latent_to_input(): # Add third model which changes name of a latent input and # makes it a real input of the model. - conn = IOKey(name="mean_axis", connections=[mean.axis], expose=True) + conn = IOKey(name="mean_axis", connections={mean.axis}, expose=True) model += (to_tensor := ToTensor())(conn, output="output") # Assert dependency map and connection keys status in model. output: ConnectionData = model.output.data # type: ignore @@ -6676,7 +6677,7 @@ def test_multi_write_7(): model += add1(left="left1", right="right1", output="output1") model += add2(left="left2", right="right2", output="output2") - out = IOKey(connections=[model.output1, model.output2]) # type: ignore + out = IOKey(connections={model.output1, model.output2}) # type: ignore with pytest.raises(KeyError) as err_info: model += Buffer()(input=out, output="output3") @@ -6917,11 +6918,11 @@ def test_extend_with_wrong_values(): def test_cyclic_extend(): - with pytest.raises(KeyError) as error_info1: + with pytest.raises(Exception) as error_info1: model = Model() model += Relu()(input="input1", output="input1") - with pytest.raises(KeyError) as error_info2: + with pytest.raises(Exception) as error_info2: model = Model() model += LogisticRegression()(input="input1", probs_output="input1") @@ -7355,17 +7356,17 @@ def __init__( all_output_shapes = list(output) # Create IOKey shape = and Scalar Input, type = GenericTensorTypes # Note that equation is string - tensor_input = IOKey(shape=all_input_shapes, type=GenericTensorType) - tensor_output = IOKey(shape=all_output_shapes, type=GenericTensorType) - scalar_equation = IOKey(type=str, value=equation) + tensor_input = BaseKey(shape=all_input_shapes, type=GenericTensorType) + tensor_output = BaseKey(shape=all_output_shapes, type=GenericTensorType) + scalar_equation = BaseKey(type=str, value=equation) else: # case where equation is TBD - tensor_input = IOKey(shape=[("Var1", ...)], type=GenericTensorType) - tensor_output = IOKey(shape=[("Var2", ...)], type=GenericTensorType) - scalar_equation = IOKey(type=str) + tensor_input = BaseKey(shape=[("Var1", ...)], type=GenericTensorType) + tensor_output = BaseKey(shape=[("Var2", ...)], type=GenericTensorType) + scalar_equation = BaseKey(type=str) - kwargs: dict[str, IOKey] = { + kwargs: dict[str, BaseKey] = { "output": tensor_output, "input": tensor_input, "equation": scalar_equation, @@ -7439,17 +7440,17 @@ def __init__( all_output_shapes = list(output) # Create TensorType and Scalar Inputs # Note that equation is string - tensor_input = IOKey(shape=all_input_shapes, type=GenericTensorType) - tensor_output = IOKey(shape=all_output_shapes, type=GenericTensorType) - scalar_equation = IOKey(type=str, value=equation) + tensor_input = BaseKey(shape=all_input_shapes, type=GenericTensorType) + tensor_output = BaseKey(shape=all_output_shapes, type=GenericTensorType) + scalar_equation = BaseKey(type=str, value=equation) else: # case where equation is TBD - tensor_input = IOKey(shape=[("Var1", ...)], type=GenericTensorType) - tensor_output = IOKey(shape=[("Var2", ...)], type=GenericTensorType) - scalar_equation = IOKey(type=str) + tensor_input = BaseKey(shape=[("Var1", ...)], type=GenericTensorType) + tensor_output = BaseKey(shape=[("Var2", ...)], type=GenericTensorType) + scalar_equation = BaseKey(type=str) - kwargs: dict[str, IOKey] = { + kwargs: dict[str, BaseKey] = { "output": tensor_output, "input": tensor_input, "equation": scalar_equation, diff --git a/tests/scripts/test_set_types.py b/tests/scripts/test_set_types.py index 9fdcb46f..12397c79 100644 --- a/tests/scripts/test_set_types.py +++ b/tests/scripts/test_set_types.py @@ -200,7 +200,7 @@ def test_types_iokey_3(): output=IOKey(name="output2", type=float | int), ) - conn = IOKey("sub", connections=[buffer_model1.input, buffer_model2.input]) + conn = IOKey("sub", connections={buffer_model1.input, buffer_model2.input}) buffer_model3 = Buffer() diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index b18dc97e..1763be6d 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -27,6 +27,7 @@ AND, DNF, NOT_GIVEN, + BaseKey, Connection, ConnectionType, Equivalences, @@ -559,7 +560,7 @@ def test_shapes_4(): input="", weight="weight1", output=IOKey(name="output2") ) model += Linear(dimension=71)( - input="input", weight="weight2", output=IOKey(connections=[l1.input, l2.input]) + input="input", weight="weight2", output=IOKey(connections={l1.input, l2.input}) ) shapes = {"input": [4, 256]} logical_ref: Mapping[str, list | None] = { @@ -2257,7 +2258,7 @@ def test_composite_3_extend_shapes_1(): m3 = Model() m3 += m2(right=IOKey(name="right")) m3 += Add()( - left=IOKey(name="left", connections=[m1.left], expose=True), # type: ignore + left=IOKey(name="left", connections={m1.left}, expose=True), # type: ignore right=m2.output, # type: ignore output=IOKey(name="output"), ) # type: ignore @@ -2931,8 +2932,8 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -2948,8 +2949,8 @@ class Model2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), - output=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + output=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -2967,10 +2968,10 @@ class Model3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="concat", - input1=IOKey(shape=["u1", "u2", "u3"], type=GenericTensorType), - input2=IOKey(shape=["u3", "u2", "u1"], type=GenericTensorType), - output=IOKey(shape=["u1", ("Var1", ...), "u3"], type=GenericTensorType), - axis=IOKey(type=int), + input1=BaseKey(shape=["u1", "u2", "u3"], type=GenericTensorType), + input2=BaseKey(shape=["u3", "u2", "u1"], type=GenericTensorType), + output=BaseKey(shape=["u1", ("Var1", ...), "u3"], type=GenericTensorType), + axis=BaseKey(type=int), ) def __call__( # type: ignore[override] @@ -2989,8 +2990,8 @@ class Model4(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), 1], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), 1], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3007,9 +3008,9 @@ class Model5(PrimitiveModel): def __init__(self, axis=None) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - axis=IOKey(type=NoneType | list[int], value=axis), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + axis=BaseKey(type=NoneType | list[int], value=axis), ) def __call__( # type: ignore[override] @@ -3028,8 +3029,8 @@ class Model6(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + input=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3045,8 +3046,8 @@ class Model7(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3062,8 +3063,8 @@ class Model8(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3079,8 +3080,8 @@ class Model9(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["u2", "u1", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["u2", "u1", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3968,8 +3969,8 @@ class MyVariadic1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3985,8 +3986,8 @@ class MyVariadic2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4002,8 +4003,8 @@ class MyVariadic3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4019,8 +4020,8 @@ class MyVariadic4(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4036,8 +4037,8 @@ class MyVariadic5(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4053,8 +4054,8 @@ class MyVariadic6(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", "a"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4070,8 +4071,8 @@ class MyVariadic7(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), - output=IOKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), + output=BaseKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4092,27 +4093,27 @@ class MyVariadic8(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey( + input1=BaseKey( shape=["u1", "u2", "u3", ("Var1", ...)], type=GenericTensorType ), - input2=IOKey( + input2=BaseKey( shape=["u4", "u5", ("Var2", ...), "u6"], type=GenericTensorType ), - input3=IOKey( + input3=BaseKey( shape=["u7", ("Var3", ...), "u8", "u9"], type=GenericTensorType ), - input4=IOKey( + input4=BaseKey( shape=[("Var4", ...), "u10", "u11", "u12"], type=GenericTensorType ), - input5=IOKey( + input5=BaseKey( shape=[("Var5", ...), "u13", "u14", "u15", "u16"], type=GenericTensorType, ), - input6=IOKey( + input6=BaseKey( shape=["u17", "u18", ("Var6", ...), "u19", "u20"], type=GenericTensorType, ), - output=IOKey( + output=BaseKey( shape=["u13", ("Var1", ...), "u14", "u15", "u16"], type=GenericTensorType, ), @@ -4151,10 +4152,10 @@ class MyVariadic9(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=[("Var2", ...), "u2"], type=GenericTensorType), - input3=IOKey(shape=["u3", ("Var3", ...), "u4"], type=GenericTensorType), - output=IOKey(shape=["u5", "u5"], type=GenericTensorType), + input1=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=[("Var2", ...), "u2"], type=GenericTensorType), + input3=BaseKey(shape=["u3", ("Var3", ...), "u4"], type=GenericTensorType), + output=BaseKey(shape=["u5", "u5"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4181,16 +4182,16 @@ class MyVariadic10(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["u1", "u2", ("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), - input3=IOKey(shape=[("Var3", ...), "u5", "u6"], type=GenericTensorType), - input4=IOKey( + input1=BaseKey(shape=["u1", "u2", ("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), + input3=BaseKey(shape=[("Var3", ...), "u5", "u6"], type=GenericTensorType), + input4=BaseKey( shape=["u7", "u8", ("Var4", ...), "u9", "u10"], type=GenericTensorType ), - input5=IOKey( + input5=BaseKey( shape=["u11", ("Var4", ...), "u12", "u13"], type=GenericTensorType ), - output=IOKey(shape=["u5", "u5"], type=GenericTensorType), + output=BaseKey(shape=["u5", "u5"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4222,8 +4223,8 @@ class MyVariadic11(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4239,8 +4240,10 @@ class MyVariadic12(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", "b", "c", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey( + shape=["a", "b", "c", ("Var1", ...)], type=GenericTensorType + ), ) def __call__( # type: ignore[override] @@ -5131,9 +5134,9 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - input2=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input2=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), ) model = MyModel() @@ -5163,9 +5166,9 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - input2=IOKey(shape=["b", "c"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input2=BaseKey(shape=["b", "c"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), ) model = MyModel() @@ -5191,8 +5194,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - output=IOKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + output=BaseKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5232,8 +5235,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - output=IOKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + output=BaseKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5549,10 +5552,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "a", "b", "c"], type=GenericTensorType ), - output=IOKey( + output=BaseKey( shape=["c", ("Var1", ...), "a", "b"], type=GenericTensorType ), ) @@ -5596,10 +5599,12 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - input2=IOKey(shape=["_a", ("Var1", ...), "_b"], type=GenericTensorType), - input3=IOKey(shape=["b", "c"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input2=BaseKey( + shape=["_a", ("Var1", ...), "_b"], type=GenericTensorType + ), + input3=BaseKey(shape=["b", "c"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5646,8 +5651,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5689,10 +5694,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "u1", "u2", "u3"], type=GenericTensorType ), - output=IOKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5730,10 +5735,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "u1", "u2", "u3"], type=GenericTensorType ), - output=IOKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -6902,8 +6907,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["b", "c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["b", "c", "d"], type=GenericTensorType), ) model = MyModel() @@ -6921,8 +6926,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey( + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey( shape=[("Var1", ...), "c", "d", "e"], type=GenericTensorType ), ) @@ -6946,10 +6951,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "c", "d", "e"], type=GenericTensorType ), - output=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), ) model = Model() @@ -7406,8 +7411,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "b"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7439,9 +7444,9 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=[("Var1", ...), "b"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=[("Var1", ...), "b"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7461,7 +7466,7 @@ def __call__( # type: ignore[override] shapes: dict[str, list] = {"input": ["a", ("Var1", ...)]} buff_model.set_shapes(shapes) model += test_model - con = IOKey(connections=[test_model.input2, buff_model.input]) # type: ignore + con = IOKey(connections={test_model.input2, buff_model.input}) # type: ignore model += Buffer()(input=con, output=IOKey(name="output")) all_nodes = get_all_nodes(model) @@ -7659,8 +7664,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[5, 5], type=GenericTensorType), - output=IOKey(shape=[5, 5], type=GenericTensorType), + input=BaseKey(shape=[5, 5], type=GenericTensorType), + output=BaseKey(shape=[5, 5], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7689,8 +7694,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["a", "b"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["a", "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7718,8 +7723,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("V1", ...), "b", "c"], type=GenericTensorType), - output=IOKey( + input=BaseKey( + shape=["a", ("V1", ...), "b", "c"], type=GenericTensorType + ), + output=BaseKey( shape=["c", ("V1", ...), "a", "b"], type=GenericTensorType ), ) @@ -7756,8 +7763,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", "c"], type=GenericTensorType), - output=IOKey(shape=["b", "c", "a"], type=GenericTensorType), + input=BaseKey(shape=["a", "b", "c"], type=GenericTensorType), + output=BaseKey(shape=["b", "c", "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7792,8 +7799,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[], type=GenericTensorType), - output=IOKey(shape=[], type=GenericTensorType), + input=BaseKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -9655,7 +9662,7 @@ def test_connect_shapes(): model = Model() model += relu1(input="") model += relu2(input="") - model += relu3(input="input", output=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input="input", output=IOKey(connections={relu1.input, relu2.input})) assert model.shapes["input"] == [5, 7] diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index fcee738d..9e49c16f 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -186,7 +186,7 @@ def test_extract_logical_connections_4(): output3=IOKey(name="out_3"), ) model += model_2( - output1=IOKey(connections=[model_1.input1, model_1.input2]), # type: ignore + output1=IOKey(connections={model_1.input1, model_1.input2}), # type: ignore output2=IOKey(name="out_4"), output3=IOKey(name="out_5"), input1="in1", @@ -1487,7 +1487,7 @@ def test_logical_model_summary_9(): model = Model() add_1, add_2 = Add(), Add() model += add_1(left="left") - model += add_2(output=IOKey(connections=[add_1.left, add_1.right]), left="left_1") + model += add_2(output=IOKey(connections={add_1.left, add_1.right}), left="left_1") with redirect_stdout(StringIO()) as summary: model.summary(shapes=True, symbolic=True) @@ -1535,13 +1535,13 @@ def test_logical_model_summary_11(): ) model_n += model_2( input1="", - output1=IOKey(connections=[model_3.input1, model_3.input2, model_3.input3]), # type: ignore + output1=IOKey(connections={model_3.input1, model_3.input2, model_3.input3}), # type: ignore output2=IOKey(name="output4"), output3=IOKey(name="output5"), ) model_n += model_1( input1="", - output1=IOKey(connections=[model_2.input1, model_2.input2, model_2.input3]), # type: ignore + output1=IOKey(connections={model_2.input1, model_2.input2, model_2.input3}), # type: ignore ) with redirect_stdout(StringIO()) as summary: @@ -1581,7 +1581,7 @@ def test_logical_model_summary_12(): input1=model_1.output1, # type: ignore input2=model_1.output2, # type: ignore input3=model_1.output3, # type: ignore - output1=IOKey(connections=[model_3.input1, model_3.input2, model_3.input3]), # type: ignore + output1=IOKey(connections={model_3.input1, model_3.input2, model_3.input3}), # type: ignore output2=IOKey(name="output4"), output3=IOKey(name="output5"), ) diff --git a/tests/scripts/test_tuple_list_args_in_extend.py b/tests/scripts/test_tuple_list_args_in_extend.py index 48846f75..87de7266 100644 --- a/tests/scripts/test_tuple_list_args_in_extend.py +++ b/tests/scripts/test_tuple_list_args_in_extend.py @@ -80,7 +80,7 @@ def test_tuple_argument_3(): add_model_2 = Add() model += add_model(left="left", right="right") model += add_model_2( - left=(add_model.left.shape(), add_model.right.shape()), + left=(add_model.left.shape, add_model.right.shape), right=add_model.left + add_model.right, output="output", ) @@ -110,7 +110,7 @@ def test_tuple_argument_4(): add_model_2 = Add() model += add_model(left="left", right="right") model += add_model_2( - left=(add_model.left.shape() * 2, add_model.right.shape() * 2), + left=(add_model.left.shape * 2, add_model.right.shape * 2), right=add_model.left + add_model.right, output="output", ) @@ -138,8 +138,8 @@ def test_tuple_argument_5(): model += add_model(left="left", right="right") model += add_model_2( left=( - (add_model.left.shape()[0], add_model.left.shape()[0]), - (add_model.left.shape()[0], add_model.left.shape()[0]), + (add_model.left.shape[0], add_model.left.shape[0]), + (add_model.left.shape[0], add_model.left.shape[0]), ), right=add_model.left + add_model.right, output="output", @@ -168,8 +168,8 @@ def test_list_tuple_mixed_argument_1(): model += add_model(left="left", right="right") model += add_model_2( left=( - [add_model.left.shape()[0], add_model.left.shape()[0]], - [add_model.left.shape()[0], add_model.left.shape()[0]], + [add_model.left.shape[0], add_model.left.shape[0]], + [add_model.left.shape[0], add_model.left.shape[0]], ), right=add_model.left + add_model.right, output="output", @@ -197,8 +197,8 @@ def test_list_tuple_mixed_argument_2(): model += add_model(left="left", right="right") - left_first_shape = add_model.left.shape()[0] - right_first_shape = add_model.right.shape()[0] + left_first_shape = add_model.left.shape[0] + right_first_shape = add_model.right.shape[0] matmul_left = ([left_first_shape, 0], [2, right_first_shape]) @@ -296,7 +296,7 @@ def test_list_argument_3(): model += add_model(left="left", right="right") model += add_model_2( - left=[add_model.left.shape(), add_model.right.shape()], + left=[add_model.left.shape, add_model.right.shape], right=add_model.left + add_model.right, output="output", ) @@ -327,7 +327,7 @@ def test_list_argument_4(): model += add_model(left="left", right="right") model += add_model_2( - left=[add_model.left.shape() * 2, add_model.right.shape() * 2], + left=[add_model.left.shape * 2, add_model.right.shape * 2], right=add_model.left + add_model.right, output="output", ) @@ -356,8 +356,8 @@ def test_list_argument_5(): model += add_model(left="left", right="right") model += add_model_2( left=[ - [add_model.left.shape()[0], add_model.left.shape()[0]], - [add_model.left.shape()[0], add_model.left.shape()[0]], + [add_model.left.shape[0], add_model.left.shape[0]], + [add_model.left.shape[0], add_model.left.shape[0]], ], right=add_model.left + add_model.right, output="output", diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 72fe06f4..31a725ce 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -22,6 +22,7 @@ from mithril import JaxBackend, NumpyBackend, TorchBackend, compile from mithril.framework.common import ( NOT_GIVEN, + BaseKey, Connection, ConnectionType, GenericTensorType, @@ -115,7 +116,7 @@ def test_scalar_to_tensor_2(): lin_2 = Linear(dimension=2) model += lin_1(input="input_1", weight="w_1", bias="b_1") model += lin_2(input="input_2", weight="w_2", bias="b_2") - shp_1 = lin_1.input.shape() + shp_1 = lin_1.input.shape reshaped_1 = lin_2.output.reshape(shp_1) to_tensor = ToTensor() model += to_tensor(input=shp_1) @@ -128,7 +129,7 @@ def test_scalar_to_tensor_2(): lin_4 = Linear(dimension=2) model += lin_3(input="input_1", weight="w_1", bias="b_1") model += lin_4(input="input_2", weight="w_2", bias="b_2") - shp_2 = lin_3.input.shape() + shp_2 = lin_3.input.shape reshaped_2 = lin_4.output.reshape(shp_2) model += Add()(left=shp_2.tensor(), right=reshaped_2, output="output") model_2 = model @@ -184,7 +185,7 @@ def test_scalar_to_tensor_3(): def test_tensor_to_scalar_1(): - """Model enforces Jit so we reshape with to_tensor_1_output.shape(). + """Model enforces Jit so we reshape with to_tensor_1_output.shape. We can not directly reshape with to_tensor_1_output which is valued as [2, 1] in tensor domain since it requires TensorToList conversion before being argument to reshape method. @@ -198,7 +199,7 @@ def test_tensor_to_scalar_1(): model += to_tensor_1(input=[2, 1]) model += to_tensor_2(input=[[1, 1]]) model += add_1(left=to_tensor_1.output, right=to_tensor_2.output) - reshaped_1 = add_1.output.reshape(to_tensor_1.output.shape()) + reshaped_1 = add_1.output.reshape(to_tensor_1.output.shape) model += Buffer()(input=reshaped_1, output="output") model_1 = model @@ -208,7 +209,7 @@ def test_tensor_to_scalar_1(): left = IOKey(value=[2, 1]).tensor() right = IOKey(value=[1, 1]).tensor() model += add_2(left=left, right=right) - reshaped_2 = add_2.output.reshape(add_2.left.shape()) + reshaped_2 = add_2.output.reshape(add_2.left.shape) model += Buffer()(input=reshaped_2, output="output") model_2 = model @@ -281,7 +282,7 @@ def test_slice_item_conversions(): model = Model() lin_2 = Linear(dimension=1) model += lin_2(input="input", weight="w", bias="b") - shp2 = lin_2.input.shape() + shp2 = lin_2.input.shape shp2_1 = shp2[1] assert shp2_1 is not None shp_item = shp2_1.tensor() @@ -306,7 +307,7 @@ def test_tuple_conversion_1(): model = Model() lin_1 = Linear(dimension=2) model += lin_1(input="input", weight="w", bias="b") - shp1 = lin_1.output.shape() + shp1 = lin_1.output.shape model += ToTensor()(input=(shp1[0], shp1[1]), output="output") model_1 = model @@ -315,7 +316,7 @@ def test_tuple_conversion_1(): lin_2 = Linear(dimension=2) tupl = ToTuple(n=2) model += lin_2(input="input", weight="w", bias="b") - shp2 = lin_2.output.shape() + shp2 = lin_2.output.shape model += tupl(input1=shp2[0], input2=shp2[1]) model += ToTensor()(input=tupl.output, output="output") # type: ignore model_2 = model @@ -336,7 +337,7 @@ def test_tuple_conversion_2(): lin_1 = Linear(dimension=2) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.shape model += tt1(input=(shp1[0], shp1[1])) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -382,7 +383,7 @@ def test_tuple_conversion_3(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.shape model += tt1(input=(shp1[0], shp1[1], 3)) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -428,7 +429,7 @@ def test_list_conversion_1(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.shape model += tt1(input=[shp1[0], shp1[1], 3.0]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -473,7 +474,7 @@ def test_nested_list_conversion_1(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.shape model += tt1(input=[[shp1[0], shp1[1], 3.0]]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -518,7 +519,7 @@ def test_nested_list_conversion_2(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input="input", weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.shape model += tt1(input=[[shp1[0], shp1[1], 3.0]]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -661,8 +662,8 @@ class ArtificialPrimitive(PrimitiveModel): def __init__(self, type) -> None: super().__init__( formula_key="tensor_to_list", - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var2", ...)], type=type), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var2", ...)], type=type), ) self._set_constraint( fn=self.artificial_constraint, keys=[PrimitiveModel.output_key, "input"] @@ -815,8 +816,8 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -837,7 +838,7 @@ def test_connect_type_conv_handling_1(): model.extend((a1 := Buffer()), input="input1") model.extend((a2 := Buffer()), input="input2") con_object = IOKey( - name="abcd", connections=[a1.input, a2.input], value=[[2.0]], expose=True + name="abcd", connections={a1.input, a2.input}, value=[[2.0]], expose=True ) model.extend( mat_mul := MatrixMultiply(), left=con_object, output=IOKey(name="output") @@ -850,7 +851,7 @@ def test_connect_type_conv_handling_1(): model.extend((a1 := Buffer()), input="input1") model.extend((a2 := Buffer()), input="input2") con_object = IOKey( - connections=["input1", "input2"], value=[[2.0]], name="abcd", expose=True + connections={"input1", "input2"}, value=[[2.0]], name="abcd", expose=True ) model.extend( (mat_mul := MatrixMultiply()), left=con_object, output=IOKey(name="output") @@ -863,7 +864,7 @@ def test_connect_type_conv_handling_1(): model.extend((a1 := Buffer()), input="input1") model.extend((a2 := Buffer()), input="input2") con_object = IOKey( - connections=["input1", a2.input], value=[[2.0]], name="abcd", expose=True + connections={"input1", a2.input}, value=[[2.0]], name="abcd", expose=True ) model.extend( (mat_mul := MatrixMultiply()), left=con_object, output=IOKey(name="output") @@ -894,8 +895,8 @@ def test_connect_1(): model += concat_model( input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True) model += Sigmoid()(input=conn, output=IOKey(name="output1")) assert ( @@ -916,8 +917,8 @@ def test_connect_2(): model += concat_model( input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True) model += ToTensor()(conn) @@ -938,8 +939,8 @@ def test_connect_3(): model += concat_model( input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True, value=3.0) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True, value=3.0) model += (to_tensor := ToTensor())(conn) @@ -971,13 +972,13 @@ def test_connect_4(): input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) model += union_model(input1="") - conn_list = [ + conns = { concat_model.input1, # type: ignore concat_model.input2, # type: ignore concat_model.input3, # type: ignore union_model.input1.tensor(), # type: ignore - ] - conn = IOKey(connections=conn_list, name="abcd", expose=True, value=(3, 2)) + } + conn = IOKey(connections=conns, name="abcd", expose=True, value=(3, 2)) model += Buffer()(input=conn, output=IOKey(name="output1")) pm = compile(model=model, backend=backend, jit=False, inference=True) @@ -1001,8 +1002,8 @@ def test_connect_6(): model = Model() concat_model = Concat(n=3) model += concat_model(input1=[[3.0]], output=IOKey(name="output")) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True) model += Buffer()(input=conn, output=IOKey(name="output1")) @@ -1032,7 +1033,7 @@ def test_connect_7(): model += add_model_2(left="left1", right="right1") conn = IOKey( - connections=[add_model_2.output, model.right], # type: ignore + connections={add_model_2.output, model.right}, # type: ignore name="abcd", expose=False, ) @@ -1082,8 +1083,8 @@ def test_connect_7_expose_output(): add_model_2 = Add() model += add_model_1(left="left", right="right", output=IOKey(name="output2")) model += add_model_2(left="left1", right="right1") - conn_list = [add_model_2.output, model.right] # type: ignore - conn = IOKey(name="abcd", expose=True, connections=conn_list) # type: ignore + conns = {add_model_2.output, model.right} # type: ignore + conn = IOKey(name="abcd", expose=True, connections=conns) # type: ignore model += (buf := Buffer())(input=conn, output=IOKey(name="output")) assert ( @@ -1139,7 +1140,7 @@ def test_connect_8(): left=add_model_1.output, right="right1", output=IOKey(name="output1") ) conn = IOKey( - connections=[add_model_1.output, model.right1], # type: ignore + connections={add_model_1.output, model.right1}, # type: ignore name="abcd", expose=False, ) @@ -1176,8 +1177,8 @@ def test_connect_9(): model = Model() concat_model = Concat(n=3) model += concat_model(input1=[[3.0]], input2=[[2.0]], input3="input3") - con_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=con_list) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns) with pytest.raises(ValueError) as err_info: model += Buffer()(input=conn, output=IOKey(name="output")) @@ -1193,8 +1194,8 @@ def test_connect_10(): model = Model() concat_model = Concat(n=3) model += concat_model(input1=[[3.0]], input3="input3") - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, value=2.0, expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, value=2.0, expose=True) with pytest.raises(ValueError) as err_info: model += Buffer()(input=conn, output=IOKey(name="output")) @@ -1219,13 +1220,13 @@ def test_connect_11(): union_model = PrimitiveUnion(n=2) model += concat_model(input1="", output=IOKey(name="output1")) model += union_model(input1="", output=IOKey(name="output2")) - conn_list = [ + conns = { concat_model.input1, # type: ignore concat_model.input2, # type: ignore union_model.input1, # type: ignore union_model.input2, # type: ignore - ] - conn = IOKey(connections=conn_list, value=(2.0,), expose=True) + } + conn = IOKey(connections=conns, value=(2.0,), expose=True) model += Buffer()(input=conn, output=IOKey(name="output3")) pm = compile(model=model, backend=backend, jit=False) @@ -1256,12 +1257,12 @@ def test_connect_12(): model += concat_model(input1="", output=IOKey(name="output1")) model += union_model(input1="", output=IOKey(name="output2")) conn = IOKey( - connections=[ + connections={ concat_model.input1, # type: ignore concat_model.input2, # type: ignore union_model.input1, # type: ignore union_model.input2, # type: ignore - ], + }, value=(2.0,), ) model += Buffer()(input=conn, output=IOKey(name="output3")) @@ -1314,7 +1315,7 @@ def test_tensor_to_scalar_connect_1(): axis2 = mean_model_2.axis axis3 = mean_model_3.axis - con = IOKey(connections=[axis1, axis2, axis3], name="axis4", value=(2, 3)) + con = IOKey(connections={axis1, axis2, axis3}, name="axis4", value=(2, 3)) model += Mean(axis=TBD)(axis=con) assert axis1.data.metadata == axis2.data.metadata == axis3.data.metadata @@ -1335,7 +1336,7 @@ def test_tensor_to_scalar_connect_3_error_existing_key(): model += mean_model_2(axis="axis2") model += mean_model_3(axis="axis3") - con = IOKey(connections=[axis1, axis2, axis3], name="axis2", value=(2, 3)) + con = IOKey(connections={axis1, axis2, axis3}, name="axis2", value=(2, 3)) model += Mean(axis=TBD)(axis=con) @@ -1510,7 +1511,7 @@ def test_tensor_to_scalar_template_1(): model += buff_model_1(input="input1") in1 = buff_model_1.output - out1 = in1.shape().tensor() ** 2 + out1 = in1.shape.tensor() ** 2 model += Buffer()(input=out1, output="output") model.set_shapes({"input1": [3, 4, 5, 6]}) @@ -1533,7 +1534,7 @@ def test_tensor_to_scalar_template_2(): in1 = buff_model_1.output in2 = buff_model_2.output in3 = buff_model_3.output - out1 = (in1.shape().tensor() ** 2 * in2) @ in3 / 2 + out1 = (in1.shape.tensor() ** 2 * in2) @ in3 / 2 model += Buffer()(input=out1, output="output") pm = compile(model=model, backend=backend) diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 5a6eb966..c415b71d 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -28,6 +28,7 @@ ) from mithril.models import ( TBD, + BaseKey, Convolution2D, ExtendInfo, IOKey, @@ -49,9 +50,9 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="None", - input1=IOKey(type=tuple[int, ...]), - input2=IOKey(type=list[float]), - output=IOKey(type=tuple[tuple[int, ...]]), + input1=BaseKey(type=tuple[int, ...]), + input2=BaseKey(type=list[float]), + output=BaseKey(type=tuple[tuple[int, ...]]), ) def __call__( # type: ignore[override] @@ -68,10 +69,10 @@ class Model2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="None", - input1=IOKey(type=int | float), - input2=IOKey(type=int | str), - input3=IOKey(type=str | float), - output=IOKey(type=tuple[int | float, int | float, int | float]), + input1=BaseKey(type=int | float), + input2=BaseKey(type=int | str), + input3=BaseKey(type=str | float), + output=BaseKey(type=tuple[int | float, int | float, int | float]), ) def __call__( # type: ignore[override] @@ -94,19 +95,21 @@ class Model3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="None", - input1=IOKey( + input1=BaseKey( type=tuple[tuple[int | float, ...], ...] | list[int | float] | tuple[int, int, int, int] ), - input2=IOKey(type=list[int] | tuple[int, ...] | tuple[tuple[int | float]]), - input3=IOKey( + input2=BaseKey( + type=list[int] | tuple[int, ...] | tuple[tuple[int | float]] + ), + input3=BaseKey( type=list[tuple[int | tuple[float | int]]] | int | float | tuple[int | float, ...] ), - output=IOKey(type=int | float | str | tuple[int, int]), + output=BaseKey(type=int | float | str | tuple[int, int]), ) def __call__( # type: ignore[override] @@ -387,7 +390,7 @@ def test_type_16(): with pytest.raises(TypeError) as err_info: model += sig_model_2( - input=IOKey(connections=[sig_model_1.input], value=[False, True]), + input=IOKey(connections={sig_model_1.input}, value=[False, True]), output=IOKey(name="output2"), ) assert str(err_info.value) == ( @@ -406,7 +409,7 @@ def test_type_17(): model.extend( sig_model_2, input=IOKey( - connections=[sig_model_1.input], + connections={sig_model_1.input}, value=[False, True], name="a", expose=True,