From d825c37eefe72b626c77b11dfcbc5c66159fb3b6 Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Fri, 20 Dec 2024 16:08:37 +0300 Subject: [PATCH 1/9] refactor extend, remove redundant methods, update corresponding tests --- mithril/framework/logical/base.py | 4 +- mithril/framework/logical/model.py | 357 +++++++++-------------- mithril/models/models.py | 112 ++++--- tests/scripts/test_data_store.py | 8 +- tests/scripts/test_io_key.py | 3 +- tests/scripts/test_key_values_in_init.py | 14 +- tests/scripts/test_scripts.py | 33 +-- 7 files changed, 242 insertions(+), 289 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 44a8e44e..7a6210e6 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -110,7 +110,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: f"Given IOKey for local key: '{key}' is not valid!" ) else: - conns = [ + _conns: list[Connection | str] = [ item.conn if isinstance(item, ConnectionData) else item for item in con._connections ] @@ -120,7 +120,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: shape=con._shape, type=con._type, expose=con._expose, - connections=conns, + connections=_conns, ) case ExtendTemplate(): raise ValueError( diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 8028ffcb..3bdfaa71 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections.abc import Mapping -from types import EllipsisType, NoneType, UnionType +from types import UnionType from typing import Any, Self from ...utils.utils import OrderedSet, find_dominant_type @@ -34,6 +34,7 @@ MainValueInstance, MainValueType, NestedListType, + NotAvailable, NullConnection, Scalar, ShapeTemplateType, @@ -295,119 +296,160 @@ def _check_multi_write( "Multi-write error!" ) + def _convert_to_iokey( + self, model: BaseModel, key: str, connection: ConnectionType + ) -> IOKey: + is_input = key in model._input_keys + local_connection = model.conns.get_connection(key) + assert local_connection is not None, "Connection is not found!" + not_valued = local_connection.metadata.data.value is TBD + match connection: + case NullConnection(): + connection = IOKey(expose=is_input and not_valued) + case str(): + expose = None + if self.conns.get_connection(connection) is None: + expose = is_input and not_valued + connection = IOKey(name=connection, expose=expose) + case Connection(): + connection = IOKey(connections=[connection], expose=None) + case ExtendTemplate(): + # Unroll ExtendTemplate + template_conn = model.conns.get_connection(key) + assert template_conn is not None, "Connection type is not found!" + connection = self._unroll_template( + connection, type(template_conn.metadata.data) + ) + connection = IOKey(connections=[connection.conn], expose=False) + case _ if isinstance(connection, MainValueInstance): + # find_dominant_type returns the dominant type in a container. + # If a container has a value of type Connection or ExtendTemplate + # we add necessary models. + if isinstance(connection, tuple | list) and find_dominant_type( + connection, raise_error=False + ) in [ConnectionData, ExtendTemplate, Connection, IOKey]: + kwargs = { + f"input{idx + 1}": item for idx, item in enumerate(connection) + } + connection_model = ( + ToTuple if isinstance(connection, tuple) else ToList + ) + conv_model = connection_model(n=len(connection)) + self.extend(conv_model, **kwargs) + + result = conv_model.conns.get_connection("output") + assert result is not None + connection = IOKey(connections=[result.conn], expose=None) + else: + assert isinstance(connection, MainValueInstance) + expose = None + if not not_valued: + expose = False + connection = IOKey(value=connection, expose=expose) + case IOKey(): + expose = connection._expose + name = connection._name + # TODO: This check should be removed: conn._connections==OrderedSet([]) + # We should not operate different if _connections is given. Fix this and + # also fix corresponding tests and dict conversions with "connect". + if ( + expose is None + and (name is None or self.conns.get_connection(name) is None) + and connection._connections == OrderedSet([]) + ): + expose = True + _conns: list[Connection | str] = [ + item.conn if isinstance(item, ConnectionData) else item + for item in connection._connections + ] + # TODO: Add replicate method to IOKey (update def __call__ in BaseModel) + connection = IOKey( + name=name, + value=connection._value, + shape=connection._shape, + type=connection._type, + expose=expose, + connections=_conns, + ) + case NotAvailable(): + raise ValueError( + f"Given value for key: '{key}' is not available. " + "Probably Canonical input/output connections are used, " + "but the model canonical connections is not determined. Please " + "provide connection/key explicitly, or set canonical connections." + ) + case _: + raise TypeError("Requires valid connection type!") + + return connection + + def update_key_name(self, connection: ConnectionData, key: str) -> None: + for key_type in KeyType: + key_dict = self.conns._connection_dict[key_type] + if key_dict.get(connection.key) is not None: + key_dict[key] = key_dict.pop(connection.conn.key) + # Update connection key + connection.key = key + connection.is_key_autogenerated = False + connection.metadata.key_origin = key + setattr(self, key, connection.conn) + break + def _add_connection( self, model: BaseModel, local_key: str, - given_connection: ConnectionType, - expose: bool | None = None, + given_connection: IOKey, ) -> tuple[ConnectionData, Updates]: updates = Updates() - outer_key, con_obj = None, None is_input = local_key in model._input_keys local_connection = model.conns.get_connection(local_key) assert local_connection is not None, "Connection is not found!" - # Flags for use in required operations. - create_connection = None - set_value: ToBeDetermined | str | MainValueType | NullConnection = NOT_GIVEN - match_connection = None - d_map = self.dependency_map._local_output_dependency_map - if isinstance( - given_connection, MainValueInstance | NullConnection - ): # or given_connection == NOT_GIVEN: - # Immediate values can be provided only for inputs. - # if given_connection != NOT_GIVEN: - if isinstance(given_connection, MainValueInstance): - set_value = given_connection - - if expose is None: - expose = is_input - create_connection = True - - elif isinstance(given_connection, str): - # Connection is given as str. - outer_key = given_connection - if (con_obj := self.conns.get_connection(given_connection)) is None: - expose = is_input - create_connection = True # Create new connection. - else: - # Connection match is required. - match_connection = True - - elif isinstance(given_connection, ConnectionData): - # Connection is given as a Connection object. - # TODO: maybe use directly connections - if ( - con_obj := self.conns.get_con_by_metadata(given_connection.metadata) - ) is None: - raise KeyError("Requires accessible connection to be processed!") - - if given_connection in model.conns.all.values(): - raise ValueError( - f"Given connection '{given_connection.key}' should not belong " - "to the extending model!" - ) + d_map = self.dependency_map._local_output_dependency_map + expose = given_connection._expose + outer_key = given_connection._name + con_obj = None + set_value: ToBeDetermined | str | MainValueType | NullConnection = NOT_GIVEN + if given_connection._value is not TBD: + set_value = given_connection._value - outer_key = con_obj.key - expose = outer_key in self.conns.output_keys and not is_input - match_connection = True - elif isinstance( - given_connection, IOKey - ) and given_connection._connections == OrderedSet([]): - outer_key = given_connection._name - if (expose := given_connection._expose) is None: - expose = True - if outer_key is None or self.conns.get_connection(outer_key) is None: - create_connection = True # Create new connection. - else: + if given_connection._connections == OrderedSet([]): + if outer_key is not None: con_obj = self.conns.get_connection(outer_key) - # Connection match is required. - match_connection = True - if given_connection._value is not TBD: - set_value = given_connection._value + if expose is None and con_obj is None and is_input: + expose = True + if outer_key is None or con_obj is None: + con_obj = self.create_connection(local_connection.metadata, outer_key) if ( - not expose + expose is False and is_input and set_value is NOT_GIVEN + and local_connection.metadata.data.value is TBD and (con_obj is None or con_obj not in d_map) ): raise ValueError( "Expose flag cannot be false when " "no value is provided for input keys!" ) - elif isinstance( - given_connection, IOKey - ) and given_connection._connections != OrderedSet([]): - match_connection = True - expose = given_connection._expose - if given_connection._name is not None: - outer_key = given_connection._name - if given_connection._value is not TBD: - set_value = given_connection._value + elif given_connection._connections != OrderedSet([]): initial_conn: ConnectionData for idx, conn in enumerate(given_connection._connections): if isinstance(conn, str): _conn = self.conns.get_connection(conn) else: _conn = self.conns.get_con_by_metadata(conn.metadata) - assert isinstance(_conn, ConnectionData) + if not isinstance(_conn, ConnectionData): + raise KeyError("Requires accessible connection to be processed!") + elif conn in model.conns.all.values(): + raise ValueError( + f"Given connection '{conn.key}' should not " # type: ignore + "belong to the extending model!" + ) if idx == 0: initial_conn = _conn - # TODO: Convert this to a method of Connections class named set_name if outer_key is not None: - for key_type in KeyType: - key_dict = self.conns._connection_dict[key_type] - if key_dict.get(initial_conn.key) is not None: - key_dict[outer_key] = key_dict.pop( - initial_conn.conn.key - ) - # Update connection key - initial_conn.key = outer_key - initial_conn.is_key_autogenerated = False - initial_conn.metadata.key_origin = outer_key - setattr(self, outer_key, initial_conn.conn) - break + self.update_key_name(initial_conn, outer_key) else: if _conn in d_map: if initial_conn in d_map: @@ -428,15 +470,13 @@ def _add_connection( updates |= self.merge_connections(initial_conn, _conn) if outer_key is None and is_input and initial_conn not in d_map: expose = True - if not outer_key and initial_conn in d_map and expose: + if not outer_key and initial_conn in d_map and expose is True: raise KeyError("Connection without a name cannot be set as output") con_obj = initial_conn - else: - raise TypeError("Requires valid connection type!") - # Name "input" can only be used for input connections. - if not is_input and outer_key == "input": + is_key_name_input = con_obj is not None and con_obj.key == "input" + if not is_input and (outer_key == "input" or is_key_name_input): raise KeyError( "The key 'input' is a reserved key which could not be used for " "internal keys." @@ -448,23 +488,12 @@ def _add_connection( "not be set in extend." ) - # If connection is not created yet, create it. - if create_connection: - con_obj = self.create_connection(local_connection.metadata, outer_key) - # Inherit submodel connections dict self.conns.connections_dict.setdefault(local_connection.metadata, set()) self.conns.connections_dict[local_connection.metadata] |= ( model.conns.connections_dict.pop(local_connection.metadata, set()) ) - # If connection already has a value set expose to False. - if ( - local_connection.metadata.data.value is not TBD - and con_obj not in self.conns.input_connections - and not isinstance(given_connection, IOKey) - ): - expose = False # If any value provided, set. assert con_obj is not None if not isinstance(set_value, NullConnection): @@ -474,7 +503,7 @@ def _add_connection( self._check_multi_write(is_input, local_connection, con_obj) # If match required, perform. - if match_connection is not None: + if con_obj.metadata != local_connection.metadata: local_key_origin = local_connection.metadata.key_origin updates |= self._match_hyper_edges( con_obj.metadata, local_connection.metadata @@ -492,19 +521,11 @@ def _add_connection( # internal based on expose and is_input flag. if is_input: if outer_key not in self._input_keys: - if expose: - # if con_obj in self.conns.internal_connections: + if expose is True: if con_obj in d_map: self.conns.set_connection_type(con_obj, KeyType.OUTPUT) else: self.conns.set_connection_type(con_obj, KeyType.INPUT) - # TODO: We both set given IOKey in handle_auto_conversion and here. - # This causes duality and confusion in self.conns. We need to refactor - # extend to disentangle this problem. We should avoid using - # self._input_keys, self._output_keys, self._latent_input_keys or - # self.conns.all in _add_connection. - - # elif outer_key not in self.conns.all: elif con_obj not in d_map: self.conns.set_connection_type(con_obj, KeyType.LATENT_INPUT) else: @@ -666,46 +687,11 @@ def merge_connections( main_connection2.is_key_autogenerated = main_connection1.is_key_autogenerated return updates - def create_connection_model( - self, - connection: ConnectionType | tuple[ConnectionType, ...] | list[ConnectionType], - ) -> ( - ConnectionData - | ConnectionType - | tuple[ConnectionType, ...] - | list[ConnectionType] - ): - result = connection - - if isinstance(connection, tuple | list): - # find_dominant_type returns the dominant type in a container. - # If a container has a value of type Connection or ExtendTemplate - # we add necessary models. For any type other than Connection, - # ExtendTemplate, float, int or bool, raise TypeError. - dominant_type = find_dominant_type(connection, raise_error=False) - if dominant_type in [ConnectionData, ExtendTemplate, Connection, IOKey]: - kwargs = { - f"input{idx + 1}": item for idx, item in enumerate(connection) - } - connection_model = ToTuple if isinstance(connection, tuple) else ToList - self.extend(conv_model := connection_model(n=len(connection)), **kwargs) - - result = conv_model.conns.get_connection("output") - assert result is not None - elif dominant_type not in [float, int, bool, slice, NoneType, EllipsisType]: - raise TypeError( - f"{dominant_type} type is not supported for conversion in " - "a container!" - ) - - return result - def extend( self, model: Model | PrimitiveModel | BaseModel, **kwargs: ConnectionType, ): - # TODO: kwargs -> ConnectionType # Check possible errors before the extension. model.check_extendability() if self.parent is not None: @@ -729,73 +715,9 @@ def extend( updates = Updates() - input_values: set[str] = set() - output_values: set[str] = set() - shape_info: dict[str, ShapeTemplateType] = {} type_info: dict[str, type | UnionType | NestedListType] = {} - for key, value in kwargs.items(): - # Check if given keys are among model's keys. - if key not in model._input_keys | model.conns.output_keys: - raise KeyError( - f"Given '{key}' key is not an input or output of " - "the model '{model}'" - ) - - # Check proper naming of given keys. - if isinstance(value, str): - if key in model._input_keys and not value.isidentifier(): - raise KeyError( - f"Given key name {value} is not a proper identifier string!" - ) - elif key in model._input_keys: - input_values.add(value) - else: - output_values.add(value) - - elif isinstance(value, Connection): - # Get ConnectionData if value is Connection type. - kwargs[key] = value.data - - elif isinstance(value, ExtendTemplate): - # Unroll ExtendTemplate - template_conn = model.conns.get_connection(key) - assert template_conn is not None, "Connection type is not found!" - kwargs[key] = self._unroll_template( - value, type(template_conn.metadata.data) - ) - - elif isinstance(value, IOKey): - # Hold shape information for IOKey type values in order - # to set all in a bulk after all connections are added. - if value._shape is not None: - shape_info |= {key: value._shape} - - if value._type is not None: - type_info[key] = value._type - - elif isinstance(value, NullConnection): - continue - - elif value is NOT_AVAILABLE: - raise ValueError( - f"Given value for key: '{key}' is not available. " - "Probably Canonical input/output connections are used, " - "but the model canonical connections is not determined. Please " - "provide connection/key explicitly, or set canonical connections." - ) - - if (updated_conn := self.create_connection_model(kwargs[key])) is not None: - kwargs[key] = updated_conn - - # Check if any cycles occur with namings. - if common_keys := input_values.intersection(output_values): - raise KeyError( - f"Given connections: '{[key for key in common_keys]}' are used both " - "in input and output keys, which creates cycle!" - ) - submodel_dag: dict[str, ConnectionData] = {} updates = self.constraint_solver.match(model.constraint_solver) @@ -807,10 +729,19 @@ def extend( ): external_keys.append(model.canonical_output.key) - for local_key in external_keys: - con_obj, _updates = self._add_connection( - model, local_key, kwargs.get(local_key, NOT_GIVEN) - ) + io_keys: dict[str, IOKey] = { + key: self._convert_to_iokey(model, key, kwargs.get(key, NOT_GIVEN)) + for key in external_keys + } + + for local_key, value in io_keys.items(): + if value._shape is not None: + shape_info |= {local_key: value._shape} + + if value._type is not None: + type_info[local_key] = value._type + + con_obj, _updates = self._add_connection(model, local_key, value) updates |= _updates submodel_dag[local_key] = con_obj if isinstance(con_obj.metadata.data, Tensor): diff --git a/mithril/models/models.py b/mithril/models/models.py index 0a59b7ec..b675b994 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -200,7 +200,7 @@ def __init__( stride=stride_conv.output, padding=pad_conv.output, dilation=IOKey(name="dilation", value=dilation), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) self.input.set_differentiable(False) self._freeze() @@ -308,7 +308,7 @@ def __init__( stride=st_converter.output, padding=pt_converter.output, dilation=dt_converter.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) self.input.set_differentiable(False) self._freeze() @@ -401,7 +401,7 @@ def __init__( ) conv_connections: dict[str, ConnectionType] = { - "output": IOKey(name="output"), + "output": IOKey(name="output", expose=True), "input": "input", "weight": "weight", "stride": IOKey(name="stride", value=stride), @@ -502,7 +502,7 @@ def __init__( self += dt_converter(input=IOKey(name="dilation", value=dilation)) conv_connections: dict[str, ConnectionType] = { - "output": IOKey(name="output"), + "output": IOKey(name="output", expose=True), "input": "input", "weight": "weight", "stride": st_converter.output, @@ -562,7 +562,7 @@ def __init__( mult = MatrixMultiply() - output = IOKey(name="output") + output = IOKey(name="output", expose=True) weight_key = IOKey(name="weight", value=weight).transpose() if use_bias: @@ -616,7 +616,9 @@ def __init__( self += mult_model(left="input", right="weight") self += sum_model( - left=mult_model.output, right="bias", output=IOKey(name="output") + left=mult_model.output, + right="bias", + output=IOKey(name="output", expose=True), ) self.input.set_differentiable(False) self._freeze() @@ -656,7 +658,9 @@ def __init__( self.factory_args = {"activation": activation, "dimension": dimension} linear_model = Linear(dimension=dimension) self += linear_model(input="input", weight="weight", bias="bias") - self += activation(input=linear_model.output, output=IOKey(name="output")) + self += activation( + input=linear_model.output, output=IOKey(name="output", expose=True) + ) self._freeze() def __call__( # type: ignore[override] @@ -726,7 +730,9 @@ def __init__( self += add(left=self.canonical_output, right="bias") add._set_shapes(shapes) - self += Buffer()(input=self.canonical_output, output=IOKey(name="output")) + self += Buffer()( + input=self.canonical_output, output=IOKey(name="output", expose=True) + ) self._freeze() @@ -803,7 +809,9 @@ def __init__( self += add(left=self.canonical_output, right="bias") add._set_shapes(shapes) - self += Buffer()(input=self.canonical_output, output=IOKey(name="output")) + self += Buffer()( + input=self.canonical_output, output=IOKey(name="output", expose=True) + ) def __call__( # type: ignore[override] self, @@ -841,7 +849,7 @@ def __init__( abs_model = Absolute() self += abs_model(input="input") - self += Sum()(input=abs_model.output, output=IOKey(name="output")) + self += Sum()(input=abs_model.output, output=IOKey(name="output", expose=True)) self._freeze() @@ -869,7 +877,9 @@ def __init__( self += square(input="input") self += sum(input=square.output) - self += Multiply()(left=sum.output, right=0.5, output=IOKey(name="output")) + self += Multiply()( + left=sum.output, right=0.5, output=IOKey(name="output", expose=True) + ) self._freeze() @@ -904,7 +914,7 @@ def __init__( self += dot_model1(left=transpose_model.input, right="kernel") self += dot_model2(left=dot_model1.output, right=transpose_model.output) self += Multiply()( - left=dot_model2.output, right=0.5, output=IOKey(name="output") + left=dot_model2.output, right=0.5, output=IOKey(name="output", expose=True) ) shapes: dict[str, ShapeTemplateType] = {"input": [1, "N"], "kernel": ["N", "N"]} self._set_shapes(shapes) @@ -967,7 +977,9 @@ def __init__( self += exp_model(input=div_model.output) self += l_square(left="l_scale", right="l_scale") self += mult_model2( - left=l_square.output, right=exp_model.output, output=IOKey(name="output") + left=l_square.output, + right=exp_model.output, + output=IOKey(name="output", expose=True), ) self.set_canonical_input("input1") @@ -1032,7 +1044,9 @@ def __init__( self += mult_model(left="input1", right=transpose_model.output) self += sum_model(left=mult_model.output, right="poly_coef") self += power_model( - base=sum_model.output, exponent="degree", output=IOKey(name="output") + base=sum_model.output, + exponent="degree", + output=IOKey(name="output", expose=True), ) self._set_shapes( @@ -1107,7 +1121,7 @@ def __init__( input=kernel.canonical_output, weight="weight", bias="bias", - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) shapes: dict[str, ShapeTemplateType] = { @@ -1159,7 +1173,10 @@ def __init__( decision_model = Sign() self += linear_model( - input="input", weight="weight", bias="bias", output=IOKey(name="output") + input="input", + weight="weight", + bias="bias", + output=IOKey(name="output", expose=True), ) self += decision_model( input=linear_model.output, output=IOKey(name="decision_output") @@ -1207,7 +1224,10 @@ def __init__( sigmoid_model = Sigmoid() self += linear_model( - input="input", weight="weight", bias="bias", output=IOKey(name="output") + input="input", + weight="weight", + bias="bias", + output=IOKey(name="output", expose=True), ) self += sigmoid_model( input=linear_model.output, output=IOKey(name="probs_output") @@ -1270,7 +1290,7 @@ def __init__( "bias": bias + "0", } if len(activations) == 1: - extend_kwargs["output"] = IOKey(name="output") + extend_kwargs["output"] = IOKey(name="output", expose=True) self += prev_layer(**extend_kwargs) # Add layers sequentially starting from second elements. @@ -1290,7 +1310,7 @@ def __init__( if idx == ( len(activations) - 2 ): # Loop starts to iterate from second elemets, so it is -2. - kwargs |= {"output": IOKey(name="output")} + kwargs |= {"output": IOKey(name="output", expose=True)} # Add current layer to the model. self += current_layer(**kwargs) @@ -1391,7 +1411,9 @@ def __init__( self += Tanh()(input=sum_model_2.output, output=IOKey(name="hidden")) self += mult_model_3(input="hidden", weight="w_ho") self += Add()( - left=mult_model_3.output, right="bias_o", output=IOKey(name="output") + left=mult_model_3.output, + right="bias_o", + output=IOKey(name="output", expose=True), ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -1542,7 +1564,7 @@ def __init__( input="hidden", weight="w_out", bias="bias_out", - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -1696,7 +1718,7 @@ def __init__( self += Concat(n=2, axis=0)( input1=sum_model_4.output, input2=mult_model_3.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -2099,7 +2121,7 @@ def __init__( self += power_model( base=dist_model.output, exponent=reciprocal_model.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) else: @@ -2108,7 +2130,7 @@ def __init__( left="input1", right="input2", norm=modifier_model.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) self._freeze() @@ -2150,7 +2172,7 @@ def __init__( input=feature_model.output, weight="weight", bias="bias", - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) self.input.set_differentiable(False) self._freeze() @@ -2221,7 +2243,7 @@ def __init__( self += power_model_3( base=mult_model.output, exponent=reciprocal_model_1.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) else: @@ -2243,7 +2265,7 @@ def __init__( self += power_model_3( base=mult_model.output, exponent=reciprocal_model_1.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) self.distances.set_differentiable(False) @@ -2329,7 +2351,8 @@ def __init__( target=p_joint_model.output if calculate_p_joint else "p_joint", ) self += sum_model_3( - input=kl_divergence_model.output, output=IOKey(name="output") + input=kl_divergence_model.output, + output=IOKey(name="output", expose=True), ) else: @@ -2353,7 +2376,8 @@ def __init__( target=p_joint_model.output if calculate_p_joint else "p_joint", ) self += sum_model_3( - input=kl_divergence_model.output, output=IOKey(name="output") + input=kl_divergence_model.output, + output=IOKey(name="output", expose=True), ) self.distances.set_differentiable(False) @@ -2426,7 +2450,7 @@ def __init__( base_kwargs: dict[str, ConnectionType] = { "distances": input_distance_matrix.output, "pred_distances": coords_distance_matrix.output, - "output": IOKey(name="output"), + "output": IOKey(name="output", expose=True), } # Create inputs taking "requires_norm" attribute of base model class. if base_model.requires_norm: @@ -2452,7 +2476,7 @@ def __init__( base_kwargs = { "distances": "input", "pred_distances": coords_distance_matrix.output, - "output": IOKey(name="output"), + "output": IOKey(name="output", expose=True), } if base_model.requires_norm: base_kwargs["norm"] = "norm" @@ -2779,7 +2803,7 @@ def __init__( self += Add()( left=sum_model_1.output, right=mult_model_2.output, - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) shapes: dict[str, ShapeTemplateType] = { @@ -2920,7 +2944,7 @@ def __init__( self += Divide()( numerator="n_true_predictions", denominator=n_prediction.tensor(), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) self.set_canonical_input(self.pred) @@ -2989,7 +3013,7 @@ def __init__( self += Buffer()( input=self.n_true_positive / (self.n_true_positive + self.n_false_positive), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) if average == "macro": @@ -3025,7 +3049,7 @@ def __init__( self += Divide()( numerator=sum_precision, denominator=self.n_classes.shape()[0].tensor(), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) elif average == "weighted": @@ -3063,7 +3087,7 @@ def __init__( else: precision += getattr(self, f"weighted_precision_{idx}") - self += Buffer()(input=precision, output=IOKey(name="output")) + self += Buffer()(input=precision, output=IOKey(name="output", expose=True)) self.label.set_differentiable(False) self.set_canonical_input(self.pred) @@ -3134,7 +3158,7 @@ def __init__( self += Buffer()( input=self.n_true_positive / (self.n_true_positive + self.n_false_negative), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) if average == "macro": @@ -3169,7 +3193,7 @@ def __init__( self += Divide()( numerator=sum_recall, denominator=self.n_classes.shape()[0].tensor(), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) elif average == "weighted": @@ -3207,7 +3231,7 @@ def __init__( else: recall += getattr(self, f"weighted_recall_{idx}") - self += Buffer()(input=recall, output=IOKey(name="output")) + self += Buffer()(input=recall, output=IOKey(name="output", expose=True)) self.label.set_differentiable(False) self.set_canonical_input(self.pred) @@ -3278,7 +3302,7 @@ def __init__( self += Buffer()( input=self.n_true_positive / (self.n_true_positive + self.n_false_positive), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) if average == "macro": @@ -3315,7 +3339,7 @@ def __init__( self += Divide()( numerator=sum_precision, denominator=self.n_classes.shape()[0].tensor(), - output=IOKey(name="output"), + output=IOKey(name="output", expose=True), ) elif average == "weighted": @@ -3356,7 +3380,7 @@ def __init__( else: precision += getattr(self, f"weighted_precision_{idx}") - self += Buffer()(input=precision, output=IOKey(name="output")) + self += Buffer()(input=precision, output=IOKey(name="output", expose=True)) self.label.set_differentiable(False) self.set_canonical_input(self.pred) @@ -3451,7 +3475,9 @@ def __init__( self += Exponential()(input="minus", output="exp") self += Add()(left=1, right="exp", output="add") self += Divide()( - numerator="input", denominator="add", output=IOKey(name="output") + numerator="input", + denominator="add", + output=IOKey(name="output", expose=True), ) self._set_shapes({"input": [("Var", ...)], "output": [("Var", ...)]}) diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index c0bec714..95c7f855 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -246,8 +246,8 @@ def test_data_store_10(): """Infer static keys from pruned buffer 2""" backend = TorchBackend(precision=32) model = Model() - model += Buffer()(input="input", output=IOKey(name="output1")) - model += Sigmoid()(input="input", output=IOKey(name="output2")) + model += Buffer()(input="input", output=IOKey(name="output1", expose=True)) + model += Sigmoid()(input="input", output=IOKey(name="output2", expose=True)) value = backend.array([[1.0, 2, 3]]) pm = mithril.compile(model, backend=backend, constant_keys={"input": value}) @@ -262,8 +262,8 @@ def test_data_store_10(): def test_data_store_11(): backend = TorchBackend(precision=32) model = Model() - model += Sigmoid()(input="input", output=IOKey(name="output1")) - model += Sigmoid()(input="input", output=IOKey(name="output2")) + model += Sigmoid()(input="input", output=IOKey(name="output1", expose=True)) + model += Sigmoid()(input="input", output=IOKey(name="output2", expose=True)) model += Add()(left="output2", right=2, output=IOKey(name="output3", expose=True)) value = backend.array([[1.0, 2, 3]]) pm = mithril.compile(model, backend=backend, constant_keys={"input": value}) diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index a3425b78..f0bf29e3 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -689,8 +689,9 @@ def test_iokey_tensor_input_all_args(): Exception: _description_ Exception: _description_ """ + from mithril import JaxBackend - backend = TorchBackend() + backend = JaxBackend() # collect all possible values possible_names = ["left", None] possible_values = [[[2.0]], TBD] diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index a502fe1a..1771569e 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -180,11 +180,23 @@ def test_integration_call_arg_connection(): model = Model() model += add2(left="in1", right="in2", output="out1") model += add1(left=add2.left, right=add2.output, output="output") - model.summary() backend = ml.TorchBackend() pm = ml.compile(model, backend, data_keys=["in2"], jit=False) assert pm.evaluate(data={"in2": 2})["output"] == backend.array(4.0) + # model = Model() + # model += (add := Add())("mahmut", "mahmout") + # model += Add()(input = IOKey("mahmut", expose=False), output = IOKey()) + # model += Add()(input = IOKey(), output = IOKey()) + # model += Add() + + # model = Model() + # model += (add := Add())("mahmut", "mahmout1") + # model += (add := Add())("mahmut", "mahmout2") + + # model = Model() + # model += Add() + # model += Add() def test_integration_call_arg_str(): diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index acaa092f..90973a81 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -239,8 +239,9 @@ def test_model_with_misconnection_error(): model += Add() final_model = Model() final_model += model - with pytest.raises(KeyError): + with pytest.raises(KeyError) as error_info: final_model += Add()(left=add.output) + assert str(error_info.value) == "'Requires accessible connection to be processed!'" def test_cyclic_extension_5(): @@ -4485,22 +4486,6 @@ def test_cycle_extend(): ) -def test_cycle_extend_2(): - model = Model() - - model_2 = Model() - model_2 += Tanh()(input="input1", output=IOKey(name="output1")) - model_2 += Sine()(input="input2", output=IOKey(name="output2")) - - with pytest.raises(KeyError) as err: - model += model_2(input2="input", output1="input", output2="output") - - assert str(err.value) == ( - "\"Given connections: '['input']' are used both in input and output keys, " - 'which creates cycle!"' - ) - - def test_cycle_handling_1(): backend = TorchBackend(precision=64) model = Model() @@ -6901,21 +6886,19 @@ def test_extend_with_wrong_values(): def test_cyclic_extend(): - with pytest.raises(KeyError) as error_info1: + with pytest.raises(Exception) as error_info1: model = Model() - model += Relu()(input="input", output="input") + model += Relu()(input="input1", output="input1") - with pytest.raises(KeyError) as error_info2: + with pytest.raises(Exception) as error_info2: model = Model() - model += LogisticRegression()(input="input", probs_output="input") + model += LogisticRegression()(input="input1", probs_output="input1") assert str(error_info1.value) == ( - "\"Given connections: '['input']' are used both in input and output keys, " - 'which creates cycle!"' + "There exists a cyclic subgraph between input1 key and ['input1'] key(s)!" ) assert str(error_info2.value) == ( - "\"Given connections: '['input']' are used both in input and output keys, " - 'which creates cycle!"' + "There exists a cyclic subgraph between input1 key and ['$3', 'input1'] key(s)!" ) From 66d3dbdc3b73d755ebfe953316fa7b47f3eaa362 Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Fri, 20 Dec 2024 16:17:09 +0300 Subject: [PATCH 2/9] minor updates made --- mithril/models/models.py | 94 +++++++++++------------- tests/scripts/test_io_key.py | 3 +- tests/scripts/test_key_values_in_init.py | 13 ---- 3 files changed, 44 insertions(+), 66 deletions(-) diff --git a/mithril/models/models.py b/mithril/models/models.py index b675b994..65680c7f 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -200,7 +200,7 @@ def __init__( stride=stride_conv.output, padding=pad_conv.output, dilation=IOKey(name="dilation", value=dilation), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.input.set_differentiable(False) self._freeze() @@ -308,7 +308,7 @@ def __init__( stride=st_converter.output, padding=pt_converter.output, dilation=dt_converter.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.input.set_differentiable(False) self._freeze() @@ -401,7 +401,7 @@ def __init__( ) conv_connections: dict[str, ConnectionType] = { - "output": IOKey(name="output", expose=True), + "output": IOKey(name="output"), "input": "input", "weight": "weight", "stride": IOKey(name="stride", value=stride), @@ -502,7 +502,7 @@ def __init__( self += dt_converter(input=IOKey(name="dilation", value=dilation)) conv_connections: dict[str, ConnectionType] = { - "output": IOKey(name="output", expose=True), + "output": IOKey(name="output"), "input": "input", "weight": "weight", "stride": st_converter.output, @@ -562,7 +562,7 @@ def __init__( mult = MatrixMultiply() - output = IOKey(name="output", expose=True) + output = IOKey(name="output") weight_key = IOKey(name="weight", value=weight).transpose() if use_bias: @@ -618,7 +618,7 @@ def __init__( self += sum_model( left=mult_model.output, right="bias", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.input.set_differentiable(False) self._freeze() @@ -658,9 +658,7 @@ def __init__( self.factory_args = {"activation": activation, "dimension": dimension} linear_model = Linear(dimension=dimension) self += linear_model(input="input", weight="weight", bias="bias") - self += activation( - input=linear_model.output, output=IOKey(name="output", expose=True) - ) + self += activation(input=linear_model.output, output=IOKey(name="output")) self._freeze() def __call__( # type: ignore[override] @@ -730,9 +728,7 @@ def __init__( self += add(left=self.canonical_output, right="bias") add._set_shapes(shapes) - self += Buffer()( - input=self.canonical_output, output=IOKey(name="output", expose=True) - ) + self += Buffer()(input=self.canonical_output, output=IOKey(name="output")) self._freeze() @@ -809,9 +805,7 @@ def __init__( self += add(left=self.canonical_output, right="bias") add._set_shapes(shapes) - self += Buffer()( - input=self.canonical_output, output=IOKey(name="output", expose=True) - ) + self += Buffer()(input=self.canonical_output, output=IOKey(name="output")) def __call__( # type: ignore[override] self, @@ -849,7 +843,7 @@ def __init__( abs_model = Absolute() self += abs_model(input="input") - self += Sum()(input=abs_model.output, output=IOKey(name="output", expose=True)) + self += Sum()(input=abs_model.output, output=IOKey(name="output")) self._freeze() @@ -877,9 +871,7 @@ def __init__( self += square(input="input") self += sum(input=square.output) - self += Multiply()( - left=sum.output, right=0.5, output=IOKey(name="output", expose=True) - ) + self += Multiply()(left=sum.output, right=0.5, output=IOKey(name="output")) self._freeze() @@ -914,7 +906,7 @@ def __init__( self += dot_model1(left=transpose_model.input, right="kernel") self += dot_model2(left=dot_model1.output, right=transpose_model.output) self += Multiply()( - left=dot_model2.output, right=0.5, output=IOKey(name="output", expose=True) + left=dot_model2.output, right=0.5, output=IOKey(name="output") ) shapes: dict[str, ShapeTemplateType] = {"input": [1, "N"], "kernel": ["N", "N"]} self._set_shapes(shapes) @@ -979,7 +971,7 @@ def __init__( self += mult_model2( left=l_square.output, right=exp_model.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.set_canonical_input("input1") @@ -1046,7 +1038,7 @@ def __init__( self += power_model( base=sum_model.output, exponent="degree", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self._set_shapes( @@ -1121,7 +1113,7 @@ def __init__( input=kernel.canonical_output, weight="weight", bias="bias", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) shapes: dict[str, ShapeTemplateType] = { @@ -1176,7 +1168,7 @@ def __init__( input="input", weight="weight", bias="bias", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self += decision_model( input=linear_model.output, output=IOKey(name="decision_output") @@ -1227,7 +1219,7 @@ def __init__( input="input", weight="weight", bias="bias", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self += sigmoid_model( input=linear_model.output, output=IOKey(name="probs_output") @@ -1290,7 +1282,7 @@ def __init__( "bias": bias + "0", } if len(activations) == 1: - extend_kwargs["output"] = IOKey(name="output", expose=True) + extend_kwargs["output"] = IOKey(name="output") self += prev_layer(**extend_kwargs) # Add layers sequentially starting from second elements. @@ -1310,7 +1302,7 @@ def __init__( if idx == ( len(activations) - 2 ): # Loop starts to iterate from second elemets, so it is -2. - kwargs |= {"output": IOKey(name="output", expose=True)} + kwargs |= {"output": IOKey(name="output")} # Add current layer to the model. self += current_layer(**kwargs) @@ -1413,7 +1405,7 @@ def __init__( self += Add()( left=mult_model_3.output, right="bias_o", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -1564,7 +1556,7 @@ def __init__( input="hidden", weight="w_out", bias="bias_out", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -1718,7 +1710,7 @@ def __init__( self += Concat(n=2, axis=0)( input1=sum_model_4.output, input2=mult_model_3.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -2121,7 +2113,7 @@ def __init__( self += power_model( base=dist_model.output, exponent=reciprocal_model.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) else: @@ -2130,7 +2122,7 @@ def __init__( left="input1", right="input2", norm=modifier_model.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self._freeze() @@ -2172,7 +2164,7 @@ def __init__( input=feature_model.output, weight="weight", bias="bias", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.input.set_differentiable(False) self._freeze() @@ -2243,7 +2235,7 @@ def __init__( self += power_model_3( base=mult_model.output, exponent=reciprocal_model_1.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) else: @@ -2265,7 +2257,7 @@ def __init__( self += power_model_3( base=mult_model.output, exponent=reciprocal_model_1.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.distances.set_differentiable(False) @@ -2352,7 +2344,7 @@ def __init__( ) self += sum_model_3( input=kl_divergence_model.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) else: @@ -2377,7 +2369,7 @@ def __init__( ) self += sum_model_3( input=kl_divergence_model.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.distances.set_differentiable(False) @@ -2450,7 +2442,7 @@ def __init__( base_kwargs: dict[str, ConnectionType] = { "distances": input_distance_matrix.output, "pred_distances": coords_distance_matrix.output, - "output": IOKey(name="output", expose=True), + "output": IOKey(name="output"), } # Create inputs taking "requires_norm" attribute of base model class. if base_model.requires_norm: @@ -2476,7 +2468,7 @@ def __init__( base_kwargs = { "distances": "input", "pred_distances": coords_distance_matrix.output, - "output": IOKey(name="output", expose=True), + "output": IOKey(name="output"), } if base_model.requires_norm: base_kwargs["norm"] = "norm" @@ -2803,7 +2795,7 @@ def __init__( self += Add()( left=sum_model_1.output, right=mult_model_2.output, - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) shapes: dict[str, ShapeTemplateType] = { @@ -2944,7 +2936,7 @@ def __init__( self += Divide()( numerator="n_true_predictions", denominator=n_prediction.tensor(), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self.set_canonical_input(self.pred) @@ -3013,7 +3005,7 @@ def __init__( self += Buffer()( input=self.n_true_positive / (self.n_true_positive + self.n_false_positive), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) if average == "macro": @@ -3049,7 +3041,7 @@ def __init__( self += Divide()( numerator=sum_precision, denominator=self.n_classes.shape()[0].tensor(), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) elif average == "weighted": @@ -3087,7 +3079,7 @@ def __init__( else: precision += getattr(self, f"weighted_precision_{idx}") - self += Buffer()(input=precision, output=IOKey(name="output", expose=True)) + self += Buffer()(input=precision, output=IOKey(name="output")) self.label.set_differentiable(False) self.set_canonical_input(self.pred) @@ -3158,7 +3150,7 @@ def __init__( self += Buffer()( input=self.n_true_positive / (self.n_true_positive + self.n_false_negative), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) if average == "macro": @@ -3193,7 +3185,7 @@ def __init__( self += Divide()( numerator=sum_recall, denominator=self.n_classes.shape()[0].tensor(), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) elif average == "weighted": @@ -3231,7 +3223,7 @@ def __init__( else: recall += getattr(self, f"weighted_recall_{idx}") - self += Buffer()(input=recall, output=IOKey(name="output", expose=True)) + self += Buffer()(input=recall, output=IOKey(name="output")) self.label.set_differentiable(False) self.set_canonical_input(self.pred) @@ -3302,7 +3294,7 @@ def __init__( self += Buffer()( input=self.n_true_positive / (self.n_true_positive + self.n_false_positive), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) if average == "macro": @@ -3339,7 +3331,7 @@ def __init__( self += Divide()( numerator=sum_precision, denominator=self.n_classes.shape()[0].tensor(), - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) elif average == "weighted": @@ -3380,7 +3372,7 @@ def __init__( else: precision += getattr(self, f"weighted_precision_{idx}") - self += Buffer()(input=precision, output=IOKey(name="output", expose=True)) + self += Buffer()(input=precision, output=IOKey(name="output")) self.label.set_differentiable(False) self.set_canonical_input(self.pred) @@ -3477,7 +3469,7 @@ def __init__( self += Divide()( numerator="input", denominator="add", - output=IOKey(name="output", expose=True), + output=IOKey(name="output"), ) self._set_shapes({"input": [("Var", ...)], "output": [("Var", ...)]}) diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index f0bf29e3..a3425b78 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -689,9 +689,8 @@ def test_iokey_tensor_input_all_args(): Exception: _description_ Exception: _description_ """ - from mithril import JaxBackend - backend = JaxBackend() + backend = TorchBackend() # collect all possible values possible_names = ["left", None] possible_values = [[[2.0]], TBD] diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index 1771569e..bf4b8ef6 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -184,19 +184,6 @@ def test_integration_call_arg_connection(): backend = ml.TorchBackend() pm = ml.compile(model, backend, data_keys=["in2"], jit=False) assert pm.evaluate(data={"in2": 2})["output"] == backend.array(4.0) - # model = Model() - # model += (add := Add())("mahmut", "mahmout") - # model += Add()(input = IOKey("mahmut", expose=False), output = IOKey()) - # model += Add()(input = IOKey(), output = IOKey()) - # model += Add() - - # model = Model() - # model += (add := Add())("mahmut", "mahmout1") - # model += (add := Add())("mahmut", "mahmout2") - - # model = Model() - # model += Add() - # model += Add() def test_integration_call_arg_str(): From 2e462710ffd0af3f63da87cfc4d0903bba88e260 Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Mon, 23 Dec 2024 16:54:14 +0300 Subject: [PATCH 3/9] minor updates on add_connection and convert_to_iokey methods --- mithril/framework/common.py | 4 +- mithril/framework/logical/model.py | 60 +++++++++++++----------------- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 6b828d9c..59ac2ea3 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -150,8 +150,8 @@ class UpdateType(Enum): class KeyType(Enum): INPUT = 1 OUTPUT = 2 - INTERNAL = 3 - LATENT_INPUT = 4 + LATENT_INPUT = 3 + INTERNAL = 4 type FixedValueType = ( diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index a271d2b3..56855fce 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -299,20 +299,15 @@ def _check_multi_write( def _convert_to_iokey( self, model: BaseModel, key: str, connection: ConnectionType ) -> IOKey: - is_input = key in model._input_keys local_connection = model.conns.get_connection(key) assert local_connection is not None, "Connection is not found!" - not_valued = local_connection.metadata.data.value is TBD match connection: case NullConnection(): - connection = IOKey(expose=is_input and not_valued) + connection = IOKey() case str(): - expose = None - if self.conns.get_connection(connection) is None: - expose = is_input and not_valued - connection = IOKey(name=connection, expose=expose) + connection = IOKey(name=connection) case Connection(): - connection = IOKey(connections=[connection], expose=None) + connection = IOKey(connections=[connection]) case ExtendTemplate(): # Unroll ExtendTemplate template_conn = model.conns.get_connection(key) @@ -342,10 +337,7 @@ def _convert_to_iokey( connection = IOKey(connections=[result.conn], expose=None) else: assert isinstance(connection, MainValueInstance) - expose = None - if not not_valued: - expose = False - connection = IOKey(value=connection, expose=expose) + connection = IOKey(value=connection) case IOKey(): expose = connection._expose name = connection._name @@ -405,6 +397,7 @@ def _add_connection( is_input = local_key in model._input_keys local_connection = model.conns.get_connection(local_key) assert local_connection is not None, "Connection is not found!" + is_not_valued = local_connection.metadata.data.value is TBD d_map = self.dependency_map._local_output_dependency_map expose = given_connection._expose @@ -417,9 +410,9 @@ def _add_connection( if given_connection._connections == OrderedSet([]): if outer_key is not None: con_obj = self.conns.get_connection(outer_key) - if expose is None and con_obj is None and is_input: - expose = True if outer_key is None or con_obj is None: + if expose is None and is_input and is_not_valued: + expose = True con_obj = self.create_connection(local_connection.metadata, outer_key) if ( expose is False @@ -432,7 +425,7 @@ def _add_connection( "Expose flag cannot be false when " "no value is provided for input keys!" ) - elif given_connection._connections != OrderedSet([]): + else: initial_conn: ConnectionData for idx, conn in enumerate(given_connection._connections): if isinstance(conn, str): @@ -468,8 +461,6 @@ def _add_connection( "name but encountered more!" ) updates |= self.merge_connections(initial_conn, _conn) - if outer_key is None and is_input and initial_conn not in d_map: - expose = True if not outer_key and initial_conn in d_map and expose is True: raise KeyError("Connection without a name cannot be set as output") con_obj = initial_conn @@ -497,7 +488,7 @@ def _add_connection( # If any value provided, set. assert con_obj is not None if not isinstance(set_value, NullConnection): - updates |= con_obj.metadata.data.set_value(set_value) # type: ignore + updates |= con_obj.metadata.data.set_value(set_value) # Check multi-write error for con_obj. self._check_multi_write(is_input, local_connection, con_obj) @@ -517,22 +508,21 @@ def _add_connection( ): con_obj.metadata.key_origin = local_key_origin - # Set connection as input, output, latent input or - # internal based on expose and is_input flag. - if is_input: - if outer_key not in self._input_keys: - if expose is True: - if con_obj in d_map: - self.conns.set_connection_type(con_obj, KeyType.OUTPUT) - else: - self.conns.set_connection_type(con_obj, KeyType.INPUT) - elif con_obj not in d_map: - self.conns.set_connection_type(con_obj, KeyType.LATENT_INPUT) - else: - if expose and outer_key not in self.conns.output_keys: - self.conns.set_connection_type(con_obj, KeyType.OUTPUT) - elif not expose and outer_key not in self.conns.internal_keys: - self.conns.set_connection_type(con_obj, KeyType.INTERNAL) + unexpose = int(not (expose or (is_input and con_obj.key in self.conns.io_keys))) + is_output = int(not (is_input and con_obj not in d_map)) + bitwise_key_type = ( + unexpose << 1 | is_output + ) + 1 # bits: (unexpose, is_output) + self.conns.set_connection_type(con_obj, KeyType(bitwise_key_type)) + + # if expose is None: + # expose = is_input and con_obj.key in self.conns.io_keys + # is_output = not (is_input and con_obj not in d_map) + # if expose: + # key_type = (KeyType.INPUT, KeyType.OUTPUT)[is_output] + # else: + # key_type = (KeyType.LATENT_INPUT, KeyType.INTERNAL)[is_output] + # self.conns.set_connection_type(con_obj, key_type) return con_obj, updates @@ -721,7 +711,7 @@ def extend( submodel_dag: dict[str, ConnectionData] = {} updates = self.constraint_solver.match(model.constraint_solver) - # Add canonical output if it is not in externel_keys + # Add canonical output if it is not in external_keys external_keys = list(model.external_keys) if ( model.canonical_output is not NOT_AVAILABLE From 41579a42b842d69f1b6d38b5ecaaeb4a5833442c Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Mon, 23 Dec 2024 16:55:22 +0300 Subject: [PATCH 4/9] minor comment deletion --- mithril/framework/logical/model.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 56855fce..2e440a7f 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -515,15 +515,6 @@ def _add_connection( ) + 1 # bits: (unexpose, is_output) self.conns.set_connection_type(con_obj, KeyType(bitwise_key_type)) - # if expose is None: - # expose = is_input and con_obj.key in self.conns.io_keys - # is_output = not (is_input and con_obj not in d_map) - # if expose: - # key_type = (KeyType.INPUT, KeyType.OUTPUT)[is_output] - # else: - # key_type = (KeyType.LATENT_INPUT, KeyType.INTERNAL)[is_output] - # self.conns.set_connection_type(con_obj, key_type) - return con_obj, updates def _unroll_template( From ec3e458f07a7e47e55b9bd0a4e3f4ba44d46106e Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Tue, 24 Dec 2024 16:49:02 +0300 Subject: [PATCH 5/9] minor pre-commit updates --- mithril/framework/common.py | 3 -- mithril/framework/logical/model.py | 8 ++-- mithril/models/models.py | 74 ++++++++---------------------- 3 files changed, 23 insertions(+), 62 deletions(-) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 59ac2ea3..3560bd4b 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -1135,9 +1135,6 @@ def __init__( conn = item.data if isinstance(item, Connection) else item self._connections.add(conn) - def __hash__(self) -> int: - return hash(id(self)) - class Connection(TemplateBase): def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool): diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 2e440a7f..e95f35e6 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -466,7 +466,7 @@ def _add_connection( con_obj = initial_conn # Name "input" can only be used for input connections. - is_key_name_input = con_obj is not None and con_obj.key == "input" + is_key_name_input = con_obj is not None and (con_key := con_obj.key) == "input" if not is_input and (outer_key == "input" or is_key_name_input): raise KeyError( "The key 'input' is a reserved key which could not be used for " @@ -508,11 +508,9 @@ def _add_connection( ): con_obj.metadata.key_origin = local_key_origin - unexpose = int(not (expose or (is_input and con_obj.key in self.conns.io_keys))) + unexposed = int(not (expose or (is_input and con_key in self.conns.io_keys))) is_output = int(not (is_input and con_obj not in d_map)) - bitwise_key_type = ( - unexpose << 1 | is_output - ) + 1 # bits: (unexpose, is_output) + bitwise_key_type = (unexposed << 1 | is_output) + 1 # (unexpose, is_output) self.conns.set_connection_type(con_obj, KeyType(bitwise_key_type)) return con_obj, updates diff --git a/mithril/models/models.py b/mithril/models/models.py index 65680c7f..717ba3f3 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -1036,11 +1036,8 @@ def __init__( self += mult_model(left="input1", right=transpose_model.output) self += sum_model(left=mult_model.output, right="poly_coef") self += power_model( - base=sum_model.output, - exponent="degree", - output=IOKey(name="output"), + base=sum_model.output, exponent="degree", output=IOKey(name="output") ) - self._set_shapes( { "input1": ["N", "d"], @@ -1165,10 +1162,7 @@ def __init__( decision_model = Sign() self += linear_model( - input="input", - weight="weight", - bias="bias", - output=IOKey(name="output"), + input="input", weight="weight", bias="bias", output=IOKey(name="output") ) self += decision_model( input=linear_model.output, output=IOKey(name="decision_output") @@ -1403,9 +1397,7 @@ def __init__( self += Tanh()(input=sum_model_2.output, output=IOKey(name="hidden")) self += mult_model_3(input="hidden", weight="w_ho") self += Add()( - left=mult_model_3.output, - right="bias_o", - output=IOKey(name="output"), + left=mult_model_3.output, right="bias_o", output=IOKey(name="output") ) shapes: dict[str, ShapeTemplateType] = { "input": ["N", 1, "d_in"], @@ -2327,50 +2319,27 @@ def __init__( self += p_joint_model( squared_distances=square_model.output, target_perplexity=perplexity ) - self += sum_model_1(left=1.0, right="pred_distances") - self += divide_model_1(numerator=1.0, denominator=sum_model_1.output) - self += size_model(input=getattr(self, "distances", "distances")) - self += zero_diagonal_model(N=size_model.output) - self += mult_model( - left=divide_model_1.output, right=zero_diagonal_model.output - ) - self += sum_model_2(input=mult_model.output) - self += divide_model_2( - numerator=mult_model.output, denominator=sum_model_2.output - ) - self += kl_divergence_model( - input=divide_model_2.output, - target=p_joint_model.output if calculate_p_joint else "p_joint", - ) - self += sum_model_3( - input=kl_divergence_model.output, - output=IOKey(name="output"), - ) - else: if calculate_p_joint: self += p_joint_model( squared_distances="distances", target_perplexity=perplexity ) - self += sum_model_1(left=1.0, right="pred_distances") - self += divide_model_1(numerator=1.0, denominator=sum_model_1.output) - self += size_model(input=getattr(self, "distances", "distances")) - self += zero_diagonal_model(N=size_model.output) - self += mult_model( - left=divide_model_1.output, right=zero_diagonal_model.output - ) - self += sum_model_2(input=mult_model.output) - self += divide_model_2( - numerator=mult_model.output, denominator=sum_model_2.output - ) - self += kl_divergence_model( - input=divide_model_2.output, - target=p_joint_model.output if calculate_p_joint else "p_joint", - ) - self += sum_model_3( - input=kl_divergence_model.output, - output=IOKey(name="output"), - ) + self += sum_model_1(left=1.0, right="pred_distances") + self += divide_model_1(numerator=1.0, denominator=sum_model_1.output) + self += size_model(input=getattr(self, "distances", "distances")) + self += zero_diagonal_model(N=size_model.output) + self += mult_model(left=divide_model_1.output, right=zero_diagonal_model.output) + self += sum_model_2(input=mult_model.output) + self += divide_model_2( + numerator=mult_model.output, denominator=sum_model_2.output + ) + self += kl_divergence_model( + input=divide_model_2.output, + target=p_joint_model.output if calculate_p_joint else "p_joint", + ) + self += sum_model_3( + input=kl_divergence_model.output, output=IOKey(name="output") + ) self.distances.set_differentiable(False) self._set_shapes({"distances": ["N", "N"], "pred_distances": ["N", "N"]}) @@ -3467,11 +3436,8 @@ def __init__( self += Exponential()(input="minus", output="exp") self += Add()(left=1, right="exp", output="add") self += Divide()( - numerator="input", - denominator="add", - output=IOKey(name="output"), + numerator="input", denominator="add", output=IOKey(name="output") ) - self._set_shapes({"input": [("Var", ...)], "output": [("Var", ...)]}) self.input.set_differentiable(False) From b21b9a972fb9f18443a17687a8ea9020c13c8f06 Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Wed, 25 Dec 2024 10:01:05 +0300 Subject: [PATCH 6/9] Add BaseKey and use BaseKey for primitive models, update tests --- mithril/framework/common.py | 15 + .../framework/logical/essential_primitives.py | 234 +++++----- mithril/framework/logical/primitive.py | 33 +- mithril/models/primitives.py | 406 +++++++++--------- tests/scripts/test_constr_counter.py | 20 +- tests/scripts/test_functions.py | 14 +- tests/scripts/test_jittable.py | 14 +- tests/scripts/test_model_to_dict_rtt.py | 8 +- tests/scripts/test_ref_counts.py | 65 +-- tests/scripts/test_scripts.py | 41 +- tests/scripts/test_shapes.py | 205 ++++----- tests/scripts/test_type_coercion.py | 9 +- tests/scripts/test_type_consistencies.py | 25 +- 13 files changed, 562 insertions(+), 527 deletions(-) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index bb495bfd..8fa15dae 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -1100,6 +1100,15 @@ def __init__( self.output_connection = None +@dataclass +class BaseKey: + value: TensorValueType | MainValueType | ToBeDetermined | str = TBD + shape: ShapeTemplateType | None = None + type: NestedListType | UnionType | type | None = None + interval: list[float | int] | None = None + + +@dataclass class IOKey(TemplateBase): def __init__( self, @@ -1141,6 +1150,12 @@ def __init__( conn = item.data if isinstance(item, Connection) else item self._connections.add(conn) + def __eq__(self, other: object): + if isinstance(other, int | float | bool | list | Connection | IOKey | tuple): + return ExtendTemplate(connections=[self, other], model="eq") + else: + raise ValueError("Unsupported type for equality operation.") + class Connection(TemplateBase): def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool): diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 4235719a..051d9c94 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -20,10 +20,10 @@ from ..common import ( NOT_GIVEN, TBD, + BaseKey, Connection, ConnectionType, GenericTensorType, - IOKey, MyTensor, ShapeTemplateType, TensorValueType, @@ -123,8 +123,8 @@ def __init__( super().__init__( formula_key="buffer", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -149,10 +149,10 @@ def __init__( ) -> None: self.factory_args = {"n": n} key_definitions = { - "output": IOKey(type=tuple[int | float | bool | list | tuple, ...]) + "output": BaseKey(type=tuple[int | float | bool | list | tuple, ...]) } key_definitions |= { - f"input{idx+1}": IOKey(type=int | float | bool | list | tuple) + f"input{idx+1}": BaseKey(type=int | float | bool | list | tuple) for idx in range(n) } self.factory_inputs = kwargs # type: ignore @@ -173,9 +173,9 @@ def __init__(self, formula_key: str, name: str | None = None) -> None: super().__init__( formula_key=formula_key, name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self._set_constraint( @@ -216,18 +216,18 @@ def __init__( super().__init__( formula_key="robust_power", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - base=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - exponent=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - threshold=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + base=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + exponent=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + threshold=BaseKey(shape=[], type=GenericTensorType), ) self.threshold.set_differentiable(False) # type: ignore else: super().__init__( formula_key="power", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - base=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - exponent=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + base=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + exponent=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self._set_constraint( @@ -303,9 +303,9 @@ def __init__( super().__init__( formula_key="divide", name=name, - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), - numerator=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - denominator=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), + numerator=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + denominator=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.factory_inputs = {"numerator": numerator, "denominator": denominator} self._set_constraint( @@ -339,9 +339,9 @@ def __init__( super().__init__( formula_key="floor_divide", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - numerator=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - denominator=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + numerator=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + denominator=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.factory_inputs = {"numerator": numerator, "denominator": denominator} @@ -378,9 +378,9 @@ def __init__( super().__init__( formula_key="matrix_multiplication", name=name, - output=IOKey(shape=[("Var3", ...), "x", "z"], type=GenericTensorType), - left=IOKey(shape=[("Var1", ...), "x", "y"], type=GenericTensorType), - right=IOKey(shape=[("Var2", ...), "y", "z"], type=GenericTensorType), + output=BaseKey(shape=[("Var3", ...), "x", "z"], type=GenericTensorType), + left=BaseKey(shape=[("Var1", ...), "x", "y"], type=GenericTensorType), + right=BaseKey(shape=[("Var2", ...), "y", "z"], type=GenericTensorType), ) self.factory_inputs = {"left": left, "right": right} self._set_constraint( @@ -410,8 +410,8 @@ def __init__( super().__init__( formula_key="shape", name=name, - output=IOKey(shape=[], type=tuple[int, ...]), - input=IOKey(shape=[("input", ...)], type=GenericTensorType), + output=BaseKey(shape=[], type=tuple[int, ...]), + input=BaseKey(shape=[("input", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} self._set_constraint(fn=shape_constraints, keys=["output", "input"]) @@ -442,9 +442,9 @@ def __init__( super().__init__( formula_key="reshape", name=name, - output=IOKey(shape=output_shape_map, type=GenericTensorType), - input=IOKey(shape=[("input", ...)], type=GenericTensorType), - shape=IOKey(type=tuple[int | None, ...] | list[int | None], value=shape), + output=BaseKey(shape=output_shape_map, type=GenericTensorType), + input=BaseKey(shape=[("input", ...)], type=GenericTensorType), + shape=BaseKey(type=tuple[int | None, ...] | list[int | None], value=shape), ) self.factory_inputs = {"input": input, "shape": shape} self._set_constraint(fn=reshape_constraints, keys=["output", "input", "shape"]) @@ -468,8 +468,8 @@ def __init__( super().__init__( formula_key="length", name=name, - output=IOKey(type=int), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=int), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -490,9 +490,9 @@ def __init__( super().__init__( formula_key="astype", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - dtype=IOKey(type=Dtype), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + dtype=BaseKey(type=Dtype), ) self.factory_inputs = {"dtype": dtype} @@ -515,8 +515,8 @@ def __init__( super().__init__( formula_key="dtype", name=name, - output=IOKey(type=core.Dtype), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=core.Dtype), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -541,9 +541,9 @@ def __init__( super().__init__( formula_key="size", name=name, - output=IOKey(type=int | tuple[int, ...]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - dim=IOKey(type=int | tuple[int, ...] | None, value=dim), + output=BaseKey(type=int | tuple[int, ...]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + dim=BaseKey(type=int | tuple[int, ...] | None, value=dim), ) self.factory_inputs = {"input": input} self._set_constraint(fn=size_constraints, keys=["output", "input", "dim"]) @@ -576,13 +576,15 @@ def __init__( super().__init__( formula_key="sequence_slice", name=name, - output=IOKey( + output=BaseKey( type=tuple[int | float | bool, ...] | list[int | float | bool] ), - input=IOKey(type=tuple[int | float | bool, ...] | list[int | float | bool]), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), + input=BaseKey( + type=tuple[int | float | bool, ...] | list[int | float | bool] + ), + start=BaseKey(type=int | None, value=start), + stop=BaseKey(type=int | None, value=stop), + step=BaseKey(type=int | None, value=step), ) self.factory_inputs = {"input": input} @@ -623,11 +625,11 @@ def __init__( super().__init__( formula_key="tensor_slice", name=name, - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - input=IOKey(shape=["b", ("Var1", ...)], type=GenericTensorType), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["b", ("Var1", ...)], type=GenericTensorType), + start=BaseKey(type=int | None, value=start), + stop=BaseKey(type=int | None, value=stop), + step=BaseKey(type=int | None, value=step), ) self.factory_inputs = {"input": input} @@ -662,8 +664,8 @@ def __init__( super().__init__( formula_key="item", name=name, - output=IOKey(type=int | float), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=int | float), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} self._set_constraint( @@ -692,9 +694,9 @@ def __init__( super().__init__( formula_key="scalar_item", name=name, - output=IOKey(type=int | float | list | tuple), - input=IOKey(type=list | tuple), - index=IOKey(type=int, value=index), + output=BaseKey(type=int | float | list | tuple), + input=BaseKey(type=list | tuple), + index=BaseKey(type=int, value=index), ) self.factory_inputs = {"input": input, "index": index} @@ -730,9 +732,9 @@ def __init__( super().__init__( formula_key="tensor_item", name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - index=IOKey( + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + index=BaseKey( type=int | slice | EllipsisType @@ -770,8 +772,8 @@ def __init__( super().__init__( formula_key="to_tensor", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(type=int | float | list | tuple), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(type=int | float | list | tuple), ) self._set_constraint( @@ -791,9 +793,11 @@ class ToList(PrimitiveModel): def __init__(self, n: int, name: str | None = None, **kwargs) -> None: self.factory_args = {"n": n} key_definitions = {} - key_definitions["output"] = IOKey(type=list[int | float | bool | list | tuple]) + key_definitions["output"] = BaseKey( + type=list[int | float | bool | list | tuple] + ) key_definitions |= { - f"input{idx+1}": IOKey(type=int | float | bool | list | tuple) + f"input{idx+1}": BaseKey(type=int | float | bool | list | tuple) for idx in range(n) } self.factory_inputs = kwargs @@ -816,8 +820,8 @@ def __init__( super().__init__( formula_key="tensor_to_list", name=name, - output=IOKey(type=NestedListType(int | float | bool)), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(type=NestedListType(int | float | bool)), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} self._set_constraint( @@ -847,7 +851,7 @@ def __init__( name: str | None = None, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool | ToBeDetermined = False, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: # TODO: Handle axis type for conditional cases below. self.factory_args = {"axis": axis, "keepdim": keepdim} @@ -862,11 +866,11 @@ def __init__( else: raise ValueError("Requires valid axis type!") - init_kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - "input": IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - "axis": IOKey(type=axis_type, value=axis), - "keepdim": IOKey(type=bool, value=keepdim), + init_kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + "input": BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + "axis": BaseKey(type=axis_type, value=axis), + "keepdim": BaseKey(type=bool, value=keepdim), } super().__init__(formula_key=formula_key, name=name, **(init_kwargs | kwargs)) @@ -899,7 +903,7 @@ def __init__( name=name, axis=axis, keepdim=keepdim, - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "axis": axis, "keepdim": keepdim} # self.factory_inputs = {"input": input} @@ -954,7 +958,7 @@ def __init__( axis=axis, keepdim=keepdim, # axis = Scalar(axis_type, axis), # TODO: Change axis type to int - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input, "axis": axis, "keepdim": keepdim} @@ -991,7 +995,7 @@ def __init__( axis=axis, keepdim=keepdim, # axis = Scalar(axis_type, axis), # TODO: Change axis type to int - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input, "axis": axis, "keepdim": keepdim} @@ -1029,8 +1033,8 @@ def __init__( name=name, axis=axis, keepdim=keepdim, - correction=IOKey(type=float | int | None, value=correction), - output=IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), + correction=BaseKey(type=float | int | None, value=correction), + output=BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), ) self.factory_args = {"axis": axis, "correction": correction, "keepdim": keepdim} # TODO: Should we remove axis, correction and keepdim from factory_args? @@ -1067,14 +1071,14 @@ def __init__( formula_key: str, polymorphic_constraint: bool = True, name: str | None = None, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: default_kwargs = dict( - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) # Finalize kwargs. - new_kwargs: Mapping[str, IOKey] = default_kwargs | kwargs + new_kwargs: Mapping[str, BaseKey] = default_kwargs | kwargs super().__init__(formula_key, name=name, **new_kwargs) if polymorphic_constraint: @@ -1113,7 +1117,7 @@ def __init__( formula_key="exp", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input} @@ -1137,16 +1141,16 @@ def __init__( super().__init__( formula_key="robust_sqrt", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "cutoff": cutoff} else: super().__init__( formula_key="sqrt", - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1177,9 +1181,9 @@ def __init__(self, formula_key: str, name: str | None = None) -> None: super().__init__( formula_key=formula_key, name=name, - output=IOKey(shape=[("Var1", ...)], type=MyTensor[bool]), - left=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var3", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=MyTensor[bool]), + left=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var3", ...)], type=GenericTensorType), ) self._set_constraint(bcast, ["output", "left", "right"]) @@ -1269,8 +1273,8 @@ def __init__( super().__init__( formula_key="logical_not", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[bool]), - input=IOKey(shape=[("Var", ...)], type=MyTensor[bool]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[bool]), + input=BaseKey(shape=[("Var", ...)], type=MyTensor[bool]), ) self.factory_inputs = {"input": input} @@ -1295,9 +1299,9 @@ def __init__( super().__init__( formula_key=formula_key, name=name, - output=IOKey(shape=[("Var1", ...)], type=MyTensor[bool]), - left=IOKey(shape=[("Var2", ...)], type=MyTensor[bool]), - right=IOKey(shape=[("Var3", ...)], type=MyTensor[bool]), + output=BaseKey(shape=[("Var1", ...)], type=MyTensor[bool]), + left=BaseKey(shape=[("Var2", ...)], type=MyTensor[bool]), + right=BaseKey(shape=[("Var3", ...)], type=MyTensor[bool]), ) self.factory_inputs = {"left": left, "right": right} self._set_constraint(bcast, ["output", "left", "right"]) @@ -1359,9 +1363,9 @@ def __init__( super().__init__( formula_key="shift_left", name=name, - output=IOKey(shape=[("Var3", ...)], type=MyTensor[int]), - input=IOKey(shape=[("Var1", ...)], type=MyTensor[int]), - shift=IOKey(shape=[("Var2", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var3", ...)], type=MyTensor[int]), + input=BaseKey(shape=[("Var1", ...)], type=MyTensor[int]), + shift=BaseKey(shape=[("Var2", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input, "shift": shift} @@ -1390,9 +1394,9 @@ def __init__( super().__init__( formula_key="shift_right", name=name, - output=IOKey(shape=[("Var3", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - shift=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var3", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + shift=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input, "shift": shift} @@ -1426,9 +1430,9 @@ def __init__( super().__init__( formula_key="transpose", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - axes=IOKey(type=NoneType, value=axes), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + axes=BaseKey(type=NoneType, value=axes), ) self.factory_inputs = {"input": input, "axes": axes} self._set_constraint( @@ -1442,18 +1446,18 @@ def __init__( super().__init__( formula_key="transpose", name=name, - output=IOKey(shape=output_shapes, type=GenericTensorType), - input=IOKey(shape=input_shapes, type=GenericTensorType), - axes=IOKey(type=int | tuple[int, ...], value=axes), + output=BaseKey(shape=output_shapes, type=GenericTensorType), + input=BaseKey(shape=input_shapes, type=GenericTensorType), + axes=BaseKey(type=int | tuple[int, ...], value=axes), ) elif axes is TBD: super().__init__( formula_key="transpose", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - axes=IOKey(type=int | tuple[int, ...] | None, value=axes), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + axes=BaseKey(type=int | tuple[int, ...] | None, value=axes), ) self._set_constraint( fn=reverse_constraints, keys=["output", "input", "axes"] @@ -1488,10 +1492,10 @@ def __init__( super().__init__( formula_key="split", name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - split_size=IOKey(type=int, value=split_size), - axis=IOKey(type=int, value=axis), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + split_size=BaseKey(type=int, value=split_size), + axis=BaseKey(type=int, value=axis), ) self.factory_inputs = {"input": input, "split_size": split_size, "axis": axis} @@ -1527,10 +1531,10 @@ def __init__( super().__init__( formula_key="primitive_slice", name=name, - output=IOKey(type=slice), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), + output=BaseKey(type=slice), + start=BaseKey(type=int | None, value=start), + stop=BaseKey(type=int | None, value=stop), + step=BaseKey(type=int | None, value=step), ) self.factory_inputs = {"start": start, "stop": stop, "step": step} diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 9d12df5e..f80b1371 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -20,9 +20,9 @@ from ..common import ( NOT_AVAILABLE, TBD, + BaseKey, Connection, IOHyperEdge, - IOKey, KeyType, NotAvailable, Scalar, @@ -54,7 +54,7 @@ def __init__( self, formula_key: str, name: str | None = None, - **kwargs: IOKey | Tensor | Scalar, + **kwargs: BaseKey | Tensor | Scalar, ) -> None: self._formula_key = formula_key self.grad_formula = formula_key + "_grad" @@ -62,9 +62,9 @@ def __init__( super().__init__(name=name) # Get shape_templates of TensorTypes and create corresponding shapes. shape_templates = { - key: value._shape + key: value.shape for key, value in kwargs.items() - if isinstance(value, IOKey) and value._shape is not None + if isinstance(value, BaseKey) and value.shape is not None } shapes = create_shape_map(shape_templates, self.constraint_solver) data_set: set[Tensor] = set() @@ -73,8 +73,9 @@ def __init__( for key, value in kwargs.items(): # TODO: The first if block is temporary. All if else blocks will be # removed after the implementation of the new type system. - if get_origin(value._type) is Union: - args = get_args(value._type) + value_type = value.type if isinstance(value, BaseKey) else value._type + if get_origin(value_type) is Union: + args = get_args(value_type) types = [] for _type in args: # TODO: assertion will be removed, @@ -83,30 +84,30 @@ def __init__( types.append(get_mytensor_subtype(_type)) possible_types = reduce(lambda x, y: x | y, types) # type: ignore - assert isinstance(value, IOKey) + assert isinstance(value, BaseKey) _value: Tensor | Scalar = Tensor( shape=shapes[key].node, possible_types=possible_types, - value=value._value, # type: ignore - interval=value._interval, + value=value.value, # type: ignore + interval=value.interval, ) assert isinstance(_value, Tensor) data_set.add(_value) - elif is_mytensor_type(value._type): - assert isinstance(value, IOKey) + elif is_mytensor_type(value_type): + assert isinstance(value, BaseKey) _value = Tensor( shape=shapes[key].node, - possible_types=get_mytensor_subtype(value._type), # type: ignore - value=value._value, # type: ignore - interval=value._interval, + possible_types=get_mytensor_subtype(value_type), # type: ignore + value=value.value, # type: ignore + interval=value.interval, ) data_set.add(_value) elif isinstance(value, Tensor | Scalar): _value = value else: _value = Scalar( - possible_types=value._type, # type: ignore - value=value._value, # type: ignore + possible_types=value_type, # type: ignore + value=value.value, # type: ignore ) conn_data = self.create_connection(IOHyperEdge(_value), key) diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 9aaf7277..983f8e96 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -20,10 +20,10 @@ from ..framework.common import ( NOT_GIVEN, TBD, + BaseKey, Connection, ConnectionType, GenericTensorType, - IOKey, MyTensor, TensorValueType, ToBeDetermined, @@ -129,7 +129,7 @@ class CustomPrimitiveModel(PrimitiveModel): def __init__( - self, formula_key: str, name: str | None = None, **kwargs: IOKey + self, formula_key: str, name: str | None = None, **kwargs: BaseKey ) -> None: self.factory_args = {"formula_key": formula_key} | kwargs super().__init__(formula_key=formula_key, name=name, **kwargs) @@ -155,12 +155,12 @@ def __init__( formula_key: str, polymorphic_constraint: bool = True, name: str | None = None, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: - default_kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - "input": IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - "target": IOKey(shape=[("Var_3", ...)], type=GenericTensorType), + default_kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + "input": BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + "target": BaseKey(shape=[("Var_3", ...)], type=GenericTensorType), } # Finalize kwargs. kwargs = default_kwargs | kwargs @@ -224,7 +224,7 @@ def __init__( polymorphic_constraint=False, formula_key="hinge_loss", name=name, - output=IOKey(shape=["N", ("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=["N", ("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "target": target} @@ -240,7 +240,7 @@ def __init__( polymorphic_constraint=False, formula_key="quad_hinge_loss", name=name, - output=IOKey(shape=["N", ("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=["N", ("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "target": target} @@ -265,10 +265,10 @@ def __init__( super().__init__( formula_key="quantile_loss", name=name, - output=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - target=IOKey(shape=[("Var_3", ...)], type=GenericTensorType), - quantile=IOKey(shape=[], type=MyTensor[int] | MyTensor[float]), + output=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + target=BaseKey(shape=[("Var_3", ...)], type=GenericTensorType), + quantile=BaseKey(shape=[], type=MyTensor[int] | MyTensor[float]), ) self.factory_inputs = {"input": input, "target": target, "quantile": quantile} @@ -337,14 +337,14 @@ def __init__( else: final_weights = weights - kwargs: dict[str, IOKey] = { - "output": IOKey(shape=["N", ("Var", ...)], type=MyTensor[float]), - "input": IOKey(shape=["N", "C", ("Var", ...)], type=GenericTensorType), - "target": IOKey(shape=["N", ("VarTarget", ...)], type=GenericTensorType), - "weights": IOKey(type=weights_type, value=final_weights), - "categorical": IOKey(type=bool), - "cutoff": IOKey(shape=[], type=GenericTensorType), - "robust": IOKey(type=bool), + kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=["N", ("Var", ...)], type=MyTensor[float]), + "input": BaseKey(shape=["N", "C", ("Var", ...)], type=GenericTensorType), + "target": BaseKey(shape=["N", ("VarTarget", ...)], type=GenericTensorType), + "weights": BaseKey(type=weights_type, value=final_weights), + "categorical": BaseKey(type=bool), + "cutoff": BaseKey(shape=[], type=GenericTensorType), + "robust": BaseKey(type=bool), } if input_type == "logits": @@ -427,10 +427,10 @@ def __init__( super().__init__( formula_key="kl_divergence", name=name, - output=IOKey(shape=[("Var_1", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), - target=IOKey(shape=[("Var_3", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var_1", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), + target=BaseKey(shape=[("Var_3", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "target": target, "cutoff": cutoff} @@ -490,15 +490,15 @@ def __init__( pos_weight_type = ( float | bool if pos_weight in (..., None) else type(pos_weight) ) - kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), - "input": IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - "target": IOKey( + kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var_out", ...)], type=MyTensor[float]), + "input": BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + "target": BaseKey( shape=[("Var_out", ...)], type=MyTensor[int] | MyTensor[float] ), - "pos_weight": IOKey(type=pos_weight_type, value=pos_weight), - "cutoff": IOKey(shape=[], type=GenericTensorType), - "robust": IOKey(type=bool), + "pos_weight": BaseKey(type=pos_weight_type, value=pos_weight), + "cutoff": BaseKey(shape=[], type=GenericTensorType), + "robust": BaseKey(type=bool), } if input_type == "logits": @@ -560,17 +560,17 @@ def __init__( super().__init__( formula_key="robust_log", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "cutoff": cutoff} else: super().__init__( formula_key="log", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -606,9 +606,9 @@ def __init__( super().__init__( formula_key="stable_reciprocal", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - cutoff=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + cutoff=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input, "cutoff": cutoff} @@ -629,7 +629,7 @@ def __init__( formula_key="sin", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input} @@ -642,7 +642,7 @@ def __init__( formula_key="cos", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) self.factory_inputs = {"input": input} @@ -655,7 +655,7 @@ def __init__( formula_key="sign", name=name, polymorphic_constraint=False, - output=IOKey(shape=[("Var", ...)], type=MyTensor[int]), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[int]), ) self.factory_inputs = {"input": input} @@ -678,15 +678,15 @@ def __init__( formula_key: str, polymorphic_constraint: bool = False, name: str | None = None, - **kwargs: IOKey, + **kwargs: BaseKey, ) -> None: # NOTE: Torch and JAX behave different for some activation functions. # For example JAX handles int type inputs for GELU or LeakyRelu while # Torch assumes only float inputs for these activations. Since JAX handles # more general case, default types are written taking this into account. - default_kwargs: dict[str, IOKey] = dict( - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var", ...)], type=MyTensor[float]), + default_kwargs: dict[str, BaseKey] = dict( + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[float]), ) # Finalize kwargs. kwargs = default_kwargs | kwargs @@ -712,8 +712,8 @@ def __init__( formula_key="relu", name=name, polymorphic_constraint=True, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -727,7 +727,7 @@ def __init__( ) -> None: super().__init__( formula_key="gelu", - approximate=IOKey(value=approximate, type=(bool)), + approximate=BaseKey(value=approximate, type=(bool)), name=name, ) self.factory_inputs = {"input": input} @@ -758,7 +758,9 @@ def __init__( input: TensorValueType | ToBeDetermined = TBD, axis: int | None | ToBeDetermined = TBD, ) -> None: - super().__init__(formula_key="softmax", name=name, axis=IOKey(type=int | None)) + super().__init__( + formula_key="softmax", name=name, axis=BaseKey(type=int | None) + ) self.factory_inputs = {"input": input, "axis": axis} def __call__( # type: ignore[override] @@ -800,7 +802,7 @@ def __init__( super().__init__( formula_key="leaky_relu", name=name, - slope=IOKey(shape=[], type=MyTensor[float]), + slope=BaseKey(shape=[], type=MyTensor[float]), ) self.factory_inputs = {"input": input, "slope": slope} @@ -823,8 +825,8 @@ def __init__( super().__init__( formula_key="stop_gradient", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -848,9 +850,9 @@ def __init__( super().__init__( formula_key="cartesian_diff", name=name, - output=IOKey(shape=["N", "M", "dim"], type=GenericTensorType), - left=IOKey(shape=["N", "dim"], type=GenericTensorType), - right=IOKey(shape=["M", "dim"], type=GenericTensorType), + output=BaseKey(shape=["N", "M", "dim"], type=GenericTensorType), + left=BaseKey(shape=["N", "dim"], type=GenericTensorType), + right=BaseKey(shape=["M", "dim"], type=GenericTensorType), ) self.factory_inputs = {"left": left, "right": right} self._set_constraint( @@ -880,17 +882,17 @@ def __init__( ) -> None: self.factory_args = {"n": n, "axis": axis} - key_definitions: dict[str, IOKey] = {} - key_definitions["output"] = IOKey( + key_definitions: dict[str, BaseKey] = {} + key_definitions["output"] = BaseKey( shape=[("Var_out", ...)], type=GenericTensorType ) key_definitions |= { - f"input{idx+1}": IOKey( + f"input{idx+1}": BaseKey( shape=[(f"Var_{idx + 1}", ...)], type=GenericTensorType ) for idx in range(n) } - key_definitions["axis"] = IOKey(type=int | None, value=axis) + key_definitions["axis"] = BaseKey(type=int | None, value=axis) super().__init__(formula_key="concat", name=name, **key_definitions) # self.factory_inputs = {key: value for key, value in kwargs.items()} @@ -917,14 +919,14 @@ def __init__( ) -> None: self.factory_args = {"n": n} input_definitions = { - f"input{idx + 1}": IOKey(type=int | float | tuple[int | float, ...]) + f"input{idx + 1}": BaseKey(type=int | float | tuple[int | float, ...]) for idx in range(n) } super().__init__( formula_key="union", name=name, - output=IOKey(type=tuple[int | float, ...]), + output=BaseKey(type=tuple[int | float, ...]), **input_definitions, ) self.factory_inputs = kwargs # type: ignore @@ -944,9 +946,9 @@ def __init__( super().__init__( formula_key="permute_tensor", name=name, - output=IOKey(shape=["N", ("Var", ...)], type=GenericTensorType), - input=IOKey(shape=["N", ("Var", ...)], type=GenericTensorType), - indices=IOKey(shape=["N"], type=GenericTensorType), + output=BaseKey(shape=["N", ("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=["N", ("Var", ...)], type=GenericTensorType), + indices=BaseKey(shape=["N"], type=GenericTensorType), ) self.factory_inputs = {"input": input, "indices": indices} @@ -987,18 +989,18 @@ def __init__( ) -> None: self.factory_args = {"use_bias": use_bias} formula_key = "conv1d_bias" - kwargs: dict[str, IOKey] = { - "output": IOKey( + kwargs: dict[str, BaseKey] = { + "output": BaseKey( shape=["N", "out_channels", "d_out"], type=GenericTensorType ), - "input": IOKey(shape=["N", "C_in", "d_in"], type=GenericTensorType), - "weight": IOKey( + "input": BaseKey(shape=["N", "C_in", "d_in"], type=GenericTensorType), + "weight": BaseKey( shape=["out_channels", "C_in", "kernel_size"], type=GenericTensorType ), - "bias": IOKey(shape=[1, "out_channels", 1], type=GenericTensorType), - "stride": IOKey(type=int), - "padding": IOKey(type=int | tuple[int, int]), - "dilation": IOKey(type=int), + "bias": BaseKey(shape=[1, "out_channels", 1], type=GenericTensorType), + "stride": BaseKey(type=int), + "padding": BaseKey(type=int | tuple[int, int]), + "dilation": BaseKey(type=int), } self.factory_inputs = { "input": input, @@ -1084,21 +1086,21 @@ def __init__( ) -> None: self.factory_args = {"use_bias": use_bias} formula_key = "conv2d_bias" - kwargs: dict[str, IOKey] = { - "output": IOKey( + kwargs: dict[str, BaseKey] = { + "output": BaseKey( shape=["N", "out_channels", "H_out", "W_out"], type=GenericTensorType ), - "input": IOKey(shape=["N", "C_in", "H", "W"], type=GenericTensorType), - "weight": IOKey( + "input": BaseKey(shape=["N", "C_in", "H", "W"], type=GenericTensorType), + "weight": BaseKey( shape=["out_channels", "C_in", "kernel_size_0", "kernel_size_1"], type=GenericTensorType, ), - "bias": IOKey(shape=[1, "out_channels", 1, 1], type=GenericTensorType), - "stride": IOKey(type=int | tuple[int, int]), - "padding": IOKey( + "bias": BaseKey(shape=[1, "out_channels", 1, 1], type=GenericTensorType), + "stride": BaseKey(type=int | tuple[int, int]), + "padding": BaseKey( type=int | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - "dilation": IOKey(type=int | tuple[int, int]), + "dilation": BaseKey(type=int | tuple[int, int]), } if not use_bias: @@ -1172,11 +1174,11 @@ def __init__( ) -> None: self.factory_args = {"start_dim": start_dim, "end_dim": end_dim} - key_definitions: dict[str, IOKey] = { - "output": IOKey(shape=[("C_out", ...)], type=GenericTensorType), - "input": IOKey(shape=[("C_in", ...)], type=GenericTensorType), - "start_dim": IOKey(type=int, value=start_dim), - "end_dim": IOKey(type=int, value=end_dim), + key_definitions: dict[str, BaseKey] = { + "output": BaseKey(shape=[("C_out", ...)], type=GenericTensorType), + "input": BaseKey(shape=[("C_in", ...)], type=GenericTensorType), + "start_dim": BaseKey(type=int, value=start_dim), + "end_dim": BaseKey(type=int, value=end_dim), } super().__init__(formula_key="flatten", name=name, **key_definitions) # self.factory_inputs = {"input": input} @@ -1226,12 +1228,12 @@ def __init__( super().__init__( formula_key="max_pool1d", name=name, - output=IOKey(shape=["N", ("C_in", ...), "W_out"], type=GenericTensorType), - input=IOKey(shape=["N", ("C_in", ...), "W"], type=GenericTensorType), - kernel_size=IOKey(type=int), - stride=IOKey(type=int), - padding=IOKey(type=tuple[int, int]), - dilation=IOKey(type=int), + output=BaseKey(shape=["N", ("C_in", ...), "W_out"], type=GenericTensorType), + input=BaseKey(shape=["N", ("C_in", ...), "W"], type=GenericTensorType), + kernel_size=BaseKey(type=int), + stride=BaseKey(type=int), + padding=BaseKey(type=tuple[int, int]), + dilation=BaseKey(type=int), ) self.factory_inputs = { "input": input, @@ -1282,9 +1284,9 @@ def __init__( super().__init__( formula_key="padding_converter_1d", name=name, - output=IOKey(type=tuple[int, int]), - input=IOKey(type=int | PaddingType | tuple[int, int]), - kernel_size=IOKey(type=int), + output=BaseKey(type=tuple[int, int]), + input=BaseKey(type=int | PaddingType | tuple[int, int]), + kernel_size=BaseKey(type=int), ) self.factory_inputs = {"input": input, "kernel_size": kernel_size} @@ -1320,16 +1322,16 @@ def __init__( super().__init__( formula_key="padding_converter_2d", name=name, - output=IOKey( + output=BaseKey( type=tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - input=IOKey( + input=BaseKey( type=int | PaddingType | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - kernel_size=IOKey(type=tuple[int, int]), + kernel_size=BaseKey(type=tuple[int, int]), ) self.factory_inputs = {"input": input, "kernel_size": kernel_size} @@ -1361,9 +1363,9 @@ def __init__( super().__init__( formula_key="stride_converter", name=name, - output=IOKey(type=int | tuple[int, int]), - input=IOKey(type=int | PaddingType | tuple[int, int] | None), - kernel_size=IOKey(type=int | tuple[int, int]), + output=BaseKey(type=int | tuple[int, int]), + input=BaseKey(type=int | PaddingType | tuple[int, int] | None), + kernel_size=BaseKey(type=int | tuple[int, int]), ) self.factory_inputs = {"input": input, "kernel_size": kernel_size} self._set_constraint( @@ -1396,10 +1398,10 @@ def __init__( super().__init__( formula_key="tuple_converter", name=name, - output=IOKey( + output=BaseKey( type=tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - input=IOKey( + input=BaseKey( type=int | PaddingType | tuple[int, int] @@ -1440,16 +1442,16 @@ def __init__( super().__init__( formula_key="max_pool2d", name=name, - output=IOKey( + output=BaseKey( shape=["N", ("C_in", ...), "H_out", "W_out"], type=GenericTensorType ), - input=IOKey(shape=["N", ("C_in", ...), "H", "W"], type=GenericTensorType), - kernel_size=IOKey(type=tuple[int, int]), - stride=IOKey(type=tuple[int, int]), - padding=IOKey( + input=BaseKey(shape=["N", ("C_in", ...), "H", "W"], type=GenericTensorType), + kernel_size=BaseKey(type=tuple[int, int]), + stride=BaseKey(type=tuple[int, int]), + padding=BaseKey( type=tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] ), - dilation=IOKey(type=tuple[int, int]), + dilation=BaseKey(type=tuple[int, int]), ) self.factory_inputs = { "input": input, @@ -1514,8 +1516,8 @@ def __init__( super().__init__( formula_key="norm_modifier", name=name, - output=IOKey(shape=[], type=GenericTensorType), - input=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[], type=GenericTensorType), + input=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1545,10 +1547,10 @@ def __init__( super().__init__( formula_key="distance_matrix", name=name, - output=IOKey(shape=["N", "M"], type=GenericTensorType), - left=IOKey(shape=["N", "d"], type=GenericTensorType), - right=IOKey(shape=["M", "d"], type=GenericTensorType), - norm=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=["N", "M"], type=GenericTensorType), + left=BaseKey(shape=["N", "d"], type=GenericTensorType), + right=BaseKey(shape=["M", "d"], type=GenericTensorType), + norm=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"left": left, "right": right} @@ -1581,9 +1583,9 @@ def __init__( super().__init__( formula_key="polynomial_features", name=name, - output=IOKey(shape=["N", "d_out"], type=GenericTensorType), - input=IOKey(shape=["N", "d_in"], type=GenericTensorType), - degree=IOKey(type=int, value=degree), + output=BaseKey(shape=["N", "d_out"], type=GenericTensorType), + input=BaseKey(shape=["N", "d_in"], type=GenericTensorType), + degree=BaseKey(type=int, value=degree), ) self.factory_inputs = {"input": input, "degree": degree} @@ -1621,10 +1623,10 @@ def __init__( super().__init__( formula_key="tsne_p_joint", name=name, - output=IOKey(shape=["N", "M"], type=MyTensor[float]), - squared_distances=IOKey(shape=["N", "M"], type=GenericTensorType), - target_perplexity=IOKey(shape=[], type=MyTensor[float]), - threshold=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=["N", "M"], type=MyTensor[float]), + squared_distances=BaseKey(shape=["N", "M"], type=GenericTensorType), + target_perplexity=BaseKey(shape=[], type=MyTensor[float]), + threshold=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = { "squared_distances": squared_distances, @@ -1661,9 +1663,9 @@ def __init__( super().__init__( formula_key="ones_with_zero_diag", name=name, - output=IOKey(shape=["N", "M"], type=MyTensor[float]), - N=IOKey(type=int, value=N), - M=IOKey(type=int | None, value=M), + output=BaseKey(shape=["N", "M"], type=MyTensor[float]), + N=BaseKey(type=int, value=N), + M=BaseKey(type=int | None, value=M), ) self.factory_inputs = {"N": N, "M": M} self._set_constraint(fn=eye_constraints, keys=["output", "N", "M"]) @@ -1691,9 +1693,9 @@ def __init__( super().__init__( formula_key="eye", name=name, - output=IOKey(shape=["N", "M"], type=MyTensor[float]), - N=IOKey(type=int, value=N), - M=IOKey(type=int | None, value=M), + output=BaseKey(shape=["N", "M"], type=MyTensor[float]), + N=BaseKey(type=int, value=N), + M=BaseKey(type=int | None, value=M), ) self.factory_inputs = {"N": N, "M": M} @@ -1718,8 +1720,8 @@ def __init__( super().__init__( formula_key="cholesky", name=name, - output=IOKey(shape=["N", "N"], type=MyTensor[float]), - input=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", "N"], type=MyTensor[float]), + input=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1745,10 +1747,10 @@ def __init__( super().__init__( formula_key="gpr_alpha", name=name, - output=IOKey(shape=["N", 1], type=MyTensor[float]), - label_mu_diff=IOKey(shape=["N", 1], type=GenericTensorType), - L=IOKey(shape=["N", "N"], type=GenericTensorType), - K_term=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", 1], type=MyTensor[float]), + label_mu_diff=BaseKey(shape=["N", 1], type=GenericTensorType), + L=BaseKey(shape=["N", "N"], type=GenericTensorType), + K_term=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"label_mu_diff": label_mu_diff, "L": L, "K_term": K_term} @@ -1780,10 +1782,10 @@ def __init__( super().__init__( formula_key="gpr_v_outer", name=name, - output=IOKey(shape=["N", "N"], type=MyTensor[float]), - K=IOKey(shape=["N", "N"], type=GenericTensorType), - K_term=IOKey(shape=["N", "N"], type=GenericTensorType), - L=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", "N"], type=MyTensor[float]), + K=BaseKey(shape=["N", "N"], type=GenericTensorType), + K_term=BaseKey(shape=["N", "N"], type=GenericTensorType), + L=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"K": K, "K_term": K_term, "L": L} @@ -1807,8 +1809,8 @@ def __init__( super().__init__( formula_key="transposed_diag", name=name, - output=IOKey(shape=["N", 1], type=GenericTensorType), - input=IOKey(shape=["N", "N"], type=GenericTensorType), + output=BaseKey(shape=["N", 1], type=GenericTensorType), + input=BaseKey(shape=["N", "N"], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -1857,10 +1859,10 @@ def __init__( super().__init__( formula_key="arange", name=name, - output=IOKey(shape=output_shp, type=GenericTensorType), - start=IOKey(type=int | float, value=start), - stop=IOKey(type=int | float, value=stop), - step=IOKey(type=int | float, value=step), + output=BaseKey(shape=output_shp, type=GenericTensorType), + start=BaseKey(type=int | float, value=start), + stop=BaseKey(type=int | float, value=stop), + step=BaseKey(type=int | float, value=step), ) self.set_canonical_input("stop") self.factory_inputs = {"start": start, "stop": stop, "step": step} @@ -1894,8 +1896,8 @@ def __init__( super().__init__( formula_key="randn", name=name, - output=IOKey(shape=[("output", ...)], type=GenericTensorType), - shape=IOKey(type=tuple[int, ...], value=shape), + output=BaseKey(shape=[("output", ...)], type=GenericTensorType), + shape=BaseKey(type=tuple[int, ...], value=shape), ) self.set_constraint(randn_constraints, keys=["output", "shape"]) @@ -1922,9 +1924,9 @@ def __init__( super().__init__( formula_key="broadcast_to", name=name, - output=IOKey(shape=[("output", ...)], type=GenericTensorType), - input=IOKey(shape=[("input", ...)], type=GenericTensorType), - shape=IOKey(type=tuple[int, ...], value=shape), + output=BaseKey(shape=[("output", ...)], type=GenericTensorType), + input=BaseKey(shape=[("input", ...)], type=GenericTensorType), + shape=BaseKey(type=tuple[int, ...], value=shape), ) self.factory_inputs = {"input": input, "shape": shape} @@ -1960,10 +1962,10 @@ def __init__( super().__init__( formula_key="eigvalsh", name=name, - output=IOKey(shape=["N", 1], type=MyTensor[float]), - K_term=IOKey(shape=["N", "N"], type=GenericTensorType), - L=IOKey(shape=["N", "N"], type=GenericTensorType), - threshold=IOKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=["N", 1], type=MyTensor[float]), + K_term=BaseKey(shape=["N", "N"], type=GenericTensorType), + L=BaseKey(shape=["N", "N"], type=GenericTensorType), + threshold=BaseKey(shape=[], type=GenericTensorType), ) self.factory_inputs = {"K_term": K_term, "L": L, "threshold": threshold} @@ -1987,8 +1989,8 @@ def __init__( super().__init__( formula_key="squeeze", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -2017,9 +2019,9 @@ def __init__( super().__init__( formula_key="auc_core", name=name, - output=IOKey(shape=[2, "M"], type=MyTensor[float]), - input=IOKey(shape=["N"], type=GenericTensorType), - label=IOKey(shape=["N"], type=GenericTensorType), + output=BaseKey(shape=[2, "M"], type=MyTensor[float]), + input=BaseKey(shape=["N"], type=GenericTensorType), + label=BaseKey(shape=["N"], type=GenericTensorType), ) self.factory_inputs = {"input": input, "label": label} @@ -2050,9 +2052,9 @@ def __init__( super().__init__( formula_key="primitive_embedding", name=name, - output=IOKey(shape=[("N1", ...), "d1", out_dim], type=GenericTensorType), - input=IOKey(shape=[("N1", ...), "d1"], type=MyTensor[int]), - weight=IOKey(shape=[num_embeddings, out_dim], type=GenericTensorType), + output=BaseKey(shape=[("N1", ...), "d1", out_dim], type=GenericTensorType), + input=BaseKey(shape=[("N1", ...), "d1"], type=MyTensor[int]), + weight=BaseKey(shape=[num_embeddings, out_dim], type=GenericTensorType), ) self.factory_inputs = {"input": input, "weight": weight} @@ -2100,19 +2102,19 @@ def __init__( self.use_attn_mask = use_attn_mask formula_key = "scaled_dot_product_attention" - kwargs: dict[str, IOKey] = { - "output": IOKey(shape=[("Var", ...), "L", "O"], type=MyTensor[float]), - "query": IOKey(shape=[("Var", ...), "L", "E"], type=GenericTensorType), - "key": IOKey(shape=[("Var", ...), "S", "E"], type=GenericTensorType), - "value": IOKey(shape=[("Var", ...), "S", "O"], type=GenericTensorType), - "dropout_p": IOKey(type=float, value=dropout_p), - "attn_mask": IOKey(type=NoneType, value=None), - "is_causal": IOKey(type=bool, value=is_causal), - "scale": IOKey(type=NoneType | int | float, value=scale), + kwargs: dict[str, BaseKey] = { + "output": BaseKey(shape=[("Var", ...), "L", "O"], type=MyTensor[float]), + "query": BaseKey(shape=[("Var", ...), "L", "E"], type=GenericTensorType), + "key": BaseKey(shape=[("Var", ...), "S", "E"], type=GenericTensorType), + "value": BaseKey(shape=[("Var", ...), "S", "O"], type=GenericTensorType), + "dropout_p": BaseKey(type=float, value=dropout_p), + "attn_mask": BaseKey(type=NoneType, value=None), + "is_causal": BaseKey(type=bool, value=is_causal), + "scale": BaseKey(type=NoneType | int | float, value=scale), } if use_attn_mask: - kwargs["attn_mask"] = IOKey( + kwargs["attn_mask"] = BaseKey( shape=["L", "S"], type=GenericTensorType, value=TBD ) @@ -2143,7 +2145,7 @@ def __call__( # type: ignore[override] not self.use_attn_mask and attn_mask is not NOT_GIVEN and not isinstance(attn_mask, str) - and isinstance(attn_mask, IOKey) + and isinstance(attn_mask, BaseKey) and attn_mask._value is not None # TODO: Here will be updated! ): raise KeyError( @@ -2181,10 +2183,10 @@ def __init__( super().__init__( formula_key="positional_encoding", name=name, - output=IOKey(shape=[("N1", ...)], type=GenericTensorType), - input=IOKey(shape=[("N1", ...)], type=GenericTensorType), - hidden_dim=IOKey(type=int, value=hidden_dim), - max_len=IOKey(type=int, value=max_len), + output=BaseKey(shape=[("N1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("N1", ...)], type=GenericTensorType), + hidden_dim=BaseKey(type=int, value=hidden_dim), + max_len=BaseKey(type=int, value=max_len), ) self.factory_inputs = { "input": input, @@ -2222,10 +2224,10 @@ def __init__( super().__init__( formula_key="swapaxes", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_in", ...)], type=GenericTensorType), - axis1=IOKey(type=int, value=axis1), - axis2=IOKey(type=int, value=axis2), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_in", ...)], type=GenericTensorType), + axis1=BaseKey(type=int, value=axis1), + axis2=BaseKey(type=int, value=axis2), ) self.factory_inputs = {"input": input, "axis1": axis1, "axis2": axis2} @@ -2262,10 +2264,10 @@ def __init__( super().__init__( formula_key="where", name=name, - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - cond=IOKey(shape=[("Var3", ...)], type=MyTensor[bool], value=TBD), - input1=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + cond=BaseKey(shape=[("Var3", ...)], type=MyTensor[bool], value=TBD), + input1=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self.factory_inputs = {"cond": cond, "input1": input1, "input2": input2} @@ -2297,8 +2299,8 @@ def __init__( super().__init__( formula_key="isnan", name=name, - output=IOKey(shape=[("Var", ...)], type=MyTensor[bool]), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=MyTensor[bool]), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -2318,8 +2320,8 @@ def __init__( super().__init__( formula_key="unique", name=name, - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} @@ -2343,9 +2345,9 @@ def __init__( super().__init__( formula_key="trapezoid", name=name, - output=IOKey(shape=[], type=GenericTensorType), - y=IOKey(shape=[("Var", ...)], type=GenericTensorType), - x=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[], type=GenericTensorType), + y=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + x=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"y": y, "x": x} @@ -2376,11 +2378,11 @@ def __init__( super().__init__( formula_key="nan_to_num", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), - nan=IOKey(type=float, value=nan), - posinf=IOKey(type=float | None, value=posinf), - neginf=IOKey(type=float | None, value=neginf), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + nan=BaseKey(type=float, value=nan), + posinf=BaseKey(type=float | None, value=posinf), + neginf=BaseKey(type=float | None, value=neginf), ) self.factory_inputs = { "input": input, @@ -2417,9 +2419,9 @@ def __init__( super().__init__( formula_key="pad", name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - pad_width=IOKey(type=tuple[tuple[int, int], ...], value=pad_width), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + pad_width=BaseKey(type=tuple[tuple[int, int], ...], value=pad_width), ) self.factory_inputs = {"input": input, "pad_width": pad_width} @@ -2446,8 +2448,8 @@ def __init__( super().__init__( formula_key="zeros_like", name=name, - output=IOKey(shape=[("Var", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var", ...)], type=GenericTensorType), ) self.factory_inputs = {"input": input} diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index 43421af3..fdf38fa6 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -18,9 +18,9 @@ from mithril.framework import Scalar, Tensor from mithril.framework.common import ( NOT_GIVEN, + BaseKey, ConnectionType, GenericTensorType, - IOKey, MyTensor, ShapeRepr, Uniadic, @@ -98,8 +98,8 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) @@ -111,8 +111,8 @@ class Model2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) self._set_constraint( @@ -127,8 +127,8 @@ class Model3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=MyTensor[int] | MyTensor[bool]), - output=IOKey(shape=[("Var2", ...)], type=MyTensor[int] | MyTensor[bool]), + input=BaseKey(shape=[("Var1", ...)], type=MyTensor[int] | MyTensor[bool]), + output=BaseKey(shape=[("Var2", ...)], type=MyTensor[int] | MyTensor[bool]), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) @@ -141,9 +141,9 @@ class MyAdd2(PrimitiveModel): def __init__(self, left, right, output) -> None: super().__init__( formula_key="add", - output=IOKey(shape=output, type=GenericTensorType), - left=IOKey(shape=left, type=GenericTensorType), - right=IOKey(shape=right, type=GenericTensorType), + output=BaseKey(shape=output, type=GenericTensorType), + left=BaseKey(shape=left, type=GenericTensorType), + right=BaseKey(shape=right, type=GenericTensorType), ) self._set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] diff --git a/tests/scripts/test_functions.py b/tests/scripts/test_functions.py index 55a31abb..bbe5ecf6 100644 --- a/tests/scripts/test_functions.py +++ b/tests/scripts/test_functions.py @@ -20,7 +20,7 @@ from mithril import CBackend, JaxBackend, NumpyBackend, TorchBackend from mithril.backends.with_manualgrad.numpy_backend.ops_grad import add_grad from mithril.framework import NOT_GIVEN, ConnectionType, ExtendInfo -from mithril.framework.common import GenericTensorType, IOKey +from mithril.framework.common import BaseKey, GenericTensorType, IOKey from mithril.framework.constraints import bcast from mithril.models import ( Absolute, @@ -413,9 +413,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - rhs=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + rhs=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "input", "rhs"] @@ -522,9 +522,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - rhs=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + rhs=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "input", "rhs"] diff --git a/tests/scripts/test_jittable.py b/tests/scripts/test_jittable.py index ae3ec005..9052c78f 100644 --- a/tests/scripts/test_jittable.py +++ b/tests/scripts/test_jittable.py @@ -29,7 +29,7 @@ to_tensor, ) from mithril.framework import NOT_GIVEN, ConnectionType, ExtendInfo -from mithril.framework.common import GenericTensorType +from mithril.framework.common import BaseKey, GenericTensorType from mithril.framework.constraints import bcast from mithril.models import ( TBD, @@ -221,9 +221,9 @@ class Adder(CustomPrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint(fn=bcast, keys=["output", "left", "right"]) @@ -345,9 +345,9 @@ class Adder(CustomPrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint(fn=bcast, keys=["output", "left", "right"]) diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index 460bf2cd..8876a549 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -16,7 +16,7 @@ import mithril from mithril import JaxBackend, TorchBackend -from mithril.framework.common import TBD, GenericTensorType, IOKey +from mithril.framework.common import TBD, BaseKey, GenericTensorType, IOKey from mithril.framework.constraints import squeeze_constraints from mithril.models import ( L2, @@ -920,9 +920,9 @@ def __init__(self, threshold=3) -> None: threshold *= 2 super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - rhs=IOKey(type=int, value=threshold), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + rhs=BaseKey(type=int, value=threshold), ) self.set_constraint( fn=squeeze_constraints, keys=[CustomPrimitiveModel.output_key, "input"] diff --git a/tests/scripts/test_ref_counts.py b/tests/scripts/test_ref_counts.py index c7001ac2..62294f42 100644 --- a/tests/scripts/test_ref_counts.py +++ b/tests/scripts/test_ref_counts.py @@ -18,6 +18,7 @@ from mithril.framework.common import ( NOT_GIVEN, + BaseKey, Connection, ConnectionType, GenericTensorType, @@ -81,8 +82,8 @@ class TestModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["c", "d", ("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["c", "d", ("Var2", ...)], type=GenericTensorType), ) model = Model() @@ -114,8 +115,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), ) buff_model1 = MyModel() @@ -345,8 +346,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) all_uniadics = set() @@ -384,8 +385,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", "c"], type=GenericTensorType), - output=IOKey(shape=["c", "d", "e"], type=GenericTensorType), + input=BaseKey(shape=["a", "b", "c"], type=GenericTensorType), + output=BaseKey(shape=["c", "d", "e"], type=GenericTensorType), ) all_uniadics = set() @@ -409,8 +410,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", "c"], type=GenericTensorType), - output=IOKey(shape=["d", "e", "f"], type=GenericTensorType), + input=BaseKey(shape=["a", "b", "c"], type=GenericTensorType), + output=BaseKey(shape=["d", "e", "f"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -440,8 +441,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[1, 1, 1], type=GenericTensorType), - output=IOKey(shape=[1, 1, 1], type=GenericTensorType), + input=BaseKey(shape=[1, 1, 1], type=GenericTensorType), + output=BaseKey(shape=[1, 1, 1], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -564,8 +565,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -618,8 +619,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -671,8 +672,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -707,8 +708,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["b", "a"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["b", "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -1268,8 +1269,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a1"], type=GenericTensorType), - output=IOKey(shape=["a2"], type=GenericTensorType), + input=BaseKey(shape=["a1"], type=GenericTensorType), + output=BaseKey(shape=["a2"], type=GenericTensorType), ) buff_model1 = MyModel() @@ -1308,8 +1309,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1340,8 +1341,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1372,8 +1373,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1402,8 +1403,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() @@ -1430,8 +1431,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["c", "d"], type=GenericTensorType), ) model = Model() diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 4cd38765..7a711fe5 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -34,6 +34,7 @@ NOT_AVAILABLE, NOT_GIVEN, TBD, + BaseKey, ConnectionData, ConnectionType, GenericTensorType, @@ -1226,9 +1227,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] @@ -4436,9 +4437,9 @@ class MyAdder(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="my_adder", - output=IOKey(shape=[("Var_out", ...)], type=GenericTensorType), - left=IOKey(shape=[("Var_1", ...)], type=GenericTensorType), - right=IOKey(shape=[("Var_2", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var_out", ...)], type=GenericTensorType), + left=BaseKey(shape=[("Var_1", ...)], type=GenericTensorType), + right=BaseKey(shape=[("Var_2", ...)], type=GenericTensorType), ) self.set_constraint( fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] @@ -7357,17 +7358,17 @@ def __init__( all_output_shapes = list(output) # Create IOKey shape = and Scalar Input, type = GenericTensorTypes # Note that equation is string - tensor_input = IOKey(shape=all_input_shapes, type=GenericTensorType) - tensor_output = IOKey(shape=all_output_shapes, type=GenericTensorType) - scalar_equation = IOKey(type=str, value=equation) + tensor_input = BaseKey(shape=all_input_shapes, type=GenericTensorType) + tensor_output = BaseKey(shape=all_output_shapes, type=GenericTensorType) + scalar_equation = BaseKey(type=str, value=equation) else: # case where equation is TBD - tensor_input = IOKey(shape=[("Var1", ...)], type=GenericTensorType) - tensor_output = IOKey(shape=[("Var2", ...)], type=GenericTensorType) - scalar_equation = IOKey(type=str) + tensor_input = BaseKey(shape=[("Var1", ...)], type=GenericTensorType) + tensor_output = BaseKey(shape=[("Var2", ...)], type=GenericTensorType) + scalar_equation = BaseKey(type=str) - kwargs: dict[str, IOKey] = { + kwargs: dict[str, BaseKey] = { "output": tensor_output, "input": tensor_input, "equation": scalar_equation, @@ -7441,17 +7442,17 @@ def __init__( all_output_shapes = list(output) # Create TensorType and Scalar Inputs # Note that equation is string - tensor_input = IOKey(shape=all_input_shapes, type=GenericTensorType) - tensor_output = IOKey(shape=all_output_shapes, type=GenericTensorType) - scalar_equation = IOKey(type=str, value=equation) + tensor_input = BaseKey(shape=all_input_shapes, type=GenericTensorType) + tensor_output = BaseKey(shape=all_output_shapes, type=GenericTensorType) + scalar_equation = BaseKey(type=str, value=equation) else: # case where equation is TBD - tensor_input = IOKey(shape=[("Var1", ...)], type=GenericTensorType) - tensor_output = IOKey(shape=[("Var2", ...)], type=GenericTensorType) - scalar_equation = IOKey(type=str) + tensor_input = BaseKey(shape=[("Var1", ...)], type=GenericTensorType) + tensor_output = BaseKey(shape=[("Var2", ...)], type=GenericTensorType) + scalar_equation = BaseKey(type=str) - kwargs: dict[str, IOKey] = { + kwargs: dict[str, BaseKey] = { "output": tensor_output, "input": tensor_input, "equation": scalar_equation, diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index b18dc97e..dd49a95d 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -27,6 +27,7 @@ AND, DNF, NOT_GIVEN, + BaseKey, Connection, ConnectionType, Equivalences, @@ -2931,8 +2932,8 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -2948,8 +2949,8 @@ class Model2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), - output=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + output=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -2967,10 +2968,10 @@ class Model3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="concat", - input1=IOKey(shape=["u1", "u2", "u3"], type=GenericTensorType), - input2=IOKey(shape=["u3", "u2", "u1"], type=GenericTensorType), - output=IOKey(shape=["u1", ("Var1", ...), "u3"], type=GenericTensorType), - axis=IOKey(type=int), + input1=BaseKey(shape=["u1", "u2", "u3"], type=GenericTensorType), + input2=BaseKey(shape=["u3", "u2", "u1"], type=GenericTensorType), + output=BaseKey(shape=["u1", ("Var1", ...), "u3"], type=GenericTensorType), + axis=BaseKey(type=int), ) def __call__( # type: ignore[override] @@ -2989,8 +2990,8 @@ class Model4(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), 1], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), 1], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3007,9 +3008,9 @@ class Model5(PrimitiveModel): def __init__(self, axis=None) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - axis=IOKey(type=NoneType | list[int], value=axis), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + axis=BaseKey(type=NoneType | list[int], value=axis), ) def __call__( # type: ignore[override] @@ -3028,8 +3029,8 @@ class Model6(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + input=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3045,8 +3046,8 @@ class Model7(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3062,8 +3063,8 @@ class Model8(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="relu", - input=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3079,8 +3080,8 @@ class Model9(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["u2", "u1", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["u2", "u1", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3968,8 +3969,8 @@ class MyVariadic1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -3985,8 +3986,8 @@ class MyVariadic2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4002,8 +4003,8 @@ class MyVariadic3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4019,8 +4020,8 @@ class MyVariadic4(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4036,8 +4037,8 @@ class MyVariadic5(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a", "b"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4053,8 +4054,8 @@ class MyVariadic6(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "a"], type=GenericTensorType), - output=IOKey(shape=["a", "a"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "a"], type=GenericTensorType), + output=BaseKey(shape=["a", "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4070,8 +4071,8 @@ class MyVariadic7(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), - output=IOKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1", "u2"], type=GenericTensorType), + output=BaseKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4092,27 +4093,27 @@ class MyVariadic8(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey( + input1=BaseKey( shape=["u1", "u2", "u3", ("Var1", ...)], type=GenericTensorType ), - input2=IOKey( + input2=BaseKey( shape=["u4", "u5", ("Var2", ...), "u6"], type=GenericTensorType ), - input3=IOKey( + input3=BaseKey( shape=["u7", ("Var3", ...), "u8", "u9"], type=GenericTensorType ), - input4=IOKey( + input4=BaseKey( shape=[("Var4", ...), "u10", "u11", "u12"], type=GenericTensorType ), - input5=IOKey( + input5=BaseKey( shape=[("Var5", ...), "u13", "u14", "u15", "u16"], type=GenericTensorType, ), - input6=IOKey( + input6=BaseKey( shape=["u17", "u18", ("Var6", ...), "u19", "u20"], type=GenericTensorType, ), - output=IOKey( + output=BaseKey( shape=["u13", ("Var1", ...), "u14", "u15", "u16"], type=GenericTensorType, ), @@ -4151,10 +4152,10 @@ class MyVariadic9(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=[("Var2", ...), "u2"], type=GenericTensorType), - input3=IOKey(shape=["u3", ("Var3", ...), "u4"], type=GenericTensorType), - output=IOKey(shape=["u5", "u5"], type=GenericTensorType), + input1=BaseKey(shape=["u1", ("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=[("Var2", ...), "u2"], type=GenericTensorType), + input3=BaseKey(shape=["u3", ("Var3", ...), "u4"], type=GenericTensorType), + output=BaseKey(shape=["u5", "u5"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4181,16 +4182,16 @@ class MyVariadic10(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["u1", "u2", ("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), - input3=IOKey(shape=[("Var3", ...), "u5", "u6"], type=GenericTensorType), - input4=IOKey( + input1=BaseKey(shape=["u1", "u2", ("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=["u3", ("Var2", ...), "u4"], type=GenericTensorType), + input3=BaseKey(shape=[("Var3", ...), "u5", "u6"], type=GenericTensorType), + input4=BaseKey( shape=["u7", "u8", ("Var4", ...), "u9", "u10"], type=GenericTensorType ), - input5=IOKey( + input5=BaseKey( shape=["u11", ("Var4", ...), "u12", "u13"], type=GenericTensorType ), - output=IOKey(shape=["u5", "u5"], type=GenericTensorType), + output=BaseKey(shape=["u5", "u5"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4222,8 +4223,8 @@ class MyVariadic11(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -4239,8 +4240,10 @@ class MyVariadic12(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", "b", "c", ("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey( + shape=["a", "b", "c", ("Var1", ...)], type=GenericTensorType + ), ) def __call__( # type: ignore[override] @@ -5131,9 +5134,9 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - input2=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input2=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), ) model = MyModel() @@ -5163,9 +5166,9 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - input2=IOKey(shape=["b", "c"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input2=BaseKey(shape=["b", "c"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), ) model = MyModel() @@ -5191,8 +5194,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - output=IOKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + output=BaseKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5232,8 +5235,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - output=IOKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + output=BaseKey(shape=["b", ("Var1", ...), "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5549,10 +5552,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "a", "b", "c"], type=GenericTensorType ), - output=IOKey( + output=BaseKey( shape=["c", ("Var1", ...), "a", "b"], type=GenericTensorType ), ) @@ -5596,10 +5599,12 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), - input2=IOKey(shape=["_a", ("Var1", ...), "_b"], type=GenericTensorType), - input3=IOKey(shape=["b", "c"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...), "b"], type=GenericTensorType), + input2=BaseKey( + shape=["_a", ("Var1", ...), "_b"], type=GenericTensorType + ), + input3=BaseKey(shape=["b", "c"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...), "c"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5646,8 +5651,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...), "u1"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5689,10 +5694,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "u1", "u2", "u3"], type=GenericTensorType ), - output=IOKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -5730,10 +5735,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "u1", "u2", "u3"], type=GenericTensorType ), - output=IOKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "u4"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -6902,8 +6907,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["b", "c", "d"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["b", "c", "d"], type=GenericTensorType), ) model = MyModel() @@ -6921,8 +6926,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey( + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey( shape=[("Var1", ...), "c", "d", "e"], type=GenericTensorType ), ) @@ -6946,10 +6951,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey( + input=BaseKey( shape=[("Var1", ...), "c", "d", "e"], type=GenericTensorType ), - output=IOKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=["a", "b", ("Var1", ...)], type=GenericTensorType), ) model = Model() @@ -7406,8 +7411,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...), "b"], type=GenericTensorType), + input=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...), "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7439,9 +7444,9 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input1=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - input2=IOKey(shape=[("Var1", ...), "b"], type=GenericTensorType), - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input1=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), + input2=BaseKey(shape=[("Var1", ...), "b"], type=GenericTensorType), + output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7659,8 +7664,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[5, 5], type=GenericTensorType), - output=IOKey(shape=[5, 5], type=GenericTensorType), + input=BaseKey(shape=[5, 5], type=GenericTensorType), + output=BaseKey(shape=[5, 5], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7689,8 +7694,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b"], type=GenericTensorType), - output=IOKey(shape=["a", "b"], type=GenericTensorType), + input=BaseKey(shape=["a", "b"], type=GenericTensorType), + output=BaseKey(shape=["a", "b"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7718,8 +7723,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", ("V1", ...), "b", "c"], type=GenericTensorType), - output=IOKey( + input=BaseKey( + shape=["a", ("V1", ...), "b", "c"], type=GenericTensorType + ), + output=BaseKey( shape=["c", ("V1", ...), "a", "b"], type=GenericTensorType ), ) @@ -7756,8 +7763,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=["a", "b", "c"], type=GenericTensorType), - output=IOKey(shape=["b", "c", "a"], type=GenericTensorType), + input=BaseKey(shape=["a", "b", "c"], type=GenericTensorType), + output=BaseKey(shape=["b", "c", "a"], type=GenericTensorType), ) def __call__( # type: ignore[override] @@ -7792,8 +7799,8 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[], type=GenericTensorType), - output=IOKey(shape=[], type=GenericTensorType), + input=BaseKey(shape=[], type=GenericTensorType), + output=BaseKey(shape=[], type=GenericTensorType), ) def __call__( # type: ignore[override] diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 72fe06f4..0a04dd0e 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -22,6 +22,7 @@ from mithril import JaxBackend, NumpyBackend, TorchBackend, compile from mithril.framework.common import ( NOT_GIVEN, + BaseKey, Connection, ConnectionType, GenericTensorType, @@ -661,8 +662,8 @@ class ArtificialPrimitive(PrimitiveModel): def __init__(self, type) -> None: super().__init__( formula_key="tensor_to_list", - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var2", ...)], type=type), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var2", ...)], type=type), ) self._set_constraint( fn=self.artificial_constraint, keys=[PrimitiveModel.output_key, "input"] @@ -815,8 +816,8 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - output=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + output=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), ) def __call__( # type: ignore[override] diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 5a6eb966..9f8103e1 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -28,6 +28,7 @@ ) from mithril.models import ( TBD, + BaseKey, Convolution2D, ExtendInfo, IOKey, @@ -49,9 +50,9 @@ class Model1(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="None", - input1=IOKey(type=tuple[int, ...]), - input2=IOKey(type=list[float]), - output=IOKey(type=tuple[tuple[int, ...]]), + input1=BaseKey(type=tuple[int, ...]), + input2=BaseKey(type=list[float]), + output=BaseKey(type=tuple[tuple[int, ...]]), ) def __call__( # type: ignore[override] @@ -68,10 +69,10 @@ class Model2(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="None", - input1=IOKey(type=int | float), - input2=IOKey(type=int | str), - input3=IOKey(type=str | float), - output=IOKey(type=tuple[int | float, int | float, int | float]), + input1=BaseKey(type=int | float), + input2=BaseKey(type=int | str), + input3=BaseKey(type=str | float), + output=BaseKey(type=tuple[int | float, int | float, int | float]), ) def __call__( # type: ignore[override] @@ -94,19 +95,21 @@ class Model3(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="None", - input1=IOKey( + input1=BaseKey( type=tuple[tuple[int | float, ...], ...] | list[int | float] | tuple[int, int, int, int] ), - input2=IOKey(type=list[int] | tuple[int, ...] | tuple[tuple[int | float]]), - input3=IOKey( + input2=BaseKey( + type=list[int] | tuple[int, ...] | tuple[tuple[int | float]] + ), + input3=BaseKey( type=list[tuple[int | tuple[float | int]]] | int | float | tuple[int | float, ...] ), - output=IOKey(type=int | float | str | tuple[int, int]), + output=BaseKey(type=int | float | str | tuple[int, int]), ) def __call__( # type: ignore[override] From 8572f6b1bab9da64018b6e6ee3a5aa6fff66bb2f Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Wed, 25 Dec 2024 11:59:26 +0300 Subject: [PATCH 7/9] update shape method name in TemplateBase, convert IOKey _connections to set of Connection or str --- examples/flux/auto_encoder.py | 4 +- examples/flux/layers.py | 14 +-- examples/gpt/model.py | 2 +- mithril/framework/common.py | 14 +-- mithril/framework/logical/base.py | 8 +- mithril/framework/logical/model.py | 31 +++---- mithril/models/models.py | 16 ++-- mithril/utils/dict_conversions.py | 4 +- tests/scripts/test_constant_inputs.py | 2 +- tests/scripts/test_extend_template.py | 66 +------------- tests/scripts/test_io_key.py | 8 +- tests/scripts/test_jittable.py | 12 +-- tests/scripts/test_key_values_in_init.py | 25 +++--- tests/scripts/test_model_to_dict_rtt.py | 10 +-- tests/scripts/test_parallel.py | 8 +- tests/scripts/test_ref_counts.py | 8 +- tests/scripts/test_scripts.py | 60 ++++++------- tests/scripts/test_set_types.py | 2 +- tests/scripts/test_shapes.py | 8 +- tests/scripts/test_summary.py | 10 +-- .../scripts/test_tuple_list_args_in_extend.py | 24 ++--- tests/scripts/test_type_coercion.py | 88 +++++++++---------- tests/scripts/test_type_consistencies.py | 4 +- 23 files changed, 176 insertions(+), 252 deletions(-) diff --git a/examples/flux/auto_encoder.py b/examples/flux/auto_encoder.py index 1ed7fbed..6b8868da 100644 --- a/examples/flux/auto_encoder.py +++ b/examples/flux/auto_encoder.py @@ -66,7 +66,7 @@ def attn_block(n_channels: int, name: str | None = None): key = block.key # type: ignore[attr-defined] value = block.value # type: ignore[attr-defined] - shape = query.shape() + shape = query.get_shape() query = query.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1])) key = key.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1])) @@ -92,7 +92,7 @@ def downsample(n_channels: int): def upsample(n_channels: int, name: str | None = None): block = Model(enforce_jit=False, name=name) # TODO: Remove enfor jit false input = IOKey("input") - input_shape = input.shape() + input_shape = input.get_shape() B, C, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3] input = input[:, :, :, None, :, None] diff --git a/examples/flux/layers.py b/examples/flux/layers.py index 33a87547..be327e85 100644 --- a/examples/flux/layers.py +++ b/examples/flux/layers.py @@ -65,8 +65,8 @@ def apply_rope() -> Model: xk = IOKey("xk") freqs_cis = IOKey("freqs_cis") - xq_shape = xq.shape() - xk_shape = xk.shape() + xq_shape = xq.get_shape() + xk_shape = xk.get_shape() B, L, H = xq_shape[0], xq_shape[1], xq_shape[2] block += Reshape()(xq, shape=(B, L, H, -1, 1, 2), output="xq_") B, L, H = xk_shape[0], xk_shape[1], xk_shape[2] @@ -96,7 +96,7 @@ def attention() -> Model: ) # We can get named connection as model.'connection_name' - context_shape = block.context.shape() # type: ignore[attr-defined] + context_shape = block.context.get_shape() # type: ignore[attr-defined] block += Transpose(axes=(0, 2, 1, 3))(block.context) # type: ignore[attr-defined] # NOTE: Reshape input is automatically connected to Transpose output block += Reshape()( @@ -137,7 +137,7 @@ def modulation(dim: int, double: bool, name: str | None = None): def rearrange(num_heads: int): block = Model() input = IOKey("input") - input_shaepe = input.shape() + input_shaepe = input.get_shape() B, L = input_shaepe[0], input_shaepe[1] block += Reshape()(shape=(B, L, 3, num_heads, -1)) block += Transpose(axes=(2, 0, 3, 1, 4))(output=IOKey("output")) @@ -209,7 +209,7 @@ def double_stream_block( block += Concat(axis=2, n=2)(input1=txt_v, input2=img_v, output="v_concat") block += attention()(q="q_concat", k="k_concat", v="v_concat", pe=pe, output="attn") - # TODO: use'[:, txt.shape()[1] :]' when fixed. + # TODO: use'[:, txt.get_shape()[1] :]' when fixed. img_attn = block.attn[:, 256:] # type: ignore[attr-defined] block += Linear(hidden_size, name="img_attn_proj")(img_attn, output="img_proj") @@ -234,7 +234,7 @@ def double_stream_block( ) img = img + block.img_mod_2[2] * block.img_mlp # type: ignore[attr-defined] - # TODO: Use txt.shape()[1]] + # TODO: Use txt.get_shape()[1]] txt_attn = block.attn[:, :256] # type: ignore[attr-defined] block += Linear(hidden_size, name="txt_attn_proj")(txt_attn, output="txt_proj") @@ -355,7 +355,7 @@ def rope(dim: int, theta: int) -> Model: omega = 1.0 / (theta ** (block.arange / dim)) # type: ignore out = input[..., None] * omega - out_shape = out.shape() + out_shape = out.get_shape() B, N, D = out_shape[0], out_shape[1], out_shape[2] block += Cosine()(out, output="cos") diff --git a/examples/gpt/model.py b/examples/gpt/model.py index 4d503a28..3f7a539f 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/model.py @@ -41,7 +41,7 @@ def causal_attention(input_dim, num_heads, bias=True): model += Linear(input_dim * 3, name="c_attn")("input", output="c_attn_out") t_axes = (0, 2, 1, 3) - shp_con = model.input.shape() # type: ignore + shp_con = model.input.get_shape() # type: ignore reshape_con = (shp_con[0], shp_con[1], num_heads, -1) model += Split(3, axis=-1)(model.c_attn_out, output="split_out") # type: ignore diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 8fa15dae..1007ce8a 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -40,7 +40,7 @@ constant_type_table, epsilon_table, ) -from ..utils.utils import OrderedSet, PaddingType, find_dominant_type +from ..utils.utils import PaddingType, find_dominant_type from .utils import ( NestedListType, align_shapes, @@ -1001,7 +1001,7 @@ def abs(self): def len(self): return ExtendTemplate(connections=[self], model="len") - def shape(self): + def get_shape(self): return ExtendTemplate(connections=[self], model="shape") def reshape(self, shape: tuple[int, ...] | TemplateBase): @@ -1118,7 +1118,7 @@ def __init__( type: NestedListType | UnionType | type | None = None, expose: bool | None = None, interval: list[float | int] | None = None, - connections: list[Connection | str] | None = None, + connections: set[Connection | str] | None = None, ) -> None: super().__init__() self._name = name @@ -1127,7 +1127,7 @@ def __init__( self._type = type self._expose = expose self._interval = interval - self._connections: OrderedSet[ConnectionData | str] = OrderedSet() + self._connections: set[Connection | str] = connections or set() # TODO: Shape should not be [] also! if self._value is not TBD and self._shape is not None and self._shape != []: @@ -1144,12 +1144,6 @@ def __init__( f"type is {self._type} while type of value is {value_type}" ) - connections = connections or [] - for item in connections: - conn: ConnectionData | str - conn = item.data if isinstance(item, Connection) else item - self._connections.add(conn) - def __eq__(self, other: object): if isinstance(other, int | float | bool | list | Connection | IOKey | tuple): return ExtendTemplate(connections=[self, other], model="eq") diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 7a6210e6..d596d156 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -95,7 +95,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: continue match con: case Connection(): - kwargs[key] = IOKey(value=val, connections=[con]) + kwargs[key] = IOKey(value=val, connections={con}) # TODO: Maybe we could check con's value if matches with val case item if isinstance(item, MainValueInstance) and con != val: raise ValueError( @@ -110,17 +110,13 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: f"Given IOKey for local key: '{key}' is not valid!" ) else: - _conns: list[Connection | str] = [ - item.conn if isinstance(item, ConnectionData) else item - for item in con._connections - ] kwargs[key] = IOKey( name=con._name, value=val, shape=con._shape, type=con._type, expose=con._expose, - connections=_conns, + connections=con._connections, ) case ExtendTemplate(): raise ValueError( diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 6b2a50ee..e9f374d0 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -311,7 +311,7 @@ def _convert_to_iokey( case str(): connection = IOKey(name=connection) case Connection(): - connection = IOKey(connections=[connection]) + connection = IOKey(connections={connection}) case ExtendTemplate(): # Unroll ExtendTemplate template_conn = model.conns.get_connection(key) @@ -319,7 +319,7 @@ def _convert_to_iokey( connection = self._unroll_template( connection, type(template_conn.metadata.data) ) - connection = IOKey(connections=[connection.conn], expose=False) + connection = IOKey(connections={connection.conn}, expose=False) case _ if isinstance(connection, MainValueInstance): # find_dominant_type returns the dominant type in a container. # If a container has a value of type Connection or ExtendTemplate @@ -338,26 +338,22 @@ def _convert_to_iokey( result = conv_model.conns.get_connection("output") assert result is not None - connection = IOKey(connections=[result.conn], expose=None) + connection = IOKey(connections={result.conn}, expose=None) else: assert isinstance(connection, MainValueInstance) connection = IOKey(value=connection) case IOKey(): expose = connection._expose name = connection._name - # TODO: This check should be removed: conn._connections==OrderedSet([]) + # TODO: This check should be removed: conn._connections==set() # We should not operate different if _connections is given. Fix this and # also fix corresponding tests and dict conversions with "connect". if ( expose is None and (name is None or self.conns.get_connection(name) is None) - and connection._connections == OrderedSet([]) + and connection._connections == set() ): expose = True - _conns: list[Connection | str] = [ - item.conn if isinstance(item, ConnectionData) else item - for item in connection._connections - ] # TODO: Add replicate method to IOKey (update def __call__ in BaseModel) connection = IOKey( name=name, @@ -365,7 +361,7 @@ def _convert_to_iokey( shape=connection._shape, type=connection._type, expose=expose, - connections=_conns, + connections=connection._connections, ) case NotAvailable(): raise ValueError( @@ -411,7 +407,7 @@ def _add_connection( if given_connection._value is not TBD: set_value = given_connection._value - if given_connection._connections == OrderedSet([]): + if given_connection._connections == set(): if outer_key is not None: con_obj = self.conns.get_connection(outer_key) if outer_key is None or con_obj is None: @@ -435,14 +431,15 @@ def _add_connection( if isinstance(conn, str): _conn = self.conns.get_connection(conn) else: - _conn = self.conns.get_con_by_metadata(conn.metadata) + _conn = self.conns.get_con_by_metadata(conn.data.metadata) + if conn.data in model.conns.all.values(): + raise ValueError( + f"Given connection '{conn.data.key}' should not " + "belong to the extending model!" + ) + if not isinstance(_conn, ConnectionData): raise KeyError("Requires accessible connection to be processed!") - elif conn in model.conns.all.values(): - raise ValueError( - f"Given connection '{conn.key}' should not " # type: ignore - "belong to the extending model!" - ) if idx == 0: initial_conn = _conn if outer_key is not None: diff --git a/mithril/models/models.py b/mithril/models/models.py index 26917614..0959a492 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -776,7 +776,7 @@ def __init__( # Assumed input shape is [N, C, H, W] input_key = IOKey(name="input") - input_shape = input_key.shape() + input_shape = input_key.get_shape() B = input_shape[0] input_key = input_key.reshape((B, num_groups, -1)) @@ -2899,7 +2899,7 @@ def __init__( )("pred", "label", "metric_out", "pred_formatted", "label_formatted") true_predictions = self.metric_out == 0 - n_prediction = self.label_formatted.shape()[0] + n_prediction = self.label_formatted.get_shape()[0] self += Sum()(input=true_predictions, output="n_true_predictions") self += Divide()( @@ -3009,13 +3009,13 @@ def __init__( self += Divide()( numerator=sum_precision, - denominator=self.n_classes.shape()[0].tensor(), + denominator=self.n_classes.get_shape()[0].tensor(), output=IOKey(name="output"), ) elif average == "weighted": precision = None - n_element = self.label_formatted.shape()[0] + n_element = self.label_formatted.get_shape()[0] assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" @@ -3153,7 +3153,7 @@ def __init__( self += Divide()( numerator=sum_recall, - denominator=self.n_classes.shape()[0].tensor(), + denominator=self.n_classes.get_shape()[0].tensor(), output=IOKey(name="output"), ) @@ -3162,7 +3162,7 @@ def __init__( assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" - n_element = self.label_formatted.shape()[0] + n_element = self.label_formatted.get_shape()[0] for idx in range(n_classes): class_idxs = self.label_formatted == idx true_positive = (self.metric_out == 0) & class_idxs @@ -3299,7 +3299,7 @@ def __init__( self += Unique()(input=self.label_formatted, output="n_classes") self += Divide()( numerator=sum_precision, - denominator=self.n_classes.shape()[0].tensor(), + denominator=self.n_classes.get_shape()[0].tensor(), output=IOKey(name="output"), ) @@ -3308,7 +3308,7 @@ def __init__( assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" - n_element = self.label_formatted.shape()[0].tensor() + n_element = self.label_formatted.get_shape()[0].tensor() for idx in range(n_classes): class_idxs = self.label_formatted == idx true_positive = (self.metric_out == 0) & class_idxs diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 8cb17f93..53a8e034 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -153,12 +153,12 @@ def dict_to_model(modelparams: dict[str, Any]) -> BaseModel: key = IOKey(**key_kwargs) mappings[k] = IOKey( **key_kwargs, - connections=[ + connections={ getattr(submodels_dict[value[0]], value[1]) if isinstance(value, Sequence) else value for value in conn["connect"] - ], + }, ) elif "name" in conn: key_kwargs = create_iokey_kwargs(conn) diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 3aee00d6..5944e96b 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -969,7 +969,7 @@ def test_nontensor_extend_from_input_multiple_connection(): model += mean1 model += mean2 model += mean3 - model += mean4(axis=IOKey(connections=[mean1.axis, mean2.axis, mean3.axis])) + model += mean4(axis=IOKey(connections={mean1.axis, mean2.axis, mean3.axis})) assert ( mean1.axis.data.metadata == mean2.axis.data.metadata diff --git a/tests/scripts/test_extend_template.py b/tests/scripts/test_extend_template.py index 01b0979a..e3eeb83e 100644 --- a/tests/scripts/test_extend_template.py +++ b/tests/scripts/test_extend_template.py @@ -150,7 +150,7 @@ def test_shape_reshape(): # Create with shortcut. model_1 = Model() model_1 += (lin_1 := Linear(dimension=1))(input="input_1", weight="w_1", bias="b_1") - shp = lin_1.input.shape() + shp = lin_1.input.get_shape() model_1 += (lin_2 := Linear(dimension=2))(input="input_2", weight="w_2", bias="b_2") reshaped = lin_2.output.reshape(shp) model_1 += Add()(left=lin_1.output, right=reshaped, output=IOKey(name="output")) @@ -210,7 +210,7 @@ def test_slice_item(): model_1 += (lin_1 := Linear(dimension=1))( input="input", weight="weight", bias="bias" ) - shp = lin_1.input.shape() + shp = lin_1.input.get_shape() item = shp[1].tensor() slc = shp[:].tensor() model_1 += Add()(left=item, right=slc, output=IOKey(name="output")) @@ -1391,68 +1391,6 @@ def test_invalid_input(): "asd" + model.input # type: ignore -# def test_coercion_models_1(): -# backend = JaxBackend() - -# data = {"left": backend.randn(3, 4, 5), "right": backend.randn(3, 4, 5)} - -# model1 = Model() -# model1 += (add_model := Add())(left="left", right="right") -# out = add_model.output -# scalar_item_output = out.shape()[1] -# tensor_item_output = out[1] -# model1 += Buffer()( -# input=scalar_item_output + tensor_item_output, output=IOKey(name="output") -# ) - -# model2 = Model() -# model2 += (add_model := Add())(left="left", right="right") -# model2 += (shp_model := Shape())(input=add_model.output) -# model2 += (to_tensor_model := ToTensor())(input=shp_model.output) -# model2 += (tensor_item_model1 := TensorItem())( -# input=to_tensor_model.output, index=1 -# ) -# model2 += (tensor_item_model2 := TensorItem())(input=add_model.output, index=1) -# model2 += (add_model_2 := Add())( -# left=tensor_item_model1.output, right=tensor_item_model2.output -# ) -# model2 += Buffer()(input=add_model_2.output, output=IOKey(name="output")) - -# compare_models(model1, model2, backend, data, check_internals=False) - - -# def test_coercion_models_2(): -# backend = JaxBackend() - -# data = {"left": backend.randn(5, 6, 2), "right": backend.randn(5, 6, 2)} - -# model1 = Model() -# model1 += (add_model := Add())(left="left", right="right") -# out = add_model.output -# scalar_item_output = out.shape()[1:3] -# tensor_item_output = out[1:3] -# model1 += Buffer()( -# input=scalar_item_output + tensor_item_output, output=IOKey(name="output") -# ) - -# model2 = Model() -# model2 += (add_model := Add())(left="left", right="right") -# model2 += (shp_model := Shape())(input=add_model.output) -# model2 += (to_tensor_model := ToTensor())(input=shp_model.output) -# model2 += (tensor_item_model1 := TensorSlice(start=TBD, stop=TBD, step=TBD))( -# input=to_tensor_model.output, start=1, stop=3, step=None -# ) -# model2 += (tensor_item_model2 := TensorSlice(start=TBD, stop=TBD, step=TBD))( -# input=add_model.output, start=1, stop=3, step=None -# ) -# model2 += (add_model_2 := Add())( -# left=tensor_item_model1.output, right=tensor_item_model2.output -# ) -# model2 += Buffer()(input=add_model_2.output, output=IOKey(name="output")) - -# compare_models(model1, model2, backend, data, check_internals=False) - - def test_tensoritem_multiple_slice_1(): model1 = Model() diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index a3425b78..e7f74189 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -272,7 +272,7 @@ def test_7(): model += (relu1 := Relu())(input="in1", output="relu1_output") model += (relu2 := Relu())(input="in2", output="relu2_output") model += (relu3 := Relu())( - input="", output=IOKey(name="my_input", connections=[relu1.input, relu2.input]) + input="", output=IOKey(name="my_input", connections={relu1.input, relu2.input}) ) assert ( model.dag[relu1]["input"].metadata @@ -448,7 +448,7 @@ def test_iokey_shapes_3(): input3=IOKey(name="input3", shape=[3, "a"]), ) - conns = [main_model.input1, main_model.input2, main_model.input3] # type: ignore + conns = {main_model.input1, main_model.input2, main_model.input3} # type: ignore key = IOKey(name="input", connections=conns) main_model += Buffer()(input=key, output="output1") @@ -1150,7 +1150,7 @@ def test_compare_models_5(): sigmoid = Sigmoid() add = Add() model2 += add(output=IOKey(name="output")) - conn = IOKey(connections=[add.left, add.right]) + conn = IOKey(connections={add.left, add.right}) model2 += sigmoid(input="input", output=conn) model2.set_shapes({"input": [2, 2]}) @@ -1245,7 +1245,7 @@ def test_iokey_template_4(): model = Model() left = IOKey("left") - res = left.shape()[0] + res = left.get_shape()[0] model += Buffer()(res.tensor(), IOKey("output")) diff --git a/tests/scripts/test_jittable.py b/tests/scripts/test_jittable.py index 9052c78f..044952ea 100644 --- a/tests/scripts/test_jittable.py +++ b/tests/scripts/test_jittable.py @@ -257,7 +257,7 @@ def test_logical_model_jittable_1(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") with pytest.raises(Exception) as error_info: - model += Item()(input=IOKey(name="input", connections=[add1.left, add2.left])) + model += Item()(input=IOKey(name="input", connections={add1.left, add2.left})) modified_msg = re.sub("\\s*", "", str(error_info.value)) expected_msg = ( "Model with enforced Jit can not be extended by a non-jittable model! \ @@ -274,7 +274,7 @@ def test_logical_model_jittable_2(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) assert not model.enforce_jit @@ -287,7 +287,7 @@ def test_logical_model_jittable_3(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) assert not model.enforce_jit @@ -300,7 +300,7 @@ def test_physical_model_jit_1(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) backend = JaxBackend() @@ -320,7 +320,7 @@ def test_physical_model_jit_2(): model += (add1 := Add())(left="l1", right="l2", output=IOKey(name="out1")) model += (add2 := Add())(left="l3", right="l4") model.enforce_jit = False - input = IOKey(name="input", connections=[add1.left, add2.left], expose=True) + input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) model += Item()(input=input) backend = JaxBackend() @@ -367,7 +367,7 @@ def test_jit_2(): model = Model(enforce_jit=False) model += (add_model := Add())(left="left", right="right") in1 = add_model.output - out1 = in1.shape() + out1 = in1.get_shape() out2 = out1.tensor().sum() mean_model = Mean(axis=TBD) model += (to_list := Item())(input=out2) diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index bf4b8ef6..7a642233 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -16,7 +16,6 @@ import mithril as ml from mithril.models import Add, Model -from mithril.utils.utils import OrderedSet def test_directed_call_connection(): @@ -28,7 +27,7 @@ def test_directed_call_connection(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) + assert left_info._connections == {connection} assert left_info._name is None assert left_info._value == 1 @@ -98,7 +97,7 @@ def test_directed_call_iokey_value_tbd(): def test_directed_call_connect_key_value_not_equal(): add1 = Add(left=1) - iokey = ml.IOKey("in1", value=2, connections=[Add().left]) + iokey = ml.IOKey("in1", value=2, connections={Add().left}) with pytest.raises(ValueError) as err_info: add1(left=iokey) @@ -108,25 +107,25 @@ def test_directed_call_connect_key_value_not_equal(): def test_directed_call_connect_key_none(): add1 = Add(left=1) connection = Add().left - con = ml.IOKey(connections=[connection]) + con = ml.IOKey(connections={connection}) info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) + assert left_info._connections == {connection} assert left_info._value == 1 # key is set to IOKey with val from factory_inputs def test_directed_call_connect_key_value_tbd(): add1 = Add(left=1) connection = Add().left - con = ml.IOKey(name="in1", connections=[connection]) + con = ml.IOKey(name="in1", connections={connection}) info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) + assert left_info._connections == {connection} assert isinstance(left_info, ml.IOKey) assert left_info._value == 1 # value is set to val from factory_inputs @@ -134,13 +133,13 @@ def test_directed_call_connect_key_value_tbd(): def test_directed_call_connect_key_value_equal(): add1 = Add(left=1) connection = Add().left - con = ml.IOKey("in1", value=1, connections=[connection]) + con = ml.IOKey("in1", value=1, connections={connection}) info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == OrderedSet([connection.data]) + assert left_info._connections == {connection} assert left_info._value == 1 # value is set to val from factory_inputs @@ -256,7 +255,7 @@ def test_integration_call_arg_iokey_value_tbd(): def test_integration_call_arg_connect_key_value_not_equal(): add1 = Add(left=1) - connect = ml.IOKey("in1", value=2, connections=[Add().left]) + connect = ml.IOKey("in1", value=2, connections={Add().left}) model = Model() with pytest.raises(ValueError) as err_info: @@ -267,7 +266,7 @@ def test_integration_call_arg_connect_key_value_not_equal(): def test_integration_call_arg_connect_key_none(): add1 = Add(left=1) add2 = Add() - con = ml.IOKey(connections=[add2.left]) + con = ml.IOKey(connections={add2.left}) model = Model() model += add2(left="in1", right="in2") @@ -281,7 +280,7 @@ def test_integration_call_arg_connect_key_none(): def test_integration_call_arg_connect_key_value_tbd(): add1 = Add(left=1) add2 = Add() - con = ml.IOKey(name="in1", expose=True, connections=[add2.left]) + con = ml.IOKey(name="in1", expose=True, connections={add2.left}) model = Model() model += add2(right="in2") @@ -295,7 +294,7 @@ def test_integration_call_arg_connect_key_value_tbd(): def test_integration_call_arg_connect_key_value_equal(): add1 = Add(left=1) add2 = Add() - con = ml.IOKey(connections=[add2.left], value=1) + con = ml.IOKey(connections={add2.left}, value=1) model = Model() model += add2(right="in2") diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index 8876a549..c6efe302 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -489,7 +489,7 @@ def test_composite_9(): input="", weight="weight1", output=IOKey(name="output2") ) model += Linear(dimension=71)( - input="input", weight="weight2", output=IOKey(connections=[l1.input, l2.input]) + input="input", weight="weight2", output=IOKey(connections={l1.input, l2.input}) ) model_dict_created = dict_conversions.model_to_dict(model) @@ -516,7 +516,7 @@ def test_composite_10(): model += Linear(dimension=71)( input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"]), + output=IOKey(name="my_input", connections={"input1", "input2"}), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -543,7 +543,7 @@ def test_composite_10_expose_false(): model += Linear(dimension=71)( input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"], expose=False), + output=IOKey(name="my_input", connections={"input1", "input2"}, expose=False), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -588,7 +588,7 @@ def test_composite_12(): Linear(dimension=71), input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"]), + output=IOKey(name="my_input", connections={"input1", "input2"}), ) model_dict_created = dict_conversions.model_to_dict(model) @@ -622,7 +622,7 @@ def test_composite_13(): Linear(dimension=71), input="input", weight="weight2", - output=IOKey(name="my_input", connections=["input1", "input2"]), + output=IOKey(name="my_input", connections={"input1", "input2"}), ) model_dict_created = dict_conversions.model_to_dict(model) diff --git a/tests/scripts/test_parallel.py b/tests/scripts/test_parallel.py index 24472dee..3641aa7d 100644 --- a/tests/scripts/test_parallel.py +++ b/tests/scripts/test_parallel.py @@ -383,7 +383,7 @@ def test_torch_parallel_2(): # primitive eye. model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = create_parallel_backend(device_mesh=(4, 1)) backend.ones([256]) @@ -507,7 +507,7 @@ def test_torch_parallel_5(): # primitive eye. model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.TorchBackend() @@ -957,7 +957,7 @@ def test_jax_parallel_2(): if "cuda" in mithril.JaxBackend.get_available_devices(): model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.JaxBackend(device="cuda", device_mesh=(4, 1)) backend.ones([256]) @@ -1090,7 +1090,7 @@ def test_jax_parallel_5(): if "cuda" in mithril.JaxBackend.get_available_devices(): model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.JaxBackend(device="cuda") diff --git a/tests/scripts/test_ref_counts.py b/tests/scripts/test_ref_counts.py index 62294f42..4e86d3a0 100644 --- a/tests/scripts/test_ref_counts.py +++ b/tests/scripts/test_ref_counts.py @@ -239,14 +239,14 @@ def test_deleted_variadic_ref_count_7(): model += add_5(left="") conn = IOKey( - connections=[ + connections={ add_1.left, add_1.right, add_2.left, add_2.right, add_3.left, add_3.right, - ] + } ) model += add_6(left=conn, right="right", output="output") @@ -912,7 +912,7 @@ def test_deleted_tensors_ref_count_3(): model += buffer4(input="input4", output=IOKey(name="output4")) model += buffer5(input="input5", output=IOKey(name="output5")) model += buffer6(input="input6", output=IOKey(name="output6")) - connections = [buffer1.input, buffer2.input, buffer3.input, model.output4] # type: ignore + connections = {buffer1.input, buffer2.input, buffer3.input, model.output4} # type: ignore conn = IOKey(connections=connections) model += buffer7(input=conn, output=IOKey(name="output")) @@ -1138,7 +1138,7 @@ def test_deleted_edge_ref_count_6(): output2=IOKey(name="output2"), output3=IOKey(name="output3"), ) - connections = [main_model.output1, main_model.input2] # type: ignore + connections = {main_model.output1, main_model.input2} # type: ignore conn = IOKey(name="abcd", expose=True, connections=connections) main_model += sigmoid4(input=conn, output=IOKey(name="output5")) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 7a711fe5..91ed86e8 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -40,7 +40,6 @@ GenericTensorType, IOKey, NotAvailable, - OrderedSet, ToBeDetermined, UniadicRecord, Variadic, @@ -110,6 +109,7 @@ Where, ) from mithril.utils.type_utils import is_list_int +from mithril.utils.utils import OrderedSet from .helper import assert_models_equal from .test_shapes import check_shapes_semantically @@ -262,7 +262,7 @@ def test_cyclic_extension_5(): left="input5", right="input6", output=IOKey( - name="my_input", expose=False, connections=[sum1.left, sum2.right] + name="my_input", expose=False, connections={sum1.left, sum2.right} ), ) @@ -2008,7 +2008,7 @@ def test_multiple_output_connections(): with pytest.raises(Exception) as err_info: model += add_1( - left="left", right="right", output=IOKey(connections=[add_2.left, "out2"]) + left="left", right="right", output=IOKey(connections={add_2.left, "out2"}) ) assert ( @@ -2025,7 +2025,7 @@ def test_multiple_output_connections_2(): model += add_1( left="left", right="right", - output=IOKey(name="my_internal_key", connections=[add_2.left, "in3"]), + output=IOKey(name="my_internal_key", connections={add_2.left, "in3"}), ) assert ( @@ -3945,7 +3945,7 @@ def test_connect_1(): relu3 = Relu() model += relu1(output="relu_output_1") model += relu2(input="", output="relu_output_2") - model += relu3(input="", output=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input="", output=IOKey(connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].metadata @@ -3962,7 +3962,7 @@ def test_connect_2(): model += relu1(input="in1", output="relu_output_1") model += relu2(input="in2", output="relu_output_2") model += relu3( - input="", output=IOKey(name="my_input", connections=[relu1.input, relu2.input]) + input="", output=IOKey(name="my_input", connections={relu1.input, relu2.input}) ) assert ( @@ -3979,7 +3979,7 @@ def test_connect_3(): relu3 = Relu() model += relu1(output="relu_output_1") model += relu2(input="", output="relu_output_2") - model += relu3(input=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input=IOKey(connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].metadata @@ -3995,7 +3995,7 @@ def test_connect_4(): relu3 = Relu() model += relu1(input="in1", output="relu_output_1") model += relu2(input="in2", output="relu_output_2") - model += relu3(input=IOKey(name="my_input", connections=[relu1.input, relu2.input])) + model += relu3(input=IOKey(name="my_input", connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].metadata @@ -4012,7 +4012,7 @@ def test_connect_5(): relu3 = Relu() model += relu1(input="in1", output="relu_output_1") model += relu2(input="", output="relu_output_2") - model += relu3(input=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input=IOKey(connections={relu1.input, relu2.input})) assert ( model.dag[relu1]["input"].key @@ -4035,7 +4035,7 @@ def test_connect_6(): model += relu2(input="in2", output="relu_output_2") with pytest.raises(KeyError) as error_info: - model += Relu()(input=IOKey(connections=[relu1.input, relu2.input])) + model += Relu()(input=IOKey(connections={relu1.input, relu2.input})) assert str(error_info.value) == ( "'Requires a connection to have only one unique key name but " @@ -4119,10 +4119,10 @@ def test_connect_composite_2_extend_from_inputs(): m2 = deepcopy(submodel) subcopy = deepcopy(submodel) model += m1(left="left", right="right") - model += m2(left=IOKey(connections=[m1.output]), right="right") # type: ignore + model += m2(left=IOKey(connections={m1.output}), right="right") # type: ignore model += subcopy( - left=IOKey(connections=[m2.output]), # type: ignore - right=IOKey(connections=[m2.output]), # type: ignore + left=IOKey(connections={m2.output}), # type: ignore + right=IOKey(connections={m2.output}), # type: ignore output="output", ) @@ -4140,9 +4140,9 @@ def test_composite_6_extend_from_inputs_connect(): relu3 = Relu() relu4 = Relu() model += relu1(output="output") - model += relu2(input=IOKey(connections=[relu1.input])) - model += relu3(input="my_input", output=IOKey(connections=[relu2.input])) - model += relu4(input=IOKey(connections=[relu3.input])) + model += relu2(input=IOKey(connections={relu1.input})) + model += relu3(input="my_input", output=IOKey(connections={relu2.input})) + model += relu4(input=IOKey(connections={relu3.input})) assert ( relu2.input.data.metadata @@ -4164,8 +4164,8 @@ def test_composite_4_extend_from_inputs_connect(): relu3 = Relu() relu4 = Relu() model += relu1(input="my_input", output=IOKey(name="output")) - model += relu2(input=IOKey(connections=[relu1.input])) - model += relu3(input=IOKey(connections=[relu2.input])) + model += relu2(input=IOKey(connections={relu1.input})) + model += relu3(input=IOKey(connections={relu2.input})) model += relu4(input="input1", output="my_input") backend = TorchBackend() @@ -4185,7 +4185,7 @@ def test_integration_composite_1_extend_from_inputs_1_with_connect(): m1 = Layer(dimension=2, activation=Sigmoid()) model += m2(weight="w1", bias="b1", output="output") model += m1( - input="input", weight="w0", bias="b0", output=IOKey(connections=[m2.input]) + input="input", weight="w0", bias="b0", output=IOKey(connections={m2.input}) ) assert m1.output.data.metadata == m2.input.data.metadata @@ -4230,7 +4230,7 @@ def test_connect_8(): r2 = Relu() model += t(output="output1") model += r1(input="input2", output="output2") - model += r2(input="", output=IOKey(connections=[t.input, r1.input])) + model += r2(input="", output=IOKey(connections={t.input, r1.input})) assert r1.input.data.metadata == r2.output.data.metadata == t.input.data.metadata @@ -4242,7 +4242,7 @@ def test_connect_9(): r2 = Relu() model += t(input="input1", output="output1") model += r1(input="", output="output2") - model += r2(input="", output=IOKey(connections=["input1", r1.input])) + model += r2(input="", output=IOKey(connections={"input1", r1.input})) assert ( r1.input.data.metadata @@ -4261,7 +4261,7 @@ def test_connect_10(): model += r1(input="input2", output=IOKey(name="output2")) model += r2( input="", - output=IOKey(connections=["input1", "input2"], expose=True, name="internal"), + output=IOKey(connections={"input1", "input2"}, expose=True, name="internal"), ) assert ( @@ -4293,7 +4293,7 @@ def test_connect_12(): model += add2(left="l3", right="l4", output=IOKey(name="out2")) model += add3( - left=IOKey(name="left", connections=[add1.left, add2.left]), + left=IOKey(name="left", connections={add1.left, add2.left}), right="right", output=IOKey(name="out3"), ) @@ -4311,7 +4311,7 @@ def test_connect_13(): buf = Buffer() model += add1(left="l1", right="l2", output=IOKey(name="out1")) model += add2(left="l3", right="l4") - model += buf(input=IOKey(name="input", connections=[add1.left, add2.left])) + model += buf(input=IOKey(name="input", connections={add1.left, add2.left})) model += Add()(left=add2.output, right=buf.output, output=IOKey(name="out2")) assert model._input_keys == {"input", "l2", "l4"} @@ -4335,7 +4335,7 @@ def test_connect_error_1(): with pytest.raises(Exception) as error_info: model += Relu()( input="input", - output=IOKey(name="my_input", connections=["input1", "input2", "output3"]), + output=IOKey(name="my_input", connections={"input1", "input2", "output3"}), ) assert ( @@ -4354,7 +4354,7 @@ def test_connect_error_2(): with pytest.raises(KeyError) as error_info: model += Relu()( input=IOKey( - name="my_input", connections=["input1", "input2", "output3", "output4"] + name="my_input", connections={"input1", "input2", "output3", "output4"} ) ) @@ -4371,7 +4371,7 @@ def test_connect_error_5(): with pytest.raises(KeyError) as error_info: model_2 += Relu()( - output=IOKey(expose=True, connections=[tanh.input, relu.input]) + output=IOKey(expose=True, connections={tanh.input, relu.input}) ) assert ( @@ -4389,7 +4389,7 @@ def test_connect_error_6(): model += l2(input="input1", weight="w1", output=IOKey(name="output2")) model += l3(input="", output=IOKey(name="output3")) model += l4( - input=IOKey(name="my_output", connections=["input1", "input2", "output3"]) + input=IOKey(name="my_output", connections={"input1", "input2", "output3"}) ) assert ( @@ -5122,7 +5122,7 @@ def test_dependency_map_latent_to_input(): # Add third model which changes name of a latent input and # makes it a real input of the model. - conn = IOKey(name="mean_axis", connections=[mean.axis], expose=True) + conn = IOKey(name="mean_axis", connections={mean.axis}, expose=True) model += (to_tensor := ToTensor())(conn, output="output") # Assert dependency map and connection keys status in model. output: ConnectionData = model.output.data # type: ignore @@ -6677,7 +6677,7 @@ def test_multi_write_7(): model += add1(left="left1", right="right1", output="output1") model += add2(left="left2", right="right2", output="output2") - out = IOKey(connections=[model.output1, model.output2]) # type: ignore + out = IOKey(connections={model.output1, model.output2}) # type: ignore with pytest.raises(KeyError) as err_info: model += Buffer()(input=out, output="output3") diff --git a/tests/scripts/test_set_types.py b/tests/scripts/test_set_types.py index 9fdcb46f..12397c79 100644 --- a/tests/scripts/test_set_types.py +++ b/tests/scripts/test_set_types.py @@ -200,7 +200,7 @@ def test_types_iokey_3(): output=IOKey(name="output2", type=float | int), ) - conn = IOKey("sub", connections=[buffer_model1.input, buffer_model2.input]) + conn = IOKey("sub", connections={buffer_model1.input, buffer_model2.input}) buffer_model3 = Buffer() diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index dd49a95d..1763be6d 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -560,7 +560,7 @@ def test_shapes_4(): input="", weight="weight1", output=IOKey(name="output2") ) model += Linear(dimension=71)( - input="input", weight="weight2", output=IOKey(connections=[l1.input, l2.input]) + input="input", weight="weight2", output=IOKey(connections={l1.input, l2.input}) ) shapes = {"input": [4, 256]} logical_ref: Mapping[str, list | None] = { @@ -2258,7 +2258,7 @@ def test_composite_3_extend_shapes_1(): m3 = Model() m3 += m2(right=IOKey(name="right")) m3 += Add()( - left=IOKey(name="left", connections=[m1.left], expose=True), # type: ignore + left=IOKey(name="left", connections={m1.left}, expose=True), # type: ignore right=m2.output, # type: ignore output=IOKey(name="output"), ) # type: ignore @@ -7466,7 +7466,7 @@ def __call__( # type: ignore[override] shapes: dict[str, list] = {"input": ["a", ("Var1", ...)]} buff_model.set_shapes(shapes) model += test_model - con = IOKey(connections=[test_model.input2, buff_model.input]) # type: ignore + con = IOKey(connections={test_model.input2, buff_model.input}) # type: ignore model += Buffer()(input=con, output=IOKey(name="output")) all_nodes = get_all_nodes(model) @@ -9662,7 +9662,7 @@ def test_connect_shapes(): model = Model() model += relu1(input="") model += relu2(input="") - model += relu3(input="input", output=IOKey(connections=[relu1.input, relu2.input])) + model += relu3(input="input", output=IOKey(connections={relu1.input, relu2.input})) assert model.shapes["input"] == [5, 7] diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index fcee738d..9e49c16f 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -186,7 +186,7 @@ def test_extract_logical_connections_4(): output3=IOKey(name="out_3"), ) model += model_2( - output1=IOKey(connections=[model_1.input1, model_1.input2]), # type: ignore + output1=IOKey(connections={model_1.input1, model_1.input2}), # type: ignore output2=IOKey(name="out_4"), output3=IOKey(name="out_5"), input1="in1", @@ -1487,7 +1487,7 @@ def test_logical_model_summary_9(): model = Model() add_1, add_2 = Add(), Add() model += add_1(left="left") - model += add_2(output=IOKey(connections=[add_1.left, add_1.right]), left="left_1") + model += add_2(output=IOKey(connections={add_1.left, add_1.right}), left="left_1") with redirect_stdout(StringIO()) as summary: model.summary(shapes=True, symbolic=True) @@ -1535,13 +1535,13 @@ def test_logical_model_summary_11(): ) model_n += model_2( input1="", - output1=IOKey(connections=[model_3.input1, model_3.input2, model_3.input3]), # type: ignore + output1=IOKey(connections={model_3.input1, model_3.input2, model_3.input3}), # type: ignore output2=IOKey(name="output4"), output3=IOKey(name="output5"), ) model_n += model_1( input1="", - output1=IOKey(connections=[model_2.input1, model_2.input2, model_2.input3]), # type: ignore + output1=IOKey(connections={model_2.input1, model_2.input2, model_2.input3}), # type: ignore ) with redirect_stdout(StringIO()) as summary: @@ -1581,7 +1581,7 @@ def test_logical_model_summary_12(): input1=model_1.output1, # type: ignore input2=model_1.output2, # type: ignore input3=model_1.output3, # type: ignore - output1=IOKey(connections=[model_3.input1, model_3.input2, model_3.input3]), # type: ignore + output1=IOKey(connections={model_3.input1, model_3.input2, model_3.input3}), # type: ignore output2=IOKey(name="output4"), output3=IOKey(name="output5"), ) diff --git a/tests/scripts/test_tuple_list_args_in_extend.py b/tests/scripts/test_tuple_list_args_in_extend.py index 48846f75..b8158d19 100644 --- a/tests/scripts/test_tuple_list_args_in_extend.py +++ b/tests/scripts/test_tuple_list_args_in_extend.py @@ -80,7 +80,7 @@ def test_tuple_argument_3(): add_model_2 = Add() model += add_model(left="left", right="right") model += add_model_2( - left=(add_model.left.shape(), add_model.right.shape()), + left=(add_model.left.get_shape(), add_model.right.get_shape()), right=add_model.left + add_model.right, output="output", ) @@ -110,7 +110,7 @@ def test_tuple_argument_4(): add_model_2 = Add() model += add_model(left="left", right="right") model += add_model_2( - left=(add_model.left.shape() * 2, add_model.right.shape() * 2), + left=(add_model.left.get_shape() * 2, add_model.right.get_shape() * 2), right=add_model.left + add_model.right, output="output", ) @@ -138,8 +138,8 @@ def test_tuple_argument_5(): model += add_model(left="left", right="right") model += add_model_2( left=( - (add_model.left.shape()[0], add_model.left.shape()[0]), - (add_model.left.shape()[0], add_model.left.shape()[0]), + (add_model.left.get_shape()[0], add_model.left.get_shape()[0]), + (add_model.left.get_shape()[0], add_model.left.get_shape()[0]), ), right=add_model.left + add_model.right, output="output", @@ -168,8 +168,8 @@ def test_list_tuple_mixed_argument_1(): model += add_model(left="left", right="right") model += add_model_2( left=( - [add_model.left.shape()[0], add_model.left.shape()[0]], - [add_model.left.shape()[0], add_model.left.shape()[0]], + [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], + [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], ), right=add_model.left + add_model.right, output="output", @@ -197,8 +197,8 @@ def test_list_tuple_mixed_argument_2(): model += add_model(left="left", right="right") - left_first_shape = add_model.left.shape()[0] - right_first_shape = add_model.right.shape()[0] + left_first_shape = add_model.left.get_shape()[0] + right_first_shape = add_model.right.get_shape()[0] matmul_left = ([left_first_shape, 0], [2, right_first_shape]) @@ -296,7 +296,7 @@ def test_list_argument_3(): model += add_model(left="left", right="right") model += add_model_2( - left=[add_model.left.shape(), add_model.right.shape()], + left=[add_model.left.get_shape(), add_model.right.get_shape()], right=add_model.left + add_model.right, output="output", ) @@ -327,7 +327,7 @@ def test_list_argument_4(): model += add_model(left="left", right="right") model += add_model_2( - left=[add_model.left.shape() * 2, add_model.right.shape() * 2], + left=[add_model.left.get_shape() * 2, add_model.right.get_shape() * 2], right=add_model.left + add_model.right, output="output", ) @@ -356,8 +356,8 @@ def test_list_argument_5(): model += add_model(left="left", right="right") model += add_model_2( left=[ - [add_model.left.shape()[0], add_model.left.shape()[0]], - [add_model.left.shape()[0], add_model.left.shape()[0]], + [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], + [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], ], right=add_model.left + add_model.right, output="output", diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 0a04dd0e..616c80c0 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -116,7 +116,7 @@ def test_scalar_to_tensor_2(): lin_2 = Linear(dimension=2) model += lin_1(input="input_1", weight="w_1", bias="b_1") model += lin_2(input="input_2", weight="w_2", bias="b_2") - shp_1 = lin_1.input.shape() + shp_1 = lin_1.input.get_shape() reshaped_1 = lin_2.output.reshape(shp_1) to_tensor = ToTensor() model += to_tensor(input=shp_1) @@ -129,7 +129,7 @@ def test_scalar_to_tensor_2(): lin_4 = Linear(dimension=2) model += lin_3(input="input_1", weight="w_1", bias="b_1") model += lin_4(input="input_2", weight="w_2", bias="b_2") - shp_2 = lin_3.input.shape() + shp_2 = lin_3.input.get_shape() reshaped_2 = lin_4.output.reshape(shp_2) model += Add()(left=shp_2.tensor(), right=reshaped_2, output="output") model_2 = model @@ -185,7 +185,7 @@ def test_scalar_to_tensor_3(): def test_tensor_to_scalar_1(): - """Model enforces Jit so we reshape with to_tensor_1_output.shape(). + """Model enforces Jit so we reshape with to_tensor_1_output.get_shape(). We can not directly reshape with to_tensor_1_output which is valued as [2, 1] in tensor domain since it requires TensorToList conversion before being argument to reshape method. @@ -199,7 +199,7 @@ def test_tensor_to_scalar_1(): model += to_tensor_1(input=[2, 1]) model += to_tensor_2(input=[[1, 1]]) model += add_1(left=to_tensor_1.output, right=to_tensor_2.output) - reshaped_1 = add_1.output.reshape(to_tensor_1.output.shape()) + reshaped_1 = add_1.output.reshape(to_tensor_1.output.get_shape()) model += Buffer()(input=reshaped_1, output="output") model_1 = model @@ -209,7 +209,7 @@ def test_tensor_to_scalar_1(): left = IOKey(value=[2, 1]).tensor() right = IOKey(value=[1, 1]).tensor() model += add_2(left=left, right=right) - reshaped_2 = add_2.output.reshape(add_2.left.shape()) + reshaped_2 = add_2.output.reshape(add_2.left.get_shape()) model += Buffer()(input=reshaped_2, output="output") model_2 = model @@ -282,7 +282,7 @@ def test_slice_item_conversions(): model = Model() lin_2 = Linear(dimension=1) model += lin_2(input="input", weight="w", bias="b") - shp2 = lin_2.input.shape() + shp2 = lin_2.input.get_shape() shp2_1 = shp2[1] assert shp2_1 is not None shp_item = shp2_1.tensor() @@ -307,7 +307,7 @@ def test_tuple_conversion_1(): model = Model() lin_1 = Linear(dimension=2) model += lin_1(input="input", weight="w", bias="b") - shp1 = lin_1.output.shape() + shp1 = lin_1.output.get_shape() model += ToTensor()(input=(shp1[0], shp1[1]), output="output") model_1 = model @@ -316,7 +316,7 @@ def test_tuple_conversion_1(): lin_2 = Linear(dimension=2) tupl = ToTuple(n=2) model += lin_2(input="input", weight="w", bias="b") - shp2 = lin_2.output.shape() + shp2 = lin_2.output.get_shape() model += tupl(input1=shp2[0], input2=shp2[1]) model += ToTensor()(input=tupl.output, output="output") # type: ignore model_2 = model @@ -337,7 +337,7 @@ def test_tuple_conversion_2(): lin_1 = Linear(dimension=2) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.get_shape() model += tt1(input=(shp1[0], shp1[1])) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -383,7 +383,7 @@ def test_tuple_conversion_3(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.get_shape() model += tt1(input=(shp1[0], shp1[1], 3)) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -429,7 +429,7 @@ def test_list_conversion_1(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.get_shape() model += tt1(input=[shp1[0], shp1[1], 3.0]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -474,7 +474,7 @@ def test_nested_list_conversion_1(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.get_shape() model += tt1(input=[[shp1[0], shp1[1], 3.0]]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -519,7 +519,7 @@ def test_nested_list_conversion_2(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input="input", weight="w", bias="b") - shp1 = lin_1.input.shape() + shp1 = lin_1.input.get_shape() model += tt1(input=[[shp1[0], shp1[1], 3.0]]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -838,7 +838,7 @@ def test_connect_type_conv_handling_1(): model.extend((a1 := Buffer()), input="input1") model.extend((a2 := Buffer()), input="input2") con_object = IOKey( - name="abcd", connections=[a1.input, a2.input], value=[[2.0]], expose=True + name="abcd", connections={a1.input, a2.input}, value=[[2.0]], expose=True ) model.extend( mat_mul := MatrixMultiply(), left=con_object, output=IOKey(name="output") @@ -851,7 +851,7 @@ def test_connect_type_conv_handling_1(): model.extend((a1 := Buffer()), input="input1") model.extend((a2 := Buffer()), input="input2") con_object = IOKey( - connections=["input1", "input2"], value=[[2.0]], name="abcd", expose=True + connections={"input1", "input2"}, value=[[2.0]], name="abcd", expose=True ) model.extend( (mat_mul := MatrixMultiply()), left=con_object, output=IOKey(name="output") @@ -864,7 +864,7 @@ def test_connect_type_conv_handling_1(): model.extend((a1 := Buffer()), input="input1") model.extend((a2 := Buffer()), input="input2") con_object = IOKey( - connections=["input1", a2.input], value=[[2.0]], name="abcd", expose=True + connections={"input1", a2.input}, value=[[2.0]], name="abcd", expose=True ) model.extend( (mat_mul := MatrixMultiply()), left=con_object, output=IOKey(name="output") @@ -895,8 +895,8 @@ def test_connect_1(): model += concat_model( input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True) model += Sigmoid()(input=conn, output=IOKey(name="output1")) assert ( @@ -917,8 +917,8 @@ def test_connect_2(): model += concat_model( input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True) model += ToTensor()(conn) @@ -939,8 +939,8 @@ def test_connect_3(): model += concat_model( input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True, value=3.0) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True, value=3.0) model += (to_tensor := ToTensor())(conn) @@ -972,13 +972,13 @@ def test_connect_4(): input1="input1", input2="input2", input3="input3", output=IOKey(name="output") ) model += union_model(input1="") - conn_list = [ + conns = { concat_model.input1, # type: ignore concat_model.input2, # type: ignore concat_model.input3, # type: ignore union_model.input1.tensor(), # type: ignore - ] - conn = IOKey(connections=conn_list, name="abcd", expose=True, value=(3, 2)) + } + conn = IOKey(connections=conns, name="abcd", expose=True, value=(3, 2)) model += Buffer()(input=conn, output=IOKey(name="output1")) pm = compile(model=model, backend=backend, jit=False, inference=True) @@ -1002,8 +1002,8 @@ def test_connect_6(): model = Model() concat_model = Concat(n=3) model += concat_model(input1=[[3.0]], output=IOKey(name="output")) - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, name="abcd", expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, name="abcd", expose=True) model += Buffer()(input=conn, output=IOKey(name="output1")) @@ -1033,7 +1033,7 @@ def test_connect_7(): model += add_model_2(left="left1", right="right1") conn = IOKey( - connections=[add_model_2.output, model.right], # type: ignore + connections={add_model_2.output, model.right}, # type: ignore name="abcd", expose=False, ) @@ -1083,8 +1083,8 @@ def test_connect_7_expose_output(): add_model_2 = Add() model += add_model_1(left="left", right="right", output=IOKey(name="output2")) model += add_model_2(left="left1", right="right1") - conn_list = [add_model_2.output, model.right] # type: ignore - conn = IOKey(name="abcd", expose=True, connections=conn_list) # type: ignore + conns = {add_model_2.output, model.right} # type: ignore + conn = IOKey(name="abcd", expose=True, connections=conns) # type: ignore model += (buf := Buffer())(input=conn, output=IOKey(name="output")) assert ( @@ -1140,7 +1140,7 @@ def test_connect_8(): left=add_model_1.output, right="right1", output=IOKey(name="output1") ) conn = IOKey( - connections=[add_model_1.output, model.right1], # type: ignore + connections={add_model_1.output, model.right1}, # type: ignore name="abcd", expose=False, ) @@ -1177,8 +1177,8 @@ def test_connect_9(): model = Model() concat_model = Concat(n=3) model += concat_model(input1=[[3.0]], input2=[[2.0]], input3="input3") - con_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=con_list) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns) with pytest.raises(ValueError) as err_info: model += Buffer()(input=conn, output=IOKey(name="output")) @@ -1194,8 +1194,8 @@ def test_connect_10(): model = Model() concat_model = Concat(n=3) model += concat_model(input1=[[3.0]], input3="input3") - conn_list = [concat_model.input1, concat_model.input2, concat_model.input3] # type: ignore - conn = IOKey(connections=conn_list, value=2.0, expose=True) + conns = {concat_model.input1, concat_model.input2, concat_model.input3} # type: ignore + conn = IOKey(connections=conns, value=2.0, expose=True) with pytest.raises(ValueError) as err_info: model += Buffer()(input=conn, output=IOKey(name="output")) @@ -1220,13 +1220,13 @@ def test_connect_11(): union_model = PrimitiveUnion(n=2) model += concat_model(input1="", output=IOKey(name="output1")) model += union_model(input1="", output=IOKey(name="output2")) - conn_list = [ + conns = { concat_model.input1, # type: ignore concat_model.input2, # type: ignore union_model.input1, # type: ignore union_model.input2, # type: ignore - ] - conn = IOKey(connections=conn_list, value=(2.0,), expose=True) + } + conn = IOKey(connections=conns, value=(2.0,), expose=True) model += Buffer()(input=conn, output=IOKey(name="output3")) pm = compile(model=model, backend=backend, jit=False) @@ -1257,12 +1257,12 @@ def test_connect_12(): model += concat_model(input1="", output=IOKey(name="output1")) model += union_model(input1="", output=IOKey(name="output2")) conn = IOKey( - connections=[ + connections={ concat_model.input1, # type: ignore concat_model.input2, # type: ignore union_model.input1, # type: ignore union_model.input2, # type: ignore - ], + }, value=(2.0,), ) model += Buffer()(input=conn, output=IOKey(name="output3")) @@ -1315,7 +1315,7 @@ def test_tensor_to_scalar_connect_1(): axis2 = mean_model_2.axis axis3 = mean_model_3.axis - con = IOKey(connections=[axis1, axis2, axis3], name="axis4", value=(2, 3)) + con = IOKey(connections={axis1, axis2, axis3}, name="axis4", value=(2, 3)) model += Mean(axis=TBD)(axis=con) assert axis1.data.metadata == axis2.data.metadata == axis3.data.metadata @@ -1336,7 +1336,7 @@ def test_tensor_to_scalar_connect_3_error_existing_key(): model += mean_model_2(axis="axis2") model += mean_model_3(axis="axis3") - con = IOKey(connections=[axis1, axis2, axis3], name="axis2", value=(2, 3)) + con = IOKey(connections={axis1, axis2, axis3}, name="axis2", value=(2, 3)) model += Mean(axis=TBD)(axis=con) @@ -1511,7 +1511,7 @@ def test_tensor_to_scalar_template_1(): model += buff_model_1(input="input1") in1 = buff_model_1.output - out1 = in1.shape().tensor() ** 2 + out1 = in1.get_shape().tensor() ** 2 model += Buffer()(input=out1, output="output") model.set_shapes({"input1": [3, 4, 5, 6]}) @@ -1534,7 +1534,7 @@ def test_tensor_to_scalar_template_2(): in1 = buff_model_1.output in2 = buff_model_2.output in3 = buff_model_3.output - out1 = (in1.shape().tensor() ** 2 * in2) @ in3 / 2 + out1 = (in1.get_shape().tensor() ** 2 * in2) @ in3 / 2 model += Buffer()(input=out1, output="output") pm = compile(model=model, backend=backend) diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 9f8103e1..c415b71d 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -390,7 +390,7 @@ def test_type_16(): with pytest.raises(TypeError) as err_info: model += sig_model_2( - input=IOKey(connections=[sig_model_1.input], value=[False, True]), + input=IOKey(connections={sig_model_1.input}, value=[False, True]), output=IOKey(name="output2"), ) assert str(err_info.value) == ( @@ -409,7 +409,7 @@ def test_type_17(): model.extend( sig_model_2, input=IOKey( - connections=[sig_model_1.input], + connections={sig_model_1.input}, value=[False, True], name="a", expose=True, From 00d7cc124b579384fd630fe239eedf64aa31e73c Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Wed, 25 Dec 2024 13:17:45 +0300 Subject: [PATCH 8/9] Make IOKey a dataclass derived from BaseKey, remove __init__ of IOKey, update tests --- mithril/framework/common.py | 36 ++++++-------------- mithril/framework/logical/base.py | 15 +++------ mithril/framework/logical/model.py | 43 ++++++++++-------------- mithril/models/primitives.py | 2 +- mithril/utils/dict_conversions.py | 14 ++++---- tests/scripts/test_key_values_in_init.py | 30 ++++++++--------- 6 files changed, 56 insertions(+), 84 deletions(-) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 1007ce8a..48934901 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -1102,6 +1102,7 @@ def __init__( @dataclass class BaseKey: + name: str | None = None value: TensorValueType | MainValueType | ToBeDetermined | str = TBD shape: ShapeTemplateType | None = None type: NestedListType | UnionType | type | None = None @@ -1109,39 +1110,24 @@ class BaseKey: @dataclass -class IOKey(TemplateBase): - def __init__( - self, - name: str | None = None, - value: TensorValueType | MainValueType | ToBeDetermined | str = TBD, - shape: ShapeTemplateType | None = None, - type: NestedListType | UnionType | type | None = None, - expose: bool | None = None, - interval: list[float | int] | None = None, - connections: set[Connection | str] | None = None, - ) -> None: - super().__init__() - self._name = name - self._value = value - self._shape = shape - self._type = type - self._expose = expose - self._interval = interval - self._connections: set[Connection | str] = connections or set() +class IOKey(BaseKey, TemplateBase): + expose: bool | None = None + connections: set[Connection | str] = field(default_factory=lambda: set()) + def __post_init__(self): # TODO: Shape should not be [] also! - if self._value is not TBD and self._shape is not None and self._shape != []: + if self.value is not TBD and self.shape is not None and self.shape != []: raise ValueError( f"Scalar values are shapeless, shape should be None or []. " - f"Got {self._shape}." + f"Got {self.shape}." ) - if self._value is not TBD and self._type is not None: - value_type = find_type(self._value) - if find_intersection_type(value_type, self._type) is None: + if self.value is not TBD and self.type is not None: + value_type = find_type(self.value) + if find_intersection_type(value_type, self.type) is None: raise TypeError( f"type of the given value and given type does not match. Given " - f"type is {self._type} while type of value is {value_type}" + f"type is {self.type} while type of value is {value_type}" ) def __eq__(self, other: object): diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index d596d156..b7fffae2 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -16,7 +16,7 @@ import abc from collections.abc import Mapping -from dataclasses import dataclass +from dataclasses import dataclass, replace from itertools import chain from types import UnionType from typing import Any @@ -103,21 +103,14 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: f"has already being set to {val}!" ) case str(): - kwargs[key] = IOKey(con, value=val, expose=False) + kwargs[key] = IOKey(name=con, value=val, expose=False) case IOKey(): - if con._value is not TBD and con._value != val: + if con.value is not TBD and con.value != val: raise ValueError( f"Given IOKey for local key: '{key}' is not valid!" ) else: - kwargs[key] = IOKey( - name=con._name, - value=val, - shape=con._shape, - type=con._type, - expose=con._expose, - connections=con._connections, - ) + kwargs[key] = replace(con, value=val) case ExtendTemplate(): raise ValueError( "Multi-write detected for a valued " diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index e9f374d0..6b69365e 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Mapping +from dataclasses import replace from types import UnionType from typing import Any, Self @@ -343,26 +344,18 @@ def _convert_to_iokey( assert isinstance(connection, MainValueInstance) connection = IOKey(value=connection) case IOKey(): - expose = connection._expose - name = connection._name - # TODO: This check should be removed: conn._connections==set() + expose = connection.expose + name = connection.name + # TODO: This check should be removed: conn.connections==set() # We should not operate different if _connections is given. Fix this and # also fix corresponding tests and dict conversions with "connect". if ( expose is None and (name is None or self.conns.get_connection(name) is None) - and connection._connections == set() + and connection.connections == set() ): expose = True - # TODO: Add replicate method to IOKey (update def __call__ in BaseModel) - connection = IOKey( - name=name, - value=connection._value, - shape=connection._shape, - type=connection._type, - expose=expose, - connections=connection._connections, - ) + connection = replace(connection, name=name, expose=expose) case NotAvailable(): raise ValueError( f"Given value for key: '{key}' is not available. " @@ -400,14 +393,14 @@ def _add_connection( is_not_valued = local_connection.metadata.data.value is TBD d_map = self.dependency_map._local_output_dependency_map - expose = given_connection._expose - outer_key = given_connection._name + expose = given_connection.expose + outer_key = given_connection.name con_obj = None set_value: ToBeDetermined | str | MainValueType | NullConnection = NOT_GIVEN - if given_connection._value is not TBD: - set_value = given_connection._value + if given_connection.value is not TBD: + set_value = given_connection.value - if given_connection._connections == set(): + if given_connection.connections == set(): if outer_key is not None: con_obj = self.conns.get_connection(outer_key) if outer_key is None or con_obj is None: @@ -427,7 +420,7 @@ def _add_connection( ) else: initial_conn: ConnectionData - for idx, conn in enumerate(given_connection._connections): + for idx, conn in enumerate(given_connection.connections): if isinstance(conn, str): _conn = self.conns.get_connection(conn) else: @@ -715,11 +708,11 @@ def extend( } for local_key, value in io_keys.items(): - if value._shape is not None: - shape_info |= {local_key: value._shape} + if value.shape is not None: + shape_info |= {local_key: value.shape} - if value._type is not None: - type_info[local_key] = value._type + if value.type is not None: + type_info[local_key] = value.type con_obj, _updates = self._add_connection(model, local_key, value) updates |= _updates @@ -812,7 +805,7 @@ def _extend(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: kwargs[model._canonical_input.key] = self.canonical_output for key, value in kwargs.items(): - _value = value._name if isinstance(value, IOKey) else value + _value = value.name if isinstance(value, IOKey) else value if isinstance(_value, str) and _value == "": if key in model._input_keys: @@ -823,7 +816,7 @@ def _extend(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: ) if isinstance(value, IOKey): - value._name = None + value.name = None else: kwargs[key] = _value diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 983f8e96..a3beff13 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -2146,7 +2146,7 @@ def __call__( # type: ignore[override] and attn_mask is not NOT_GIVEN and not isinstance(attn_mask, str) and isinstance(attn_mask, BaseKey) - and attn_mask._value is not None # TODO: Here will be updated! + and attn_mask.value is not None # TODO: Here will be updated! ): raise KeyError( "Model does not have 'attn_mask' input. Got attn_mask argument!" diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 53a8e034..c41f9bb2 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -538,21 +538,21 @@ def item_to_json(item: IOKey): # TODO: Currently type is not supported for Tensors. # Handle This whit conversion test updates. result: dict[str, Any] = {} - if not isinstance(item._value, ToBeDetermined): - result["value"] = item._value - if item._shape is not None: + if not isinstance(item.value, ToBeDetermined): + result["value"] = item.value + if item.shape is not None: shape_template = [] - for symbol in item._shape: + for symbol in item.shape: if isinstance(symbol, tuple): # variadic shape_template.append(f"{symbol[0]},...") else: shape_template.append(str(symbol)) result["shape_template"] = shape_template - elif isinstance(item._type, UnionType): - result["type"] = [type_to_str(item) for item in item._type.__args__] + elif isinstance(item.type, UnionType): + result["type"] = [type_to_str(item) for item in item.type.__args__] else: result["type"] = [ - type_to_str(item._type), + type_to_str(item.type), ] return result diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index 7a642233..4bd3cf5b 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -27,9 +27,9 @@ def test_directed_call_connection(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == {connection} - assert left_info._name is None - assert left_info._value == 1 + assert left_info.connections == {connection} + assert left_info.name is None + assert left_info.value == 1 def test_directed_call_int(): @@ -58,8 +58,8 @@ def test_directed_call_str(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 + assert left_info.name == "in1" + assert left_info.value == 1 def test_directed_call_iokey_value_equal(): @@ -70,8 +70,8 @@ def test_directed_call_iokey_value_equal(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 + assert left_info.name == "in1" + assert left_info.value == 1 def test_directed_call_iokey_value_not_equal(): @@ -91,8 +91,8 @@ def test_directed_call_iokey_value_tbd(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.name == "in1" + assert left_info.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_not_equal(): @@ -112,8 +112,8 @@ def test_directed_call_connect_key_none(): info = add1(left=con, right="right") left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == {connection} - assert left_info._value == 1 # key is set to IOKey with val from factory_inputs + assert left_info.connections == {connection} + assert left_info.value == 1 # key is set to IOKey with val from factory_inputs def test_directed_call_connect_key_value_tbd(): @@ -125,9 +125,9 @@ def test_directed_call_connect_key_value_tbd(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == {connection} + assert left_info.connections == {connection} assert isinstance(left_info, ml.IOKey) - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_equal(): @@ -139,8 +139,8 @@ def test_directed_call_connect_key_value_equal(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._connections == {connection} - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.connections == {connection} + assert left_info.value == 1 # value is set to val from factory_inputs def test_directed_call_extend_template(): From 45cacb1cb913e0f797f3fb6e7a81598778904603 Mon Sep 17 00:00:00 2001 From: berat Tuna KARLI Date: Thu, 26 Dec 2024 15:25:33 +0300 Subject: [PATCH 9/9] IOKey updated --- examples/flux/auto_encoder.py | 4 +- examples/flux/layers.py | 14 +++--- examples/gpt/model.py | 2 +- mithril/framework/common.py | 49 ++++++++++++------- mithril/framework/logical/base.py | 13 +++-- mithril/framework/logical/model.py | 22 ++++++--- mithril/models/models.py | 16 +++--- mithril/utils/dict_conversions.py | 10 ++-- tests/scripts/test_extend_template.py | 4 +- tests/scripts/test_io_key.py | 2 +- tests/scripts/test_jittable.py | 2 +- tests/scripts/test_key_values_in_init.py | 14 +++--- tests/scripts/test_parallel.py | 8 +-- .../scripts/test_tuple_list_args_in_extend.py | 24 ++++----- tests/scripts/test_type_coercion.py | 30 ++++++------ 15 files changed, 119 insertions(+), 95 deletions(-) diff --git a/examples/flux/auto_encoder.py b/examples/flux/auto_encoder.py index 6b8868da..5cda7307 100644 --- a/examples/flux/auto_encoder.py +++ b/examples/flux/auto_encoder.py @@ -66,7 +66,7 @@ def attn_block(n_channels: int, name: str | None = None): key = block.key # type: ignore[attr-defined] value = block.value # type: ignore[attr-defined] - shape = query.get_shape() + shape = query.shape query = query.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1])) key = key.transpose((0, 2, 3, 1)).reshape((shape[0], 1, -1, shape[1])) @@ -92,7 +92,7 @@ def downsample(n_channels: int): def upsample(n_channels: int, name: str | None = None): block = Model(enforce_jit=False, name=name) # TODO: Remove enfor jit false input = IOKey("input") - input_shape = input.get_shape() + input_shape = input.shape B, C, H, W = input_shape[0], input_shape[1], input_shape[2], input_shape[3] input = input[:, :, :, None, :, None] diff --git a/examples/flux/layers.py b/examples/flux/layers.py index be327e85..dcf0758f 100644 --- a/examples/flux/layers.py +++ b/examples/flux/layers.py @@ -65,8 +65,8 @@ def apply_rope() -> Model: xk = IOKey("xk") freqs_cis = IOKey("freqs_cis") - xq_shape = xq.get_shape() - xk_shape = xk.get_shape() + xq_shape = xq.shape + xk_shape = xk.shape B, L, H = xq_shape[0], xq_shape[1], xq_shape[2] block += Reshape()(xq, shape=(B, L, H, -1, 1, 2), output="xq_") B, L, H = xk_shape[0], xk_shape[1], xk_shape[2] @@ -96,7 +96,7 @@ def attention() -> Model: ) # We can get named connection as model.'connection_name' - context_shape = block.context.get_shape() # type: ignore[attr-defined] + context_shape = block.context.shape # type: ignore[attr-defined] block += Transpose(axes=(0, 2, 1, 3))(block.context) # type: ignore[attr-defined] # NOTE: Reshape input is automatically connected to Transpose output block += Reshape()( @@ -137,7 +137,7 @@ def modulation(dim: int, double: bool, name: str | None = None): def rearrange(num_heads: int): block = Model() input = IOKey("input") - input_shaepe = input.get_shape() + input_shaepe = input.shape B, L = input_shaepe[0], input_shaepe[1] block += Reshape()(shape=(B, L, 3, num_heads, -1)) block += Transpose(axes=(2, 0, 3, 1, 4))(output=IOKey("output")) @@ -209,7 +209,7 @@ def double_stream_block( block += Concat(axis=2, n=2)(input1=txt_v, input2=img_v, output="v_concat") block += attention()(q="q_concat", k="k_concat", v="v_concat", pe=pe, output="attn") - # TODO: use'[:, txt.get_shape()[1] :]' when fixed. + # TODO: use'[:, txt.shape[1] :]' when fixed. img_attn = block.attn[:, 256:] # type: ignore[attr-defined] block += Linear(hidden_size, name="img_attn_proj")(img_attn, output="img_proj") @@ -234,7 +234,7 @@ def double_stream_block( ) img = img + block.img_mod_2[2] * block.img_mlp # type: ignore[attr-defined] - # TODO: Use txt.get_shape()[1]] + # TODO: Use txt.shape[1]] txt_attn = block.attn[:, :256] # type: ignore[attr-defined] block += Linear(hidden_size, name="txt_attn_proj")(txt_attn, output="txt_proj") @@ -355,7 +355,7 @@ def rope(dim: int, theta: int) -> Model: omega = 1.0 / (theta ** (block.arange / dim)) # type: ignore out = input[..., None] * omega - out_shape = out.get_shape() + out_shape = out.shape B, N, D = out_shape[0], out_shape[1], out_shape[2] block += Cosine()(out, output="cos") diff --git a/examples/gpt/model.py b/examples/gpt/model.py index 3f7a539f..4a73aaa2 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/model.py @@ -41,7 +41,7 @@ def causal_attention(input_dim, num_heads, bias=True): model += Linear(input_dim * 3, name="c_attn")("input", output="c_attn_out") t_axes = (0, 2, 1, 3) - shp_con = model.input.get_shape() # type: ignore + shp_con = model.input.shape # type: ignore reshape_con = (shp_con[0], shp_con[1], num_heads, -1) model += Split(3, axis=-1)(model.c_attn_out, output="split_out") # type: ignore diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 506ac769..90192e30 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -1001,7 +1001,8 @@ def abs(self): def len(self): return ExtendTemplate(connections=[self], model="len") - def get_shape(self): + @property + def shape(self): return ExtendTemplate(connections=[self], model="shape") def reshape(self, shape: tuple[int, ...] | TemplateBase): @@ -1102,40 +1103,50 @@ def __init__( @dataclass class BaseKey: - name: str | None = None value: TensorValueType | MainValueType | ToBeDetermined | str = TBD shape: ShapeTemplateType | None = None type: NestedListType | UnionType | type | None = None interval: list[float | int] | None = None -@dataclass -class IOKey(BaseKey, TemplateBase): - expose: bool | None = None - connections: set[Connection | str] = field(default_factory=lambda: set()) +class IOKey(TemplateBase): + def __init__( + self, + name: str | None = None, + value: TensorValueType | MainValueType | ToBeDetermined | str = TBD, + shape: ShapeTemplateType | None = None, + type: NestedListType | UnionType | type | None = None, + expose: bool | None = None, + interval: list[float | int] | None = None, + connections: set[Connection | str] | None = None, + ) -> None: + super().__init__() + self.name = name + self.expose = expose + if connections is None: + connections = set() + self.connections: set[Connection | str] = connections + self.data = BaseKey(value, shape, type, interval) - def __post_init__(self): # TODO: Shape should not be [] also! - if self.value is not TBD and self.shape is not None and self.shape != []: + if ( + self.data.value is not TBD + and self.data.shape is not None + and self.data.shape != [] + ): raise ValueError( f"Scalar values are shapeless, shape should be None or []. " - f"Got {self.shape}." + f"Got {self.data.shape}." ) - if self.value is not TBD and self.type is not None: - value_type = find_type(self.value) - if find_intersection_type(value_type, self.type) is None: + if self.data.value is not TBD and self.data.type is not None: + value_type = find_type(self.data.value) + if find_intersection_type(value_type, self.data.type) is None: raise TypeError( f"type of the given value and given type does not match. Given " - f"type is {self.type} while type of value is {value_type}" + f"type is {self.data.type} while type of value is {value_type}" ) - def __eq__(self, other: object): - if isinstance(other, int | float | bool | list | Connection | IOKey | tuple): - return ExtendTemplate(connections=[self, other], model="eq") - else: - raise ValueError("Unsupported type for equality operation.") - class Connection(TemplateBase): def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool): diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 43a64fa6..74e30fe1 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -16,7 +16,7 @@ import abc from collections.abc import Mapping -from dataclasses import dataclass, replace +from dataclasses import dataclass from itertools import chain from types import UnionType from typing import Any @@ -105,12 +105,19 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: case str(): kwargs[key] = IOKey(name=con, value=val, expose=False) case IOKey(): - if con.value is not TBD and con.value != val: + if con.data.value is not TBD and con.data.value != val: raise ValueError( f"Given IOKey for local key: '{key}' is not valid!" ) else: - kwargs[key] = replace(con, value=val) + kwargs[key] = IOKey( + name=con.name, + expose=con.expose, + connections=con.connections, + type=con.data.type, + shape=con.data.shape, + value=val, + ) case ExtendTemplate(): raise ValueError( "Multi-write detected for a valued " diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 88a7ebd2..e9d976ea 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections.abc import Mapping -from dataclasses import replace from types import UnionType from typing import Any, Self @@ -355,7 +354,14 @@ def _convert_to_iokey( and connection.connections == set() ): expose = True - connection = replace(connection, name=name, expose=expose) + connection = IOKey( + name=name, + expose=expose, + connections=connection.connections, + type=connection.data.type, + shape=connection.data.shape, + value=connection.data.value, + ) case NotAvailable(): raise ValueError( f"Given value for key: '{key}' is not available. " @@ -395,8 +401,8 @@ def _add_connection( outer_key = given_connection.name con_obj = None set_value: ToBeDetermined | str | MainValueType | NullConnection = NOT_GIVEN - if given_connection.value is not TBD: - set_value = given_connection.value + if given_connection.data.value is not TBD: + set_value = given_connection.data.value if given_connection.connections == set(): if outer_key is not None: @@ -706,11 +712,11 @@ def extend( } for local_key, value in io_keys.items(): - if value.shape is not None: - shape_info |= {local_key: value.shape} + if value.data.shape is not None: + shape_info |= {local_key: value.data.shape} - if value.type is not None: - type_info[local_key] = value.type + if value.data.type is not None: + type_info[local_key] = value.data.type con_obj, _updates = self._add_connection(model, local_key, value) updates |= _updates diff --git a/mithril/models/models.py b/mithril/models/models.py index 0959a492..47ad5b2a 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -776,7 +776,7 @@ def __init__( # Assumed input shape is [N, C, H, W] input_key = IOKey(name="input") - input_shape = input_key.get_shape() + input_shape = input_key.shape B = input_shape[0] input_key = input_key.reshape((B, num_groups, -1)) @@ -2899,7 +2899,7 @@ def __init__( )("pred", "label", "metric_out", "pred_formatted", "label_formatted") true_predictions = self.metric_out == 0 - n_prediction = self.label_formatted.get_shape()[0] + n_prediction = self.label_formatted.shape[0] self += Sum()(input=true_predictions, output="n_true_predictions") self += Divide()( @@ -3009,13 +3009,13 @@ def __init__( self += Divide()( numerator=sum_precision, - denominator=self.n_classes.get_shape()[0].tensor(), + denominator=self.n_classes.shape[0].tensor(), output=IOKey(name="output"), ) elif average == "weighted": precision = None - n_element = self.label_formatted.get_shape()[0] + n_element = self.label_formatted.shape[0] assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" @@ -3153,7 +3153,7 @@ def __init__( self += Divide()( numerator=sum_recall, - denominator=self.n_classes.get_shape()[0].tensor(), + denominator=self.n_classes.shape[0].tensor(), output=IOKey(name="output"), ) @@ -3162,7 +3162,7 @@ def __init__( assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" - n_element = self.label_formatted.get_shape()[0] + n_element = self.label_formatted.shape[0] for idx in range(n_classes): class_idxs = self.label_formatted == idx true_positive = (self.metric_out == 0) & class_idxs @@ -3299,7 +3299,7 @@ def __init__( self += Unique()(input=self.label_formatted, output="n_classes") self += Divide()( numerator=sum_precision, - denominator=self.n_classes.get_shape()[0].tensor(), + denominator=self.n_classes.shape[0].tensor(), output=IOKey(name="output"), ) @@ -3308,7 +3308,7 @@ def __init__( assert ( n_classes is not None ), "n_classes must be provided if average is or 'weighted'" - n_element = self.label_formatted.get_shape()[0].tensor() + n_element = self.label_formatted.shape[0].tensor() for idx in range(n_classes): class_idxs = self.label_formatted == idx true_positive = (self.metric_out == 0) & class_idxs diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index c41f9bb2..a77406ca 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -538,8 +538,8 @@ def item_to_json(item: IOKey): # TODO: Currently type is not supported for Tensors. # Handle This whit conversion test updates. result: dict[str, Any] = {} - if not isinstance(item.value, ToBeDetermined): - result["value"] = item.value + if not isinstance(item.data.value, ToBeDetermined): + result["value"] = item.data.value if item.shape is not None: shape_template = [] for symbol in item.shape: @@ -549,10 +549,10 @@ def item_to_json(item: IOKey): shape_template.append(str(symbol)) result["shape_template"] = shape_template - elif isinstance(item.type, UnionType): - result["type"] = [type_to_str(item) for item in item.type.__args__] + elif isinstance(item.data.type, UnionType): + result["type"] = [type_to_str(item) for item in item.data.type.__args__] else: result["type"] = [ - type_to_str(item.type), + type_to_str(item.data.type), ] return result diff --git a/tests/scripts/test_extend_template.py b/tests/scripts/test_extend_template.py index e3eeb83e..a192cab1 100644 --- a/tests/scripts/test_extend_template.py +++ b/tests/scripts/test_extend_template.py @@ -150,7 +150,7 @@ def test_shape_reshape(): # Create with shortcut. model_1 = Model() model_1 += (lin_1 := Linear(dimension=1))(input="input_1", weight="w_1", bias="b_1") - shp = lin_1.input.get_shape() + shp = lin_1.input.shape model_1 += (lin_2 := Linear(dimension=2))(input="input_2", weight="w_2", bias="b_2") reshaped = lin_2.output.reshape(shp) model_1 += Add()(left=lin_1.output, right=reshaped, output=IOKey(name="output")) @@ -210,7 +210,7 @@ def test_slice_item(): model_1 += (lin_1 := Linear(dimension=1))( input="input", weight="weight", bias="bias" ) - shp = lin_1.input.get_shape() + shp = lin_1.input.shape item = shp[1].tensor() slc = shp[:].tensor() model_1 += Add()(left=item, right=slc, output=IOKey(name="output")) diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index e7f74189..ce717da9 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -1245,7 +1245,7 @@ def test_iokey_template_4(): model = Model() left = IOKey("left") - res = left.get_shape()[0] + res = left.shape[0] model += Buffer()(res.tensor(), IOKey("output")) diff --git a/tests/scripts/test_jittable.py b/tests/scripts/test_jittable.py index 044952ea..2ae0831c 100644 --- a/tests/scripts/test_jittable.py +++ b/tests/scripts/test_jittable.py @@ -367,7 +367,7 @@ def test_jit_2(): model = Model(enforce_jit=False) model += (add_model := Add())(left="left", right="right") in1 = add_model.output - out1 = in1.get_shape() + out1 = in1.shape out2 = out1.tensor().sum() mean_model = Mean(axis=TBD) model += (to_list := Item())(input=out2) diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index 4bd3cf5b..a0763227 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -29,7 +29,7 @@ def test_directed_call_connection(): assert isinstance(left_info, ml.IOKey) assert left_info.connections == {connection} assert left_info.name is None - assert left_info.value == 1 + assert left_info.data.value == 1 def test_directed_call_int(): @@ -59,7 +59,7 @@ def test_directed_call_str(): assert isinstance(left_info, ml.IOKey) assert left_info.name == "in1" - assert left_info.value == 1 + assert left_info.data.value == 1 def test_directed_call_iokey_value_equal(): @@ -71,7 +71,7 @@ def test_directed_call_iokey_value_equal(): assert isinstance(left_info, ml.IOKey) assert left_info.name == "in1" - assert left_info.value == 1 + assert left_info.data.value == 1 def test_directed_call_iokey_value_not_equal(): @@ -92,7 +92,7 @@ def test_directed_call_iokey_value_tbd(): assert isinstance(left_info, ml.IOKey) assert left_info.name == "in1" - assert left_info.value == 1 # value is set to val from factory_inputs + assert left_info.data.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_not_equal(): @@ -113,7 +113,7 @@ def test_directed_call_connect_key_none(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) assert left_info.connections == {connection} - assert left_info.value == 1 # key is set to IOKey with val from factory_inputs + assert left_info.data.value == 1 # key is set to IOKey with val from factory_inputs def test_directed_call_connect_key_value_tbd(): @@ -127,7 +127,7 @@ def test_directed_call_connect_key_value_tbd(): assert isinstance(left_info, ml.IOKey) assert left_info.connections == {connection} assert isinstance(left_info, ml.IOKey) - assert left_info.value == 1 # value is set to val from factory_inputs + assert left_info.data.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_equal(): @@ -140,7 +140,7 @@ def test_directed_call_connect_key_value_equal(): left_info = info._connections["left"] assert isinstance(left_info, ml.IOKey) assert left_info.connections == {connection} - assert left_info.value == 1 # value is set to val from factory_inputs + assert left_info.data.value == 1 # value is set to val from factory_inputs def test_directed_call_extend_template(): diff --git a/tests/scripts/test_parallel.py b/tests/scripts/test_parallel.py index 3641aa7d..ea8767a6 100644 --- a/tests/scripts/test_parallel.py +++ b/tests/scripts/test_parallel.py @@ -383,7 +383,7 @@ def test_torch_parallel_2(): # primitive eye. model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = create_parallel_backend(device_mesh=(4, 1)) backend.ones([256]) @@ -507,7 +507,7 @@ def test_torch_parallel_5(): # primitive eye. model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.TorchBackend() @@ -957,7 +957,7 @@ def test_jax_parallel_2(): if "cuda" in mithril.JaxBackend.get_available_devices(): model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.JaxBackend(device="cuda", device_mesh=(4, 1)) backend.ones([256]) @@ -1090,7 +1090,7 @@ def test_jax_parallel_5(): if "cuda" in mithril.JaxBackend.get_available_devices(): model = Model() model += (linear := Linear(256))(input="input", weight="w", bias="b") - model += (e := Eye(N=TBD))(N=linear.output.get_shape()[0]) + model += (e := Eye(N=TBD))(N=linear.output.shape[0]) model += Add()(left=linear.output, right=e.output, output="output") backend = mithril.JaxBackend(device="cuda") diff --git a/tests/scripts/test_tuple_list_args_in_extend.py b/tests/scripts/test_tuple_list_args_in_extend.py index b8158d19..87de7266 100644 --- a/tests/scripts/test_tuple_list_args_in_extend.py +++ b/tests/scripts/test_tuple_list_args_in_extend.py @@ -80,7 +80,7 @@ def test_tuple_argument_3(): add_model_2 = Add() model += add_model(left="left", right="right") model += add_model_2( - left=(add_model.left.get_shape(), add_model.right.get_shape()), + left=(add_model.left.shape, add_model.right.shape), right=add_model.left + add_model.right, output="output", ) @@ -110,7 +110,7 @@ def test_tuple_argument_4(): add_model_2 = Add() model += add_model(left="left", right="right") model += add_model_2( - left=(add_model.left.get_shape() * 2, add_model.right.get_shape() * 2), + left=(add_model.left.shape * 2, add_model.right.shape * 2), right=add_model.left + add_model.right, output="output", ) @@ -138,8 +138,8 @@ def test_tuple_argument_5(): model += add_model(left="left", right="right") model += add_model_2( left=( - (add_model.left.get_shape()[0], add_model.left.get_shape()[0]), - (add_model.left.get_shape()[0], add_model.left.get_shape()[0]), + (add_model.left.shape[0], add_model.left.shape[0]), + (add_model.left.shape[0], add_model.left.shape[0]), ), right=add_model.left + add_model.right, output="output", @@ -168,8 +168,8 @@ def test_list_tuple_mixed_argument_1(): model += add_model(left="left", right="right") model += add_model_2( left=( - [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], - [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], + [add_model.left.shape[0], add_model.left.shape[0]], + [add_model.left.shape[0], add_model.left.shape[0]], ), right=add_model.left + add_model.right, output="output", @@ -197,8 +197,8 @@ def test_list_tuple_mixed_argument_2(): model += add_model(left="left", right="right") - left_first_shape = add_model.left.get_shape()[0] - right_first_shape = add_model.right.get_shape()[0] + left_first_shape = add_model.left.shape[0] + right_first_shape = add_model.right.shape[0] matmul_left = ([left_first_shape, 0], [2, right_first_shape]) @@ -296,7 +296,7 @@ def test_list_argument_3(): model += add_model(left="left", right="right") model += add_model_2( - left=[add_model.left.get_shape(), add_model.right.get_shape()], + left=[add_model.left.shape, add_model.right.shape], right=add_model.left + add_model.right, output="output", ) @@ -327,7 +327,7 @@ def test_list_argument_4(): model += add_model(left="left", right="right") model += add_model_2( - left=[add_model.left.get_shape() * 2, add_model.right.get_shape() * 2], + left=[add_model.left.shape * 2, add_model.right.shape * 2], right=add_model.left + add_model.right, output="output", ) @@ -356,8 +356,8 @@ def test_list_argument_5(): model += add_model(left="left", right="right") model += add_model_2( left=[ - [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], - [add_model.left.get_shape()[0], add_model.left.get_shape()[0]], + [add_model.left.shape[0], add_model.left.shape[0]], + [add_model.left.shape[0], add_model.left.shape[0]], ], right=add_model.left + add_model.right, output="output", diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 616c80c0..31a725ce 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -116,7 +116,7 @@ def test_scalar_to_tensor_2(): lin_2 = Linear(dimension=2) model += lin_1(input="input_1", weight="w_1", bias="b_1") model += lin_2(input="input_2", weight="w_2", bias="b_2") - shp_1 = lin_1.input.get_shape() + shp_1 = lin_1.input.shape reshaped_1 = lin_2.output.reshape(shp_1) to_tensor = ToTensor() model += to_tensor(input=shp_1) @@ -129,7 +129,7 @@ def test_scalar_to_tensor_2(): lin_4 = Linear(dimension=2) model += lin_3(input="input_1", weight="w_1", bias="b_1") model += lin_4(input="input_2", weight="w_2", bias="b_2") - shp_2 = lin_3.input.get_shape() + shp_2 = lin_3.input.shape reshaped_2 = lin_4.output.reshape(shp_2) model += Add()(left=shp_2.tensor(), right=reshaped_2, output="output") model_2 = model @@ -185,7 +185,7 @@ def test_scalar_to_tensor_3(): def test_tensor_to_scalar_1(): - """Model enforces Jit so we reshape with to_tensor_1_output.get_shape(). + """Model enforces Jit so we reshape with to_tensor_1_output.shape. We can not directly reshape with to_tensor_1_output which is valued as [2, 1] in tensor domain since it requires TensorToList conversion before being argument to reshape method. @@ -199,7 +199,7 @@ def test_tensor_to_scalar_1(): model += to_tensor_1(input=[2, 1]) model += to_tensor_2(input=[[1, 1]]) model += add_1(left=to_tensor_1.output, right=to_tensor_2.output) - reshaped_1 = add_1.output.reshape(to_tensor_1.output.get_shape()) + reshaped_1 = add_1.output.reshape(to_tensor_1.output.shape) model += Buffer()(input=reshaped_1, output="output") model_1 = model @@ -209,7 +209,7 @@ def test_tensor_to_scalar_1(): left = IOKey(value=[2, 1]).tensor() right = IOKey(value=[1, 1]).tensor() model += add_2(left=left, right=right) - reshaped_2 = add_2.output.reshape(add_2.left.get_shape()) + reshaped_2 = add_2.output.reshape(add_2.left.shape) model += Buffer()(input=reshaped_2, output="output") model_2 = model @@ -282,7 +282,7 @@ def test_slice_item_conversions(): model = Model() lin_2 = Linear(dimension=1) model += lin_2(input="input", weight="w", bias="b") - shp2 = lin_2.input.get_shape() + shp2 = lin_2.input.shape shp2_1 = shp2[1] assert shp2_1 is not None shp_item = shp2_1.tensor() @@ -307,7 +307,7 @@ def test_tuple_conversion_1(): model = Model() lin_1 = Linear(dimension=2) model += lin_1(input="input", weight="w", bias="b") - shp1 = lin_1.output.get_shape() + shp1 = lin_1.output.shape model += ToTensor()(input=(shp1[0], shp1[1]), output="output") model_1 = model @@ -316,7 +316,7 @@ def test_tuple_conversion_1(): lin_2 = Linear(dimension=2) tupl = ToTuple(n=2) model += lin_2(input="input", weight="w", bias="b") - shp2 = lin_2.output.get_shape() + shp2 = lin_2.output.shape model += tupl(input1=shp2[0], input2=shp2[1]) model += ToTensor()(input=tupl.output, output="output") # type: ignore model_2 = model @@ -337,7 +337,7 @@ def test_tuple_conversion_2(): lin_1 = Linear(dimension=2) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.get_shape() + shp1 = lin_1.input.shape model += tt1(input=(shp1[0], shp1[1])) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -383,7 +383,7 @@ def test_tuple_conversion_3(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.get_shape() + shp1 = lin_1.input.shape model += tt1(input=(shp1[0], shp1[1], 3)) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -429,7 +429,7 @@ def test_list_conversion_1(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.get_shape() + shp1 = lin_1.input.shape model += tt1(input=[shp1[0], shp1[1], 3.0]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -474,7 +474,7 @@ def test_nested_list_conversion_1(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input=[[1], [2.0]], weight="w", bias="b") - shp1 = lin_1.input.get_shape() + shp1 = lin_1.input.shape model += tt1(input=[[shp1[0], shp1[1], 3.0]]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -519,7 +519,7 @@ def test_nested_list_conversion_2(): lin_1 = Linear(dimension=3) tt1 = ToTensor() model += lin_1(input="input", weight="w", bias="b") - shp1 = lin_1.input.get_shape() + shp1 = lin_1.input.shape model += tt1(input=[[shp1[0], shp1[1], 3.0]]) model += Add()(left=lin_1.output, right=tt1.output, output="output") model_1 = model @@ -1511,7 +1511,7 @@ def test_tensor_to_scalar_template_1(): model += buff_model_1(input="input1") in1 = buff_model_1.output - out1 = in1.get_shape().tensor() ** 2 + out1 = in1.shape.tensor() ** 2 model += Buffer()(input=out1, output="output") model.set_shapes({"input1": [3, 4, 5, 6]}) @@ -1534,7 +1534,7 @@ def test_tensor_to_scalar_template_2(): in1 = buff_model_1.output in2 = buff_model_2.output in3 = buff_model_3.output - out1 = (in1.get_shape().tensor() ** 2 * in2) @ in3 / 2 + out1 = (in1.shape.tensor() ** 2 * in2) @ in3 / 2 model += Buffer()(input=out1, output="output") pm = compile(model=model, backend=backend)