diff --git a/benchmarks/speed_benchmarks/speed_helper.py b/benchmarks/speed_benchmarks/speed_helper.py index 060e8995..b59c6961 100644 --- a/benchmarks/speed_benchmarks/speed_helper.py +++ b/benchmarks/speed_benchmarks/speed_helper.py @@ -18,7 +18,6 @@ from mithril.models import ( MLP, - BaseModel, Convolution2D, Flatten, MaxPool2D, @@ -84,7 +83,7 @@ def create_compl_conv( def create_compl_mlp( input_size: int, dimensions: Sequence[int | None], - activations: list[type[BaseModel]], + activations: list[type[Model]], ): """Mithril's MLP wrapper with input size diff --git a/mithril/__init__.py b/mithril/__init__.py index f568f77e..166f82a2 100644 --- a/mithril/__init__.py +++ b/mithril/__init__.py @@ -35,9 +35,10 @@ short, ) from .framework.codegen import code_gen_map -from .framework.common import TBD, Connection, IOKey +from .framework.common import TBD +from .framework.logical import Connection, IOKey from .framework.physical.model import PhysicalConstantType, PhysicalShapeType -from .models import BaseModel, PhysicalModel +from .models import Model, PhysicalModel from .models.train_model import TrainModel __all__ = [ @@ -97,7 +98,7 @@ def compile( - model: BaseModel, + model: Model, backend: Backend[DataType], *, constant_keys: PhysicalConstantType[DataType] | None = None, diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index 4b0da810..63cd7cf2 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -35,7 +35,7 @@ find_intersection_type, is_type_adjustment_required, ) -from ..logical import PrimitiveModel +from ..logical import Operator from .python_gen import PythonCodeGen, RawGradientType from .utils import check_repr_inequality @@ -212,11 +212,11 @@ def evaluate_gradients_wrapper_manualgrad( def get_primitive_details( self, output_key: str - ) -> tuple[PrimitiveModel, list[str], list[str]]: + ) -> tuple[Operator, list[str], list[str]]: model = self.pm.flat_graph.get_model(output_key) global_input_keys = self.pm.flat_graph.get_source_keys(output_key) - global_input_keys += [self.get_cache_name(output_key, model)] + global_input_keys += [self.get_cache_name(output_key)] local_input_keys = list(model.input_keys) + ["cache"] return model, global_input_keys, local_input_keys @@ -229,7 +229,7 @@ def is_static_scalar(self, key: str) -> bool: def call_primitive( self, - model: PrimitiveModel, + model: Operator, fn: Callable[..., Any], l_input_keys: list[str], g_input_keys: list[str], @@ -259,7 +259,7 @@ def call_primitive( return ast.Assign(targets, generated_fn), used_keys | _used_keys def create_primitive_call_targets( - self, output_key: str, model: PrimitiveModel, inference: bool + self, output_key: str, model: Operator, inference: bool ) -> tuple[list[ast.expr | ast.Name], set[str]]: targets: list[ast.expr | ast.Name] = [] @@ -271,27 +271,26 @@ def create_primitive_call_targets( if not self.pm.inference: # TODO: Change this with cache refactor - cache_name = output_key + f"_{model.cache_name}" + cache_name = output_key + f"_{Operator.cache_name}" used_keys.add(cache_name) targets.append( ast.Subscript( value=ast.Name(id=cache_name, ctx=ast.Load()), - slice=ast.Constant(value=PrimitiveModel.output_key), + slice=ast.Constant(value=Operator.output_key), ctx=ast.Store(), ) ) return targets, used_keys - def get_cache_name(self, output_key: str, model: PrimitiveModel) -> str: - cache_name = "_".join([output_key, model.cache_name]) + def get_cache_name(self, output_key: str) -> str: + cache_name = "_".join([output_key, Operator.cache_name]) if cache_name not in self.pm.flat_graph.all_data: - self.add_cache(model, output_key) + self.add_cache(output_key, cache_name) return cache_name - def add_cache(self, model: PrimitiveModel, output_key: str) -> None: - cache_name = "_".join([output_key, model.cache_name]) + def add_cache(self, output_key: str, cache_name: str) -> None: cache_value: dict[str, Any] | None = None if self.pm.inference else {} # Create a scalar for caches in manualgrad backend. self.pm.flat_graph.update_data( diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index aeaaa7e7..5b005428 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -31,7 +31,7 @@ EvaluateType, ParamsEvalType, ) -from ..logical import PrimitiveModel +from ..logical import Operator from ..physical.model import PhysicalModel from ..utils import GeneratedFunction from .code_gen import CodeGen @@ -307,7 +307,7 @@ def is_static_scalar(self, key: str) -> bool: def get_primitive_details( self, output_key: str - ) -> tuple[PrimitiveModel, list[str], list[str]]: + ) -> tuple[Operator, list[str], list[str]]: model = self.pm.flat_graph.get_model(output_key) global_input_keys = self.pm.flat_graph.get_source_keys(output_key) @@ -317,7 +317,7 @@ def get_primitive_details( def call_primitive( self, - model: PrimitiveModel, + model: Operator, fn: Callable[..., Any], l_input_keys: list[str], g_input_keys: list[str], @@ -354,7 +354,6 @@ def generate_evaluate(self) -> ast.FunctionDef: for output_key in self.pm.flat_graph.topological_order: model, g_input_keys, l_input_keys = self.get_primitive_details(output_key) formula_key = model.formula_key - primitive_function = ( self.pm.backend.primitive_function_dict[formula_key] if formula_key in self.pm.backend.primitive_function_dict @@ -568,7 +567,7 @@ def create_primitive_call( return generated_fn, used_keys def create_primitive_call_targets( - self, output_key: str, model: PrimitiveModel, inference: bool + self, output_key: str, model: Operator, inference: bool ) -> tuple[list[ast.expr], set[str]]: if ( keyword.iskeyword(output_key) diff --git a/mithril/framework/codegen/torch_gen.py b/mithril/framework/codegen/torch_gen.py index b3298a0a..826a807f 100644 --- a/mithril/framework/codegen/torch_gen.py +++ b/mithril/framework/codegen/torch_gen.py @@ -19,7 +19,7 @@ import torch from ...backends.with_autograd.torch_backend import TorchBackend -from ..logical import PrimitiveModel +from ..logical import Operator from ..physical.model import PhysicalModel from .python_gen import PythonCodeGen @@ -34,7 +34,7 @@ def __init__(self, pm: PhysicalModel[torch.Tensor]) -> None: def call_primitive( self, - model: PrimitiveModel, + model: Operator, fn: Callable[..., Any], l_input_keys: list[str], g_input_keys: list[str], diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 9d07b780..0000069f 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -52,11 +52,10 @@ "get_shapes", "NOT_GIVEN", "TBD", - "IOKey", "KeyType", - "ConnectionType", + "ConnectionDataType", + "ConnectionDataInstanceType", "IOHyperEdge", - "Connection", "ConnectionData", "Connections", "Tensor", @@ -104,15 +103,6 @@ class NullConnection(SingletonObject): pass -class Auto(SingletonObject): - """ - A singleton class representing a configuration - setting of automatically handled arguments. - """ - - pass - - class ToBeDetermined(SingletonObject): """ A singleton class representing a null data indicating @@ -124,7 +114,6 @@ class ToBeDetermined(SingletonObject): NOT_GIVEN = NullConnection() TBD = ToBeDetermined() -AUTO = Auto() class UpdateType(Enum): @@ -1157,316 +1146,7 @@ def remove_constraint(self, constraint: Constraint) -> None: self.constraints[type].discard(constraint) -class TemplateBase: - def __getitem__( - self, - key: slice - | int - | EllipsisType - | tuple[slice | int | None | EllipsisType | TemplateBase, ...] - | IOKey - | TemplateBase - | None, - ) -> ExtendTemplate: - match key: - case slice(): - slice_output = ExtendTemplate( - connections=[key.start, key.stop, key.step], model="slice" - ) - output = ExtendTemplate( - connections=[self, slice_output], model="indexer" - ) - - case int() | EllipsisType() | None: - output = ExtendTemplate(connections=[self, key], model="indexer") - - case tuple(): - connections: list[TemplateBase | int | None | EllipsisType] = [] - for item in key: - if isinstance(item, slice): - slice_output = ExtendTemplate( - connections=[item.start, item.stop, item.step], - model="slice", - ) - connections.append(slice_output) - else: - connections.append(item) - tuple_template = ExtendTemplate( - connections=connections, # type: ignore - model="to_tuple", - defaults={"n": len(key)}, - ) - output = ExtendTemplate( - connections=[self, tuple_template], model="indexer" - ) - return output - - def __add__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="add") - - def __radd__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="add") - - def __sub__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="sub") - - def __rsub__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="sub") - - def __mul__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="mul") - - def __rmul__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="mul") - - def __truediv__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="div") - - def __rtruediv__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="div") - - def __floordiv__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="fdiv") - - def __rfloordiv__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="fdiv") - - def __pow__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate( - connections=[self, other], model="pow", defaults={"robust": False} - ) - - def __rpow__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate( - connections=[other, self], model="pow", defaults={"robust": False} - ) - - def __matmul__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="matmul") - - def __gt__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="gt") - - def __rgt__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="gt") - - def __ge__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="ge") - - def __rge__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="ge") - - def __lt__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="lt") - - def __rlt__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="lt") - - def __le__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="le") - - def __rle__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="le") - - def __eq__(self, other: object) -> ExtendTemplate: # type: ignore[override] - if isinstance( - other, int | float | bool | list | Connection | IOKey | tuple | Tensor - ): - return ExtendTemplate(connections=[self, other], model="eq") - else: - raise ValueError("Unsupported type for equality operation.") - - def __req__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="eq") - - def __ne__(self, other: object) -> ExtendTemplate: # type: ignore[override] - if isinstance( - other, int | float | bool | list | Connection | IOKey | tuple | Tensor - ): - return ExtendTemplate(connections=[self, other], model="ne") - else: - raise ValueError("Unsupported type for equality operation.") - - def __rne__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="ne") - - def __and__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="and") - - def __rand__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="and") - - def __or__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="or") - - def __ror__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="or") - - def __xor__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="xor") - - def __rxor__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="xor") - - def __lshift__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="lshift") - - def __rlshift__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="lshift") - - def __rshift__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[self, other], model="rshift") - - def __rrshift__(self, other: TemplateConnectionType) -> ExtendTemplate: - return ExtendTemplate(connections=[other, self], model="rshift") - - def __invert__(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="not") - - def __neg__(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="minus") - - def abs(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="abs") - - def len(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="len") - - @property - def shape(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="shape") - - def reshape( - self, shape: tuple[int | TemplateBase, ...] | TemplateBase - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, shape], model="reshape") - - def size( - self, dim: int | tuple[int, ...] | TemplateBase | None = None - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, dim], model="size") - - def tensor(self) -> ExtendTemplate: - return ExtendTemplate( - connections=[self], model="tensor", defaults={"dtype": None} - ) - - def mean( - self, - axis: int | tuple[int, ...] | TemplateBase | None = None, - keepdim: bool = False, - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, axis, keepdim], model="mean") - - def sum( - self, - axis: int | tuple[int, ...] | TemplateBase | None = None, - keepdim: bool = False, - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, axis, keepdim], model="sum") - - def max( - self, - axis: int | tuple[int, ...] | TemplateBase | None = None, - keepdim: bool = False, - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, axis, keepdim], model="max") - - def min( - self, - axis: int | tuple[int, ...] | TemplateBase | None = None, - keepdim: bool = False, - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, axis, keepdim], model="min") - - def prod( - self, - axis: int | tuple[int, ...] | TemplateBase | None = None, - keepdim: bool = False, - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, axis, keepdim], model="prod") - - def var( - self, - axis: int | tuple[int, ...] | TemplateBase | None = None, - keepdim: bool = False, - correction: float | None = 0.0, - ) -> ExtendTemplate: - return ExtendTemplate( - connections=[self, axis, keepdim, correction], model="var" - ) - - def sqrt(self) -> ExtendTemplate: - return ExtendTemplate( - connections=[self], model="sqrt", defaults={"robust": False} - ) - - def exp(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="exp") - - def transpose( - self, axes: tuple[int, ...] | TemplateBase | None = None - ) -> ExtendTemplate: - return ExtendTemplate(connections=[self, axes], model="transpose") - - def split(self, split_size: int, axis: int) -> ExtendTemplate: - return ExtendTemplate(connections=[self, split_size, axis], model="split") - - def item(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="item") - - def cast(self, dtype: Dtype | None = None) -> ExtendTemplate: - return ExtendTemplate(connections=[self, dtype], model="cast") - - def dtype(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="dtype") - - def sin(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="sin") - - def cos(self) -> ExtendTemplate: - return ExtendTemplate(connections=[self], model="cos") - - -class ExtendTemplate(TemplateBase): - output_connection: ConnectionData | None - - def __init__( - self, - connections: list[TemplateConnectionType], - model: str, - defaults: dict[str, Any] | None = None, - ) -> None: - for connection in connections: - if isinstance(connection, str): - raise ValueError( - "In extend template operations, 'str' is not a valid type." - ) - - self.connections = connections - self.model = model - - if defaults is None: - defaults = {} - self.defaults = defaults - self.output_connection = None - - -@dataclass class BaseKey: - value: ( - Tensor[int | float | bool] - | ScalarValueType - | TensorValueType - | ToBeDetermined - | str - ) = TBD - shape: ShapeTemplateType | None = None - type: UnionType | type | type[Tensor[int | float | bool]] | ScalarType | None = None - interval: list[float | int] | None = None - - -class IOKey(TemplateBase): def __init__( self, name: str | None = None, @@ -1482,11 +1162,12 @@ def __init__( | None = None, expose: bool | None = None, interval: list[float | int] | None = None, - connections: set[Connection | str] | None = None, + connections: set[ConnectionData | str] | None = None, ) -> None: - super().__init__() # If shape is provided, type should be Tensor. if shape is not None: + if type is Tensor: + type = Tensor[int | float | bool] if type is None: type = Tensor[int | float | bool] elif get_origin(type) is not Tensor: @@ -1499,53 +1180,30 @@ def __init__( self.expose = expose if connections is None: connections = set() - self.connections: set[Connection | str] = connections - self.data = BaseKey(value, shape, type, interval) + self.connections: set[ConnectionData | str] = connections # TODO: Shape should not be [] also! if ( - self.data.value is not TBD - and self.data.shape is not None - and self.data.shape != [] + value is not TBD + and not isinstance(value, Tensor) + and shape is not None + and shape != [] ): raise ValueError( f"Scalar values are shapeless, shape should be None or []. " - f"Got {self.data.shape}." + f"Got {shape}." ) - if self.data.value is not TBD and self.data.type is not None: - value_type = find_type(self.data.value) - if find_intersection_type(value_type, self.data.type) is None: + if value is not TBD and type is not None: + value_type = find_type(value) + if find_intersection_type(value_type, type) is None: raise TypeError( - f"type of the given value and given type does not match. Given " - f"type is {self.data.type} while type of value is {value_type}" + "type of the given value and given type does not match. Given " + f"type is {type} while type of value is {value_type}" ) - - -class Connection(TemplateBase): - def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool): - self.data = ConnectionData(key, metadata, is_key_autogenerated, self) - - @property - def key(self) -> str: - return self.data.key - - @property - def metadata(self) -> IOHyperEdge: - return self.data.metadata - - def set_differentiable(self, differentiable: bool = True) -> None: - self.data.set_differentiable(differentiable) - - def __hash__(self) -> int: - return hash(id(self)) - - -ShapesType = ( - Mapping[str | Connection, ShapeTemplateType] - | Mapping[str, ShapeTemplateType] - | Mapping[Connection, ShapeTemplateType] -) -ShapeResultType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None] + self.value = value + self.value_shape = shape + self.type = type + self.interval = interval @dataclass @@ -1561,7 +1219,6 @@ class ConnectionData: metadata: IOHyperEdge # TODO: remove is_key_autogenerated field is_key_autogenerated: bool - conn: Connection def __hash__(self) -> int: return hash(id(self)) @@ -1579,36 +1236,25 @@ def set_differentiable(self, differentiable: bool = True) -> None: self.metadata.differentiable = differentiable -TemplateConnectionType = ( - TemplateBase - | int - | float - | list[int | float] - | EllipsisType - | tuple[slice | int | None | EllipsisType | TemplateBase, ...] - | None - | Tensor[int | float | bool] +ShapesType = ( + Mapping[str | ConnectionData, ShapeTemplateType] + | Mapping[str, ShapeTemplateType] + | Mapping[ConnectionData, ShapeTemplateType] ) +ShapeResultType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None] -ConnectionType = ( +ConnectionDataType = ( str | MainValueType - | ExtendTemplate | NullConnection - | IOKey - | Connection + | BaseKey + | ConnectionData | Tensor[int | float | bool] ) -ConnectionInstanceType = ( - str - | MainValueInstance - | ExtendTemplate - | NullConnection - | IOKey - | Connection - | Tensor # type: ignore +ConnectionDataInstanceType = ( + str | MainValueInstance | NullConnection | BaseKey | ConnectionData | Tensor # type: ignore ) @@ -1722,6 +1368,14 @@ def remove_connection(self, connection: ConnectionData) -> None: def get_data(self, key: str) -> IOHyperEdge: return self.get_metadata(key) + def get_type(self, key: ConnectionData) -> KeyType: + con_data = self.get_extracted_connection(key) + for _key_type in KeyType: + key_dict = self._connection_dict[_key_type] + if key_dict.get(con_data.key) is not None: + return _key_type + raise ValueError("No matching key type found!") + def get_non_diff_keys(self) -> set[str]: return {key for key, conn in self.all.items() if conn.metadata.is_non_diff} @@ -1768,15 +1422,15 @@ def get_shape_node(self, key: str) -> ShapeNode: def set_value(self, con: ConnectionData, value: MainValueType) -> None: self.get_data(con.key).set_value(value) - def extract_metadata(self, key: str | Connection) -> IOHyperEdge: - if isinstance(key, Connection): + def extract_metadata(self, key: str | ConnectionData) -> IOHyperEdge: + if isinstance(key, ConnectionData): # Extract the key from the Connection object. metadata = key.metadata else: metadata = self.get_metadata(key) return metadata - def get_extracted_connection(self, key: str | Connection) -> ConnectionData: + def get_extracted_connection(self, key: str | ConnectionData) -> ConnectionData: if (result := self.get_con_by_metadata(self.extract_metadata(key))) is None: raise KeyError("Connection is not found!") return result diff --git a/mithril/framework/logical/__init__.py b/mithril/framework/logical/__init__.py index 84074043..472a8994 100644 --- a/mithril/framework/logical/__init__.py +++ b/mithril/framework/logical/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .base import * # noqa F403 -from .essential_primitives import * # noqa F403 +from .operators import * # noqa F403 from .model import * # noqa F403 +from .operator import * # noqa F403 from .primitive import * # noqa F403 diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index fcf2d057..23ded642 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -14,27 +14,30 @@ from __future__ import annotations -import abc from collections.abc import KeysView, Mapping -from dataclasses import dataclass from itertools import chain from types import UnionType from typing import Any from ...utils.utils import OrderedSet from ..common import ( + NOT_GIVEN, TBD, AssignedConstraintType, - Connection, + BaseKey, ConnectionData, + ConnectionDataType, Connections, - ConnectionType, Constraint, ConstraintFunctionType, ConstraintSolver, IOHyperEdge, + KeyType, + MainValueInstance, MainValueType, + NullConnection, ScalarType, + ScalarValueType, ShapeNode, ShapesType, ShapeTemplateType, @@ -50,35 +53,10 @@ ) from ..constraints import constraint_type_map -__all__ = ["BaseModel", "ExtendInfo"] +__all__ = ["BaseModel"] -@dataclass -class ExtendInfo: - _model: BaseModel - _connections: dict[str, ConnectionType] - - def __post_init__(self) -> None: - external_keys = ( - set(self._model.external_keys) - | {item.key for item in self._model.conns.couts} - | {item.key for item in self._model.conns.cins} - ) - - for key in self._connections: - if key not in external_keys: - raise KeyError(f"Key '{key}' is not a valid key for the model!") - - @property - def model(self) -> BaseModel: - return self._model - - @property - def connections(self) -> dict[str, ConnectionType]: - return self._connections - - -class BaseModel(abc.ABC): +class BaseModel: # Disposable models only used once for entire training session. # This attribute is only use for manual backends' code generation. @@ -87,13 +65,17 @@ class BaseModel(abc.ABC): # TODO: factory_args should be instance variable not class! factory_args: dict[str, Any] = {} - def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: - return ExtendInfo(self, kwargs) + def __init__( + self, + name: str | None = None, + formula_key: str | None = None, + enforce_jit: bool = True, + ) -> None: + self.dag: dict[BaseModel, dict[str, ConnectionData]] = {} + self._formula_key: str | None = formula_key - def __init__(self, name: str | None = None, enforce_jit: bool = True) -> None: - self.parent: BaseModel | None = ( - None # TODO: maybe set it only to PrimitiveModel / Model. - ) + # TODO: maybe set it only to Operator / Model. + self.parent: BaseModel | None = None self.assigned_shapes: list[ShapesType] = [] self.assigned_types: dict[ str, @@ -109,19 +91,902 @@ def __init__(self, name: str | None = None, enforce_jit: bool = True) -> None: self.constraint_solver: ConstraintSolver = ConstraintSolver() self.safe_shapes: dict[str, ShapeTemplateType] = {} self.is_frozen = False + self.inter_key_count = 0 + + @property + def formula_key(self) -> str | None: + return self._formula_key + + def create_key_name(self) -> str: + self.inter_key_count += 1 + return "$" + str(self.inter_key_count) + + def _set_outputs( + self, *args: str | ConnectionData, **kwargs: str | ConnectionData + ) -> None: + if self.parent is not None: + raise Exception("Child model's outputs cannot be set.") + # Convert all args and kwargs to tuple. + # Convert all args and kwargs to tuple. + pairs = tuple([(None, arg) for arg in args]) + tuple(kwargs.items()) + + for pair in pairs: + new_name, name = pair + metadata = self.conns.extract_metadata(name) + + # Check the connection is valid. + if (conn_data := self.conns.get_con_by_metadata(metadata)) is None: + raise KeyError("Requires valid key or Connection to set output!") - @abc.abstractmethod - def summary( + # Check if given metadata is already an output. + if conn_data in self.conns.output_connections: + raise KeyError(f"'{conn_data.key}' key is already set as output!") + + if conn_data in self.conns.input_connections: + raise KeyError("Input of the overall model cannot be set as output.") + + # Autogenerated keys can not be set directly as output without a name. + if new_name is None and conn_data.key.startswith("$"): + raise KeyError( + "Autogenerated keys can only be set as output if" + " a name is provided for the connection as keyworded argument." + ) + + if new_name is not None: # Non-named connections. + if new_name in self.conns.all: + raise KeyError(f"Key '{new_name}' is already used!") + self.update_key_name(conn_data, new_name) + self.conns.set_connection_type(conn_data, KeyType.OUTPUT) + self.dependency_map.update_globals(OrderedSet({conn_data})) + + def _check_multi_write( self, - shapes: bool = True, - types: bool = False, - symbolic: bool = False, - name: str | None = None, - alternative_shapes: bool = False, - uni_cache: dict[UniadicRecord, str] | None = None, - var_cache: dict[Variadic, str] | None = None, + local_input: bool, + local_connection: ConnectionData, + connection: ConnectionData, + ) -> None: + conn_is_output = ( + self.dependency_map.local_output_dependency_map.get(connection, None) + is not None + ) + if local_connection.key in self.conns.all and connection.key in self.conns.all: + local_conn_is_output = ( + self.dependency_map.local_output_dependency_map.get( + local_connection, None + ) + is not None + ) + if ( + conn_is_output + and local_conn_is_output + and local_connection.key != connection.key + ): + # Check if 2 connections are part of main model. If it is the case, + # We expect at least one of them is not an input of the main model, + # otherwise condition is Multi-write error + raise Exception( + "Given connections are both output connections. Multi-write error!" + ) + + local_val = local_connection.metadata.value + global_val = connection.metadata.value + + if conn_is_output and not local_input: + # Check if 2 connections are both output of any models. + raise Exception( + "Given connections are both output connections. Multi-write error!" + ) + elif ( + local_input + and local_val is not TBD + # and global_val is not TBD + and conn_is_output + and global_val != local_val + ): + raise ValueError( + "An input of the extending model tries to write " + "to an output connection in the extended model. " + "Multi-write error!" + ) + elif not local_input and global_val is not TBD and local_val != global_val: + raise ValueError( + "A valued connection of the extended model tries to write " + "to an output connection of the extending model. " + "Multi-write error!" + ) + + def _prepare_keys( + self, model: BaseModel, key: str, connection: ConnectionDataType + ) -> BaseKey: + local_connection = model.conns.get_connection(key) + assert local_connection is not None, "Connection is not found!" + match connection: + case NullConnection(): + _connection = BaseKey() + case str(): + _connection = BaseKey(name=connection) + case ConnectionData(): + _connection = BaseKey(connections={connection}) + case _ if isinstance(connection, MainValueInstance | Tensor): + # find_dominant_type returns the dominant type in a container. + # If a container has a value of type Connection or ExtendTemplate + # we add necessary models. + _connection = BaseKey(value=connection) + case BaseKey(): + expose = connection.expose + name = connection.name + # TODO: This check should be removed: conn.connections==set() + # We should not operate different if _connections is given. Fix this and + # also fix corresponding tests and dict conversions with "connect". + if ( + expose is None + and (name is None or self.conns.get_connection(name) is None) + and connection.connections == set() + ): + expose = True + _connection = BaseKey( + name=name, + expose=expose, + connections=connection.connections, + type=connection.type, + shape=connection.value_shape, + value=connection.value, + ) + + return _connection + + def _add_connection( + self, + model: BaseModel, + local_key: str, + given_connection: BaseKey, + updates: Updates, + ) -> tuple[ConnectionData, Updates]: + is_input = local_key in model.input_keys + local_connection = model.conns.get_connection(local_key) + assert local_connection is not None, "Connection is not found!" + is_not_valued = local_connection.metadata.value is TBD + + d_map = self.dependency_map.local_output_dependency_map + expose = given_connection.expose + outer_key = given_connection.name + con_obj = None + set_value: ( + ToBeDetermined + | str + | ScalarValueType + | Tensor[int | float | bool] + | NullConnection + ) = NOT_GIVEN + if given_connection.value is not TBD: + set_value = given_connection.value + + if given_connection.connections == set(): + if outer_key is not None: + con_obj = self.conns.get_connection(outer_key) + if outer_key is None or con_obj is None: + if expose is None and is_input and is_not_valued: + expose = True + con_obj = self._create_connection(local_connection.metadata, outer_key) + if ( + expose is False + and is_input + and set_value is NOT_GIVEN + and local_connection.metadata.value is TBD + and (con_obj is None or con_obj not in d_map) + ): + raise ValueError( + "Expose flag cannot be false when " + "no value is provided for input keys!" + ) + else: + initial_conn: ConnectionData + for idx, conn in enumerate(given_connection.connections): + if isinstance(conn, str): + _conn = self.conns.get_connection(conn) + else: + _conn = self.conns.get_con_by_metadata(conn.metadata) + if conn in model.conns.all.values(): + raise ValueError( + f"Given connection '{conn.key}' should not " + "belong to the extending model!" + ) + + if not isinstance(_conn, ConnectionData): + raise KeyError("Requires accessible connection to be processed!") + if idx == 0: + initial_conn = _conn + if outer_key is not None: + self.update_key_name(initial_conn, outer_key) + else: + if _conn in d_map: + if initial_conn in d_map: + raise KeyError( + "IOKey object can not have more than one output " + "connection. Multi-write error!" + ) + initial_conn, _conn = _conn, initial_conn + if ( + not outer_key + and not initial_conn.is_key_autogenerated + and not _conn.is_key_autogenerated + ): + raise KeyError( + "Requires a connection to have only one unique key " + "name but encountered more!" + ) + updates |= self.merge_connections(initial_conn, _conn) + if outer_key is None and not initial_conn.is_key_autogenerated: + outer_key = initial_conn.key + if not outer_key and initial_conn in d_map and expose is True: + raise KeyError("Connection without a name cannot be set as output") + con_obj = initial_conn + + # Name "input" can only be used for input connections. + is_key_name_input = con_obj is not None and (con_key := con_obj.key) == "input" + if not is_input and (outer_key == "input" or is_key_name_input): + raise KeyError( + "The key 'input' is a reserved key which could not be used for " + "internal keys." + ) + + if not is_input and not isinstance(set_value, NullConnection): + raise KeyError( + f"{local_key} key is an output of the model, output values could " + "not be set in extend." + ) + + # Inherit submodel connections dict + self.conns.connections_dict.setdefault(local_connection.metadata, set()) + self.conns.connections_dict[local_connection.metadata] |= ( + model.conns.connections_dict.pop(local_connection.metadata, set()) + ) + + # If any value provided, set. + assert con_obj is not None + if not isinstance(set_value, NullConnection): + updates |= con_obj.metadata.set_value(set_value) + + # Check multi-write error for con_obj. + self._check_multi_write(is_input, local_connection, con_obj) + + # If match required, perform. + if con_obj.metadata != local_connection.metadata: + local_key_origin = local_connection.metadata.key_origin + updates |= self._match_hyper_edges( + con_obj.metadata, local_connection.metadata + ) + # If local_connection is an output of the model, + # update con_obj "key_origin" with local_connection's key_origin. + if ( + not is_input + and outer_key not in self.conns.output_keys + or con_obj.metadata.key_origin is None + ): + con_obj.metadata.key_origin = local_key_origin + + unexposed = not (expose or (is_input and con_key in self.conns.io_keys)) + if unexposed: + if is_input: + key_type = (KeyType.LATENT_INPUT, KeyType.INTERNAL)[con_obj in d_map] + else: + key_type = (KeyType.LATENT_OUTPUT, KeyType.INTERNAL)[con_obj in d_map] + else: + key_type = (KeyType.OUTPUT, KeyType.INPUT)[ + is_input and con_obj not in d_map + ] + if con_obj in d_map: + self.conns.couts.discard(con_obj) + self.conns.set_connection_type(con_obj, key_type) + # Update Canonicals + if ( + local_connection in model.conns.cins + and con_obj in self.conns.input_connections + and con_obj.metadata.value is TBD + ): + self.conns.cins.add(con_obj) + + if local_connection in model.conns.couts and ( + con_obj not in self.dependency_map.local_input_dependency_map + or con_obj in self.conns.output_connections + ): + self.conns.couts.add(con_obj) + + # If any type provided, set using models set_types method + # in order to execute constraint solver to propagate type + # updates along the model keys. + if (set_type := given_connection.type) is not None: + model._set_types({local_connection: set_type}) + + return con_obj, updates + + def update_key_name(self, connection: ConnectionData, key: str) -> None: + con_data = self.conns.get_extracted_connection(connection) + key_type = self.conns.get_type(con_data) + key_dict = self.conns._connection_dict[key_type] + key_dict[key] = key_dict.pop(con_data.key) + # Update con_data key + con_data.key = key + con_data.is_key_autogenerated = False + con_data.metadata.key_origin = key + + def merge_connections( + self, connection1: ConnectionData, connection2: ConnectionData + ) -> Updates: + # This method is used if there is 2 Connection objects to represent same Edge. + # In this case, connection2 is updated with connection1's data and it is removed + # from dag, dependency_map, self attribute (if exists) and Connections object. + + # TODO: Check multi-write error for Connect type. + + main_connection1 = self.conns.get_con_by_metadata(connection1.metadata) + main_connection2 = self.conns.get_con_by_metadata(connection2.metadata) + + if ( + main_connection1 is None + or main_connection2 is None + or main_connection1 == main_connection2 + ): + return Updates() + + # Remove main_connection2 from connections dict + con1_key = main_connection1.key + + if connection2 in self.conns.output_connections: + if con1_key not in self.conns.output_keys: + self.conns.set_connection_type(connection1, KeyType.OUTPUT) + if con1_key in self.input_keys: + self.conns.set_connection_type(main_connection1, KeyType.INTERNAL) + elif ( + main_connection2 in self.conns.internal_connections + and con1_key in self.input_keys + ): + self.conns.set_connection_type(main_connection1, KeyType.INTERNAL) + + # Switch all connection2 objects with connection1 object in current dag. + for m, m_info in self.dag.items(): + local_conns = m.conns.get_cons_by_metadata(main_connection2.metadata) + if local_conns is None: + continue + + for local_conn in local_conns: + if m_info.get(local_conn.key) is not None: + self.dag[m][local_conn.key] = main_connection1 + + # Update dependecy map, we need to update only local maps + for ( + o_conn, + key_info, + ) in self.dependency_map.local_output_dependency_map.items(): + if main_connection2 in key_info[1]: + self.dependency_map.local_output_dependency_map[o_conn][1].remove( + main_connection2 + ) + self.dependency_map.local_output_dependency_map[o_conn][1].add( + main_connection1 + ) + + if main_connection2 in self.dependency_map.local_output_dependency_map: + self.dependency_map.local_output_dependency_map[main_connection1] = ( + self.dependency_map.local_output_dependency_map.pop(main_connection2) + ) + + if main_connection2 in self.dependency_map.local_input_dependency_map: + old_dependencies = self.dependency_map.local_input_dependency_map.pop( + main_connection2 + ) + self.dependency_map.local_input_dependency_map.setdefault( + main_connection1, old_dependencies + ) + for dependecy in old_dependencies: + if ( + dependecy + not in self.dependency_map.local_input_dependency_map[ + main_connection1 + ] + ): + self.dependency_map.local_input_dependency_map[ + main_connection1 + ].append(dependecy) + + self.dependency_map.merge_global_connections(main_connection1, main_connection2) + self.dependency_map.merge_global_caches(main_connection1, main_connection2) + updates = self._match_hyper_edges( + main_connection1.metadata, main_connection2.metadata + ) + + self.conns.remove_connection(main_connection2) + + main_connection2.key = main_connection1.key + main_connection2.is_key_autogenerated = main_connection1.is_key_autogenerated + return updates + + def extend( + self, + model: BaseModel | BaseModel, + **kwargs: ConnectionDataType, ) -> None: - raise NotImplementedError("Implement summary method!") + # Check possible errors before the extension. + model.check_extendability() + if self.parent is not None: + raise AttributeError("Child model could not be re-extended!") + if self == model: + raise KeyError("Model can not extend with itself!") + if self._enforce_jit and not model.jittable: + raise Exception( + "Model with enforced Jit can not be extended by a non-jittable model! \ + Jit can be unforced by setting enforce_jit = False" + ) + if model.name is not None: + # TODO: We could store model names in a set to check if it is unique. + for m in self.dag: + if m.name == model.name: + raise KeyError(f"Model already has a submodel named {model.name}.") + + model.parent = self + # Freeze the model. + model._freeze() + + updates = Updates() + + shape_info: dict[str, ShapeTemplateType] = {} + type_info: dict[ + str, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] = {} + + submodel_dag: dict[str, ConnectionData] = {} + updates = self.constraint_solver.match(model.constraint_solver) + + # Add canonical output if it is not in external_keys + external_keys = list(model.external_keys) + external_keys += [ + item.key for item in model.conns.couts if item.key not in external_keys + ] + + io_keys: dict[str, BaseKey] = { + key: self._prepare_keys(model, key, kwargs.get(key, NOT_GIVEN)) + for key in external_keys + } + + for local_key, value in io_keys.items(): + if value.value_shape is not None: + shape_info |= {local_key: value.value_shape} + + if value.type is not None: + type_info[local_key] = value.type + + con_obj, _updates = self._add_connection(model, local_key, value, updates) + updates |= _updates + submodel_dag[local_key] = con_obj + if con_obj.metadata.is_tensor: + updates.shape_updates.add(con_obj.metadata) + + # Replace shape info keys, which are local keys, with global equivalents. + shape_info = { + submodel_dag[key].key: template for key, template in shape_info.items() + } + type_info = { + submodel_dag[key].key: template for key, template in type_info.items() + } + + # Set given shapes. + self._set_shapes( + shape_info, + updates=updates, + ) # TODO: Should "trace" be set to True?. + + model.constraint_solver.clear() + model.conns.connections_dict = {} + + # Insert to self dag as a FrozenDict."" + # Since we update dag in merge_connections, we could not use FrozenDict. + self.dag[model] = model_dag = submodel_dag + + self.dependency_map.add_model_dag(model, model_dag) + + # Update jittablity by using model's jittablity. + self._jittable &= model.jittable + + @staticmethod + def _update_key_name( + new_key: str, + underscored_keys: set[str], + raw_keys: dict[str, list[str]], + key_mappings: dict[str, str], + key_origin: str, + input_set: set[str], + ) -> tuple[str, str]: + # Add underscore if generated key name exists in input keys + key_prefix = "_" + # check any of key_prefix + raw_keys[key_origin] in input keys + flag = True + while flag: + flag = False + for item in raw_keys[key_origin]: + if key_prefix + key_mappings[item] in input_set | set( + key_mappings.values() + ): + key_prefix += "_" + flag = True + + new_key = key_prefix + new_key + underscored_keys.add(key_origin) + # Update same origin key names that has been previously added. + for raw_key in raw_keys[key_origin]: + key_mappings[raw_key] = key_prefix + key_mappings[raw_key] + raw_keys[key_prefix + key_origin] = raw_keys.pop(key_origin) + key_origin = key_prefix + key_origin + return new_key, key_origin + + def generate_keys( + self, + symbolic: bool = True, + include_internals: bool = True, + include_outputs: bool = False, + ) -> dict[str, str]: + if self.dag == {}: + return {} + key_mappings: dict[str, str] = {} + raw_keys: dict[str, list[str]] = {} + underscored_keys = set[str]() + + if include_outputs: + input_set = set(self.external_keys) + keys = "external_keys" + else: + input_set = set(self.input_keys) + keys = "input_keys" + + sorted_inputs = [ + self.dag[m][key].key + for m in self.get_models_in_topological_order() + for key in getattr(m, keys) + if self.dag[m][key].key in input_set + ] + # TODO: remove duplicate loop traverse + for key in sorted_inputs: + new_key = key + if key[0] != "$": + continue + # TODO: Discuss if we want to generate input key only + # if len(self.conns.cins) == 1, or we could name all canonical inputs. + if ( + len(self.conns.cins) == 1 + and key == self._cin.key + and "input" not in self.input_keys + ): + # Handle canonical input + new_key = "input" + else: + key_origin = self.conns.get_key_origin(key) + assert key_origin is not None + # Add prefix until key_origin not in underscored_keys and input_keys. + while ( + key_origin in (underscored_keys | self.input_keys) + or key_origin == "input" + ): + key_origin = "_" + key_origin + + raw_keys.setdefault(key_origin, []) + key_idx = len(raw_keys[key_origin]) + if key_idx == 0: + # Set key origin as is for the initial key. + key_suffix = "" + else: + key_suffix = "_" + str(key_idx) + if key_idx == 1: + # Update initial key if same key origin is encountered + # (add index to initial key). + raw_key = raw_keys[key_origin][0] + key_mappings[raw_key] = key_mappings[raw_key] + "_0" + if key_mappings[raw_key] in self.input_keys: + new_key, key_origin = self._update_key_name( + new_key, + underscored_keys, + raw_keys, + key_mappings, + key_origin, + set(self.input_keys), + ) + + new_key = key_origin + key_suffix + if new_key in self.input_keys: + new_key, key_origin = self._update_key_name( + new_key, + underscored_keys, + raw_keys, + key_mappings, + key_origin, + set(self.input_keys), + ) + raw_keys[key_origin].append(key) + key_mappings[key] = new_key + + if include_internals: + sorted_models = self.get_models_in_topological_order() + internal_key_mappings: dict[str, str] = {} + for idx, m in enumerate(sorted_models): + for key in m.external_keys: + outer_conn = self.dag[m][key] + outer_key = outer_conn.key + if outer_key[0] == "$": + # if key is autogenerated, generate a name for the key + model_name = m.class_name + key_origin = outer_conn.metadata.key_origin + assert key_origin is not None + + generated_name = ( + "_" + model_name + "_" + str(idx) + "_" + key_origin + ) + + # if key is an output key, directly write it + # to the internal_key_mappings + # or + # if key is an input key, first check if the key + # is already in internal_key mappings to avoid + # overwrite + write_to_internal_key_mappings = ( + key in m.conns.output_keys + or internal_key_mappings.get( + outer_key, key_mappings.get(outer_key) + ) + is None + ) + + while ( + generated_name in internal_key_mappings.values() + and write_to_internal_key_mappings + ): + assert key_origin is not None + key_origin = "_" + key_origin + generated_name = ( + "_" + model_name + "_" + str(idx) + "_" + key_origin + ) + + if write_to_internal_key_mappings: + internal_key_mappings[outer_key] = generated_name + + key_mappings = internal_key_mappings | key_mappings + if symbolic: + key_mappings = {key: "$" + value for key, value in key_mappings.items()} + return key_mappings + + def get_unique_submodel_names(self) -> dict[BaseModel, str]: + name_mapping: dict[BaseModel, str] = {} + existing_names: set[str] = set() + model_type_dict: dict[str, list[BaseModel]] = {} + + # First, assign existing names and track used names. + # Also save unnamed models to model_type_dict. + for model in self.dag: + if model.name: + name_mapping[model] = model.name + existing_names.add(model.name) + else: + model_type_dict.setdefault(model.class_name, []).append(model) + + # Iterate over different model types among unnamed models. + for model_type, model_list in model_type_dict.items(): + counter = 0 + # Iterate over same class model objects to name them. + for i, model in enumerate(model_list): + if len(model_list) == 1: + # If there is only one model of a type, do not increment counter. + counter -= 1 + name = model_type + else: + name = f"{model_type}_{counter + i}" + while name in existing_names: + counter += 1 # counter is incremented until a unique name is found. + name = f"{model_type}_{counter + i}" + name_mapping[model] = name + existing_names.add(name) + return name_mapping + + def _freeze(self) -> None: + for cout in self.conns.couts: + self.conns.set_connection_type(cout, KeyType.OUTPUT, safe=False) + self.dependency_map.update_all_keys() + + # Name unnamed submodels before freezing considering the insertion order. + model_names = self.get_unique_submodel_names() + for m in self.dag: + if m.name is None: + m.name = model_names[m] + + if self.formula_key is not None: + # Must be convertable to primitive. + assert len(self.conns.output_keys) == 1, ( + "Logical models have altenative primitive implementation must " + "have only 1 output." + ) + # super()._freeze() + self.is_frozen = True + + @staticmethod + def _reverse_dfs( + node: BaseModel, + graph: dict[BaseModel, OrderedSet[BaseModel]], + top_order: list[BaseModel], + visited: set[BaseModel], + ) -> None: + visited.add(node) + for m in graph[node]: + if m not in visited: + BaseModel._reverse_dfs(m, graph, top_order, visited) + top_order.append(node) + + def get_models_in_topological_order(self) -> list[BaseModel]: + dependency_map = self.dependency_map.local_output_dependency_map + graph = { + info[0]: OrderedSet( + [dependency_map[spec][0] for spec in info[1] if spec in dependency_map] + ) + for info in dependency_map.values() + } + top_order: list[BaseModel] = list() + visited: set[BaseModel] = set() + for model in graph: + if model not in top_order: + BaseModel._reverse_dfs(model, graph, top_order, visited) + return top_order + + # TODO: Summary should be isolated from the model. + def extract_connection_info( + self, + name_mappings: dict[BaseModel, str], + data_to_key_map: dict[IOHyperEdge, list[str]] | None = None, + data_memo: Mapping[int, IOHyperEdge] | None = None, + ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: + conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} + if self.input_keys: + if data_to_key_map is None: + data_to_key_map = {} + if data_memo is None: + data_memo = {} + model_key_map: dict[BaseModel, dict[str, str]] = {} + + # handle the case when model is constructed with += operation. In that case, + # directly take canonical output as the output_key. + + # TODO: We may expose all canonical outputs for summary instead of + # checking only len(self.conns.couts) == 1. + output_keys = ( + ([self._cout.key] if len(self.conns.couts) == 1 else []) + if not self.conns.output_keys + else self.conns.output_keys + ) + # extract key mappings and data map of outer model + key_mappings = self.generate_keys( + include_internals=False, include_outputs=True + ) + data_map = {key: conn.metadata for key, conn in self.conns.all.items()} + + # Sort in topological order + sorted_models = self.get_models_in_topological_order() + + for model in sorted_models: + model_name = name_mappings[model] + m_info = self.dag[model] + # set default structure of conn_info and shape_info + conns = conn_info.setdefault(model_name, ({}, {})) + # include input keys with Tensor value + input_keys = tuple(model.input_keys) + # Generate sub_model key_map and data map + model_key_map[model] = m_key_mappings = model.generate_keys( + include_internals=False, include_outputs=True + ) + m_data_map = { + key: conn.metadata for key, conn in model.conns.all.items() + } + for inner_key in input_keys + tuple(model.conns.output_keys): + # Find the data of the key, if data memo is given, extract its + # copied version and extract the shapes + key_data = data_memo.get( + id(m_data_map[inner_key]), m_data_map[inner_key] + ) + + # Find inner and outer keys. Also find their updated version based + # on their key mappings + updated_inner_key = m_key_mappings.get(inner_key, inner_key) + outer_conn = m_info[inner_key] + outer_key = outer_conn.key + updated_outer_key = data_to_key_map.get( + key_data, [key_mappings.get(outer_key, outer_key)] + ) + + # take and setdefault connection list in which update will be done + conn = conns[inner_key in model.conns.output_keys].setdefault( + updated_inner_key, [] + ) + if inner_key not in input_keys: + continue + + if (val := key_data.value) is not TBD: + conn.append(str(val)) + + elif outer_key in self.input_keys: + # If outer_key in input_keys of overall model, it means + # the input key is overall input to the model. Do the + # updates accordingly + input_name = ["'" + key + "'" for key in updated_outer_key] + conn.extend(input_name) + else: + # if input_key is not in self.input_keys, that means this + # input key connected to a model and it is an internal + # connection. Find the connected model and do the intializations + con_model = self.dependency_map.local_output_dependency_map[ + outer_conn + ][0] + con_generated_keys = model_key_map.setdefault( + con_model, + con_model.generate_keys( + include_internals=False, include_outputs=True + ), + ) + conn_info.setdefault(name_mappings[con_model], ({}, {})) + model_conn = m_info[inner_key] + con = con_model.conns.get_con_by_metadata(model_conn.metadata) + assert con is not None, "Connection is not found" + con_key = con.key + + con_key = con_generated_keys.get(con_key, con_key) + # Since being internal key means having two sided connection, + # Two updates on conn_info dict needs to be done. one for + # model's input key and other for connected_model's output + # key. do the updates accordingly. + conn_info[model_name][0].setdefault( + updated_inner_key, [] + ).append(name_mappings[con_model] + "." + con_key) + conn_info[name_mappings[con_model]][1].setdefault( + con_key, [] + ).append(model_name + "." + updated_inner_key) + + for outer_key in output_keys: + # Lastly, traverse through output keys of the overall model + # Find the connected model, and find the inner key by finding + # the metadata + metadata = self.conns.get_metadata(outer_key) + outer_out_conn = self.conns.get_connection(outer_key) + + assert metadata is not None, "Metadata is not found!" + assert outer_out_conn is not None, "Connection is not found" + + model = self.dependency_map.local_output_dependency_map[outer_out_conn][ + 0 + ] + other_conn = model.conns.get_con_by_metadata(metadata) + assert other_conn is not None, "Connection is not found" + + inner_key = other_conn.key + updated_inner_key = model_key_map[model].get(inner_key, inner_key) + key_data = data_memo.get(id(data_map[outer_key]), data_map[outer_key]) + updated_outer_key = data_to_key_map.get( + key_data, [key_mappings.get(outer_key, outer_key)] + ) + if updated_outer_key[0][0] == "$": + # There is only possibilty of outer key is found to be with $ sign. + # That is, if model is constructed with += operator. In that case, + # canonical output will be external key even if it is not named by + # user. Therefore, handle the case with dicrectly writing $output + updated_outer_key = ["$output"] + model_name = name_mappings[model] + conn_info[model_name][1][updated_inner_key].extend( + ["'" + key + "'" for key in updated_outer_key] + ) + + return conn_info + + @property + def grad_formula(self) -> str: + if self.formula_key is None: + raise AttributeError("Model has no formula key!") + return self.formula_key + "_grad" + + @property + def class_name(self) -> str: + return self.__class__.__name__ @property def enforce_jit(self) -> bool: @@ -169,14 +1034,6 @@ def _get_outermost_parent(self) -> BaseModel: model = model.parent return model - def generate_keys( - self, - symbolic: bool = True, - include_internals: bool = True, - include_outputs: bool = False, - ) -> dict[str, str]: - return {} - def __setattr__(self, name: str, value: Any) -> None: # You need to be careful here to avoid infinite recursion if ( @@ -190,42 +1047,30 @@ def __setattr__(self, name: str, value: Any) -> None: else: super().__setattr__(name, value) - def _freeze(self) -> None: - self.is_frozen = True - - @abc.abstractmethod - def extract_connection_info( - self, - name_mappings: dict[BaseModel, str], - data_to_key_map: dict[IOHyperEdge, list[str]] | None = None, - data_memo: Mapping[int, IOHyperEdge] | None = None, - ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: - raise NotImplementedError("Implement extract_connection_info method!") + def _create_key_name(self) -> str: + self.inter_key_count += 1 + return "$" + str(self.inter_key_count) + # TODO: Rename it to _create_connection_data def _create_connection( - self, metadata: IOHyperEdge, key: str, is_key_autogenerated: bool - ) -> Connection: - # Check if the key is already exist in the connections object. - if self.conns.get_connection(key) is not None: + self, metadata: IOHyperEdge, key: str | None = None + ) -> ConnectionData: + # If key is not provided, create a new key name and + # label it as auto-generated. + if key is not None and self.conns.get_connection(key) is not None: raise KeyError("Connection with name " + key + " already exists!") + if is_key_autogenerated := key is None: + key = self._create_key_name() - # Create connection object with given metadata, key - # and auto-generation status of the key. - con = Connection( - key=key, - metadata=metadata, - is_key_autogenerated=is_key_autogenerated, + con = ConnectionData( + key=key, metadata=metadata, is_key_autogenerated=is_key_autogenerated ) - # Add ConnectionData to the connections object. - self.conns.add(con.data) - return con + if not is_key_autogenerated: + # Set key_origin into metadata + metadata.key_origin = key - def create_connection(self, metadata: IOHyperEdge, key: str) -> ConnectionData: - con = self._create_connection(metadata, key, False) - # Set key_origin into metadata - metadata.key_origin = key - setattr(self, key, con) - return con.data + self.conns.add(con) + return con def _set_shapes( self, @@ -242,7 +1087,7 @@ def _set_shapes( model = self._get_outermost_parent() used_keys: dict[str | int, ShapeType] = {} - shape_nodes: dict[str | Connection, tuple[ShapeNode, str]] = {} + shape_nodes: dict[str | ConnectionData, tuple[ShapeNode, str]] = {} # TODO: Can this be refactored to use a single loop? for key, shape in chain(shapes.items(), kwargs.items()): metadata = self.conns.extract_metadata(key) @@ -270,6 +1115,46 @@ def _set_shapes( model.constraint_solver(updates) + def _set_types( + self, + config: Mapping[ + str | ConnectionData, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + | Mapping[ + ConnectionData, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + | Mapping[ + str, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + | None = None, + **kwargs: type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. + if config is None: + config = {} + + assigned_types: dict[ + str, + type | UnionType | ScalarType | Tensor[int | float | bool], + ] = {} + + # Get the outermost parent as all the updates will happen here. + model = self._get_outermost_parent() + updates = Updates() + for key, key_type in chain(config.items(), kwargs.items()): + metadata = self.conns.extract_metadata(key) + conn = self.conns.get_con_by_metadata(metadata) + assert conn is not None + inner_key = conn.key + assigned_types[inner_key] = key_type + updates |= metadata.set_type(key_type) + # Store assigned types in the model. + self.assigned_types |= assigned_types + # Run the constraints for updating affected connections. + model.constraint_solver(updates) + def _set_value( self, key: ConnectionData, @@ -293,19 +1178,12 @@ def _set_value( # Data is scalar, set the value directly. return key.metadata.set_value(value) - def set_shapes( - self, config: ShapesType | None = None, **kwargs: ShapeTemplateType - ) -> None: - if config is None: - config = {} - self._set_shapes(config, trace=True, updates=None, **kwargs) - - def set_values( + def _set_values( self, config: Mapping[ - str | Connection, Tensor[int | float | bool] | MainValueType | str + str | ConnectionData, Tensor[int | float | bool] | MainValueType | str ] - | Mapping[Connection, Tensor[int | float | bool] | MainValueType | str] + | Mapping[ConnectionData, Tensor[int | float | bool] | MainValueType | str] | Mapping[str, Tensor[int | float | bool] | MainValueType | str] | None = None, **kwargs: Tensor[int | float | bool] | MainValueType | str, @@ -344,63 +1222,6 @@ def set_values( # Solve constraints with the updated values. model.constraint_solver(updates) - def set_types( - self, - config: Mapping[ - str | Connection, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] - | Mapping[ - Connection, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] - | Mapping[ - str, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] - | None = None, - **kwargs: type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ) -> None: - """ - Set types of any connection in the Model - - This method updates types in given connections. - connections can be given either as Connection or their string - equivalent. Giving a valid type for given connections, this method - will update the connections's types and thereafter runs the - constraints to update affected connections' types. - - Args: - values (dict[str | Connection, MainValueType]): A dictionary where - keys are either strings or Connection objects, and values are - of type of type or UnionType objects. - - """ - if config is None: - config = {} - # Initialize assigned shapes dictionary to store assigned shapes. - assigned_types: dict[ - str, - type | UnionType | ScalarType | Tensor[int | float | bool], - ] = {} - - # Get the outermost parent as all the updates will happen here. - model = self._get_outermost_parent() - updates = Updates() - for key, key_type in chain(config.items(), kwargs.items()): - metadata = self.conns.extract_metadata(key) - conn = self.conns.get_con_by_metadata(metadata) - assert conn is not None - inner_key = conn.key - if key_type is Tensor: - key_type = Tensor[int | float | bool] - assigned_types[inner_key] = key_type - updates |= metadata.set_type(key_type) - # Store assigned types in the model. - self.assigned_types |= assigned_types - # Run the constraints for updating affected connections. - model.constraint_solver(updates) - def get_shapes( self, uni_keys: dict[UniadicRecord, str] | None = None, @@ -455,25 +1276,25 @@ def add_constraint( return self._add_constraint(fn, keys, type, dependencies) @property - def cin(self) -> Connection: + def _cin(self) -> ConnectionData: if (cin_len := len(self.conns.cins)) != 1: raise KeyError( f"Currently, there exists {cin_len} canonical inputs, model " "should have exactly one canonical input!" ) - return next(iter(self.conns.cins)).conn + return next(iter(self.conns.cins)) @property - def cout(self) -> Connection: + def _cout(self) -> ConnectionData: if (cout_len := len(self.conns.couts)) != 1: raise KeyError( f"Currently, there exists {cout_len} canonical outputs, model " "should have exactly one canonical output!" ) - return next(iter(self.conns.couts)).conn + return next(iter(self.conns.couts)) - def set_cin(self, *connections: str | Connection, safe: bool = True) -> None: + def _set_cin(self, *connections: str | ConnectionData, safe: bool = True) -> None: self.conns.cins = set() for given_conn in connections: conn = self.conns.get_extracted_connection(given_conn) @@ -493,7 +1314,7 @@ def set_cin(self, *connections: str | Connection, safe: bool = True) -> None: else: self.conns.cins.add(conn) - def set_cout(self, *connections: str | Connection, safe: bool = True) -> None: + def _set_cout(self, *connections: str | ConnectionData, safe: bool = True) -> None: self.conns.couts = set() for given_conn in connections: conn = self.conns.get_extracted_connection(given_conn) @@ -544,34 +1365,6 @@ def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: updates = left.match(right) return updates - def get_models_in_topological_order(self) -> list[BaseModel]: - dependency_map = self.dependency_map.local_output_dependency_map - graph = { - info[0]: OrderedSet( - [dependency_map[spec][0] for spec in info[1] if spec in dependency_map] - ) - for info in dependency_map.values() - } - top_order: list[BaseModel] = list() - visited: set[BaseModel] = set() - for model in graph: - if model not in top_order: - BaseModel._reverse_dfs(model, graph, top_order, visited) - return top_order - - @staticmethod - def _reverse_dfs( - node: BaseModel, - graph: dict[BaseModel, OrderedSet[BaseModel]], - top_order: list[BaseModel], - visited: set[BaseModel], - ) -> None: - visited.add(node) - for m in graph[node]: - if m not in visited: - BaseModel._reverse_dfs(m, graph, top_order, visited) - top_order.append(node) - class DependencyMap: """ diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index fbd9c14b..2e9187af 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -14,23 +14,22 @@ from __future__ import annotations -from collections.abc import Mapping -from types import UnionType +from collections.abc import KeysView, Mapping, Sequence +from dataclasses import dataclass +from types import EllipsisType, UnionType from typing import Any, Self -from ...utils.utils import OrderedSet, find_dominant_type +from ...core import Dtype as CoreDtype +from ...utils.utils import find_dominant_type from ..common import ( NOT_GIVEN, TBD, - Connection, + BaseKey, ConnectionData, - ConnectionInstanceType, - ConnectionType, - ExtendTemplate, + ConnectionDataType, IOHyperEdge, - IOKey, - KeyType, MainValueInstance, + MainValueType, NullConnection, ScalarType, ScalarValueType, @@ -38,292 +37,541 @@ Tensor, ToBeDetermined, UniadicRecord, - Updates, Variadic, get_summary, get_summary_shapes, get_summary_types, ) -from .base import BaseModel, ExtendInfo -from .essential_primitives import ( - Absolute, - Add, - Cast, - Cosine, - Divide, - Dtype, - Equal, - Exponential, - FloorDivide, - Greater, - GreaterEqual, - Indexer, - Item, - Length, - Less, - LessEqual, - LogicalAnd, - LogicalNot, - LogicalOr, - LogicalXOr, - MatrixMultiply, - Max, - Mean, - Min, - Minus, - Multiply, - NotEqual, - Power, - Prod, - Reshape, - Shape, - ShiftLeft, - ShiftRight, - Sine, - Size, - Slice, - Split, - Sqrt, - Subtract, - Sum, - TensorToList, - ToList, - ToTensor, - ToTuple, - Transpose, - Variance, + +# from .base import BaseModel, ConnectionDataType +from .base import BaseModel +from .operator import Operator +from .operators import ( + AbsoluteOp, + AddOp, + CastOp, + CosineOp, + DivideOp, + DtypeOp, + EqualOp, + ExponentialOp, + FloorDivideOp, + GreaterEqualOp, + GreaterOp, + IndexerOp, + ItemOp, + LengthOp, + LessEqualOp, + LessOp, + LogicalAndOp, + LogicalNotOp, + LogicalOrOp, + LogicalXOrOp, + MatrixMultiplyOp, + MaxOp, + MeanOp, + MinOp, + MinusOp, + MultiplyOp, + NotEqualOp, + PowerOp, + ProdOp, + ReshapeOp, + ShapeOp, + ShiftLeftOp, + ShiftRightOp, + SineOp, + SizeOp, + SliceOp, + SplitOp, + SqrtOp, + SubtractOp, + SumOp, + ToListOp, + ToTensorOp, + ToTupleOp, + TransposeOp, + VarianceOp, ) -from .primitive import PrimitiveModel - -__all__ = ["Model"] - -ops_table: dict[str, type[PrimitiveModel]] = { - "add": Add, - "sub": Subtract, - "div": Divide, - "fdiv": FloorDivide, - "mul": Multiply, - "pow": Power, - "matmul": MatrixMultiply, - "shape": Shape, - "reshape": Reshape, - "len": Length, - "size": Size, - "tensor": ToTensor, - "list": TensorToList, - "item": Item, - "indexer": Indexer, - "mean": Mean, - "sqrt": Sqrt, - "exp": Exponential, - "sum": Sum, - "max": Max, - "min": Min, - "abs": Absolute, - "prod": Prod, - "var": Variance, - "gt": Greater, - "ge": GreaterEqual, - "lt": Less, - "le": LessEqual, - "eq": Equal, - "ne": NotEqual, - "not": LogicalNot, - "and": LogicalAnd, - "or": LogicalOr, - "xor": LogicalXOr, - "lshift": ShiftLeft, - "rshift": ShiftRight, - "minus": Minus, - "transpose": Transpose, - "split": Split, - "slice": Slice, - "to_tuple": ToTuple, - "cast": Cast, - "dtype": Dtype, - "sin": Sine, - "cos": Cosine, -} +__all__ = [ + "Connection", + "IOKey", + "ExtendTemplate", + "ExtendInfo", + "Model", + "ConnectionType", + "ConnectionInstanceType", + "TemplateConnectionType", + "define_unique_names", +] + + +class TemplateBase: + def __getitem__( + self, + key: slice + | int + | EllipsisType + | tuple[slice | int | None | EllipsisType | TemplateBase, ...] + | IOKey + | TemplateBase + | None, + ) -> ExtendTemplate: + match key: + case slice(): + slice_output = ExtendTemplate( + connections=[key.start, key.stop, key.step], model=SliceOp + ) + output = ExtendTemplate( + connections=[self, slice_output], model=IndexerOp + ) -class Model(BaseModel): - def __init__( + case int() | EllipsisType() | None: + output = ExtendTemplate(connections=[self, key], model=IndexerOp) + + case tuple(): + connections: list[TemplateBase | int | None | EllipsisType] = [] + for item in key: + if isinstance(item, slice): + slice_output = ExtendTemplate( + connections=[item.start, item.stop, item.step], + model=SliceOp, + ) + connections.append(slice_output) + else: + connections.append(item) + tuple_template = ExtendTemplate( + connections=connections, + model=ToTupleOp, + defaults={"n": len(key)}, + ) + output = ExtendTemplate( + connections=[self, tuple_template], model=IndexerOp + ) + return output + + def __add__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=AddOp) + + def __radd__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=AddOp) + + def __sub__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=SubtractOp) + + def __rsub__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=SubtractOp) + + def __mul__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=MultiplyOp) + + def __rmul__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=MultiplyOp) + + def __truediv__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=DivideOp) + + def __rtruediv__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=DivideOp) + + def __floordiv__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=FloorDivideOp) + + def __rfloordiv__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=FloorDivideOp) + + def __pow__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate( + connections=[self, other], model=PowerOp, defaults={"robust": False} + ) + + def __rpow__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate( + connections=[other, self], model=PowerOp, defaults={"robust": False} + ) + + def __matmul__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=MatrixMultiplyOp) + + def __gt__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=GreaterOp) + + def __rgt__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=GreaterOp) + + def __ge__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=GreaterEqualOp) + + def __rge__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=GreaterEqualOp) + + def __lt__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=LessOp) + + def __rlt__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=LessOp) + + def __le__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=LessEqualOp) + + def __rle__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=LessEqualOp) + + def __eq__(self, other: object) -> ExtendTemplate: # type: ignore[override] + if isinstance( + other, int | float | bool | list | Connection | IOKey | tuple | Tensor + ): + return ExtendTemplate(connections=[self, other], model=EqualOp) + else: + raise ValueError("Unsupported type for equality operation.") + + def __req__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=EqualOp) + + def __ne__(self, other: object) -> ExtendTemplate: # type: ignore[override] + if isinstance( + other, int | float | bool | list | Connection | IOKey | tuple | Tensor + ): + return ExtendTemplate(connections=[self, other], model=NotEqualOp) + else: + raise ValueError("Unsupported type for equality operation.") + + def __rne__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=NotEqualOp) + + def __and__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=LogicalAndOp) + + def __rand__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=LogicalAndOp) + + def __or__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=LogicalOrOp) + + def __ror__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=LogicalOrOp) + + def __xor__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=LogicalXOrOp) + + def __rxor__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=LogicalXOrOp) + + def __lshift__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=ShiftLeftOp) + + def __rlshift__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=ShiftLeftOp) + + def __rshift__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[self, other], model=ShiftRightOp) + + def __rrshift__(self, other: TemplateConnectionType) -> ExtendTemplate: + return ExtendTemplate(connections=[other, self], model=ShiftRightOp) + + def __invert__(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=LogicalNotOp) + + def __neg__(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=MinusOp) + + def abs(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=AbsoluteOp) + + def len(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=LengthOp) + + @property + def shape(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=ShapeOp) + + def reshape( + self, shape: tuple[int | TemplateBase, ...] | TemplateBase + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, shape], model=ReshapeOp) + + def size( + self, dim: int | tuple[int, ...] | TemplateBase | None = None + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, dim], model=SizeOp) + + def tensor(self) -> ExtendTemplate: + return ExtendTemplate( + connections=[self], model=ToTensorOp, defaults={"dtype": None} + ) + + def mean( self, - name: str | None = None, - enforce_jit: bool = True, - ) -> None: - self.dag: dict[BaseModel, dict[str, ConnectionData]] = {} - self.inter_key_count: int = 0 - self.formula_key: str | None = None + axis: int | tuple[int, ...] | TemplateBase | None = None, + keepdim: bool = False, + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, axis, keepdim], model=MeanOp) - super().__init__(name=name, enforce_jit=enforce_jit) + def sum( + self, + axis: int | tuple[int, ...] | TemplateBase | None = None, + keepdim: bool = False, + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, axis, keepdim], model=SumOp) - def create_key_name(self) -> str: - self.inter_key_count += 1 - return "$" + str(self.inter_key_count) + def max( + self, + axis: int | tuple[int, ...] | TemplateBase | None = None, + keepdim: bool = False, + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, axis, keepdim], model=MaxOp) - def create_connection( - self, metadata: IOHyperEdge, key: str | None = None - ) -> ConnectionData: - # If key is not provided, create a new key name and - # label it as auto-generated. - if is_key_autogenerated := key is None: - key = self.create_key_name() + def min( + self, + axis: int | tuple[int, ...] | TemplateBase | None = None, + keepdim: bool = False, + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, axis, keepdim], model=MinOp) - con = self._create_connection(metadata, key, is_key_autogenerated) + def prod( + self, + axis: int | tuple[int, ...] | TemplateBase | None = None, + keepdim: bool = False, + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, axis, keepdim], model=ProdOp) - if not is_key_autogenerated: - # Set key_origin into metadata - metadata.key_origin = key - setattr(self, key, con) + def var( + self, + axis: int | tuple[int, ...] | TemplateBase | None = None, + keepdim: bool = False, + correction: float | None = 0.0, + ) -> ExtendTemplate: + return ExtendTemplate( + connections=[self, axis, keepdim, correction], model=VarianceOp + ) - return con.data + def sqrt(self) -> ExtendTemplate: + return ExtendTemplate( + connections=[self], model=SqrtOp, defaults={"robust": False} + ) - def set_outputs(self, *args: str | Connection, **kwargs: str | Connection) -> None: - if self.parent is not None: - raise Exception("Child model's outputs cannot be set.") - # Convert all args and kwargs to tuple. - # Convert all args and kwargs to tuple. - pairs = tuple([(None, arg) for arg in args]) + tuple(kwargs.items()) + def exp(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=ExponentialOp) - for pair in pairs: - new_name, name = pair - metadata = self.conns.extract_metadata(name) + def transpose( + self, axes: tuple[int, ...] | TemplateBase | None = None + ) -> ExtendTemplate: + return ExtendTemplate(connections=[self, axes], model=TransposeOp) - # Check the connection is valid. - if (conn_data := self.conns.get_con_by_metadata(metadata)) is None: - raise KeyError("Requires valid key or Connection to set output!") + def split(self, split_size: int, axis: int) -> ExtendTemplate: + return ExtendTemplate(connections=[self, split_size, axis], model=SplitOp) - # Check if given metadata is already an output. - if conn_data in self.conns.output_connections: - raise KeyError(f"'{conn_data.key}' key is already set as output!") + def item(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=ItemOp) - if conn_data in self.conns.input_connections: - raise KeyError("Input of the overall model cannot be set as output.") + def cast(self, dtype: CoreDtype | None = None) -> ExtendTemplate: + return ExtendTemplate(connections=[self, dtype], model=CastOp) - # Autogenerated keys can not be set directly as output without a name. - if new_name is None and conn_data.key.startswith("$"): - raise KeyError( - "Autogenerated keys can only be set as output if" - " a name is provided for the connection as keyworded argument." - ) + def dtype(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=DtypeOp) - if new_name is None: # Non-named connections. - # Set connection as output and update dependency map. - self.conns.set_connection_type(conn_data, KeyType.OUTPUT) - self.dependency_map.update_globals(OrderedSet({conn_data})) + def sin(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=SineOp) - else: # Named connections. - # Create new output connection with given key name. - # TODO: Update here to use directly set_name method of Connections class - # after it is implemented. - edge = IOHyperEdge(metadata.edge_type) + def cos(self) -> ExtendTemplate: + return ExtendTemplate(connections=[self], model=CosineOp) - new_conn = self.create_connection(edge, new_name) - # Set connection as output and update dependency map. - self.conns.set_connection_type(new_conn, KeyType.OUTPUT) +class Connection(TemplateBase): + def __init__(self, data: ConnectionData) -> None: + self.data = data - # Merge new_conn with given connection. - self.merge_connections(new_conn, conn_data) + @property + def key(self) -> str: + return self.data.key - def _set_formula_key(self, formula_key: str) -> None: - self.formula_key = formula_key + @property + def metadata(self) -> IOHyperEdge: + return self.data.metadata - def _check_multi_write( + def set_differentiable(self, differentiable: bool = True) -> None: + self.data.set_differentiable(differentiable) + + def __hash__(self) -> int: + return hash(id(self)) + + +class IOKey(BaseKey, TemplateBase): + def __init__( self, - local_input: bool, - local_connection: ConnectionData, - connection: ConnectionData, + name: str | None = None, + value: Tensor[int | float | bool] + | ScalarValueType + | ToBeDetermined + | str = TBD, + shape: ShapeTemplateType | None = None, + type: UnionType + | type + | type[Tensor[int | float | bool]] + | ScalarType + | None = None, + expose: bool | None = None, + interval: list[float | int] | None = None, + connections: set[Connection | str] | None = None, ) -> None: - conn_is_output = ( - self.dependency_map.local_output_dependency_map.get(connection, None) - is not None + _connections: set[ConnectionData | str] = { + con.data if isinstance(con, Connection) else con + for con in connections or set() + } + + super().__init__( + name=name, + value=value, + shape=shape, + type=type, + expose=expose, + interval=interval, + connections=_connections, ) - if local_connection.key in self.conns.all and connection.key in self.conns.all: - local_conn_is_output = ( - self.dependency_map.local_output_dependency_map.get( - local_connection, None - ) - is not None - ) - if ( - conn_is_output - and local_conn_is_output - and local_connection.key != connection.key - ): - # Check if 2 connections are part of main model. If it is the case, - # We expect at least one of them is not an input of the main model, - # otherwise condition is Multi-write error - raise Exception( - "Given connections are both output connections. Multi-write error!" + + +class ExtendTemplate(TemplateBase): + output_connection: ConnectionData | None + + def __init__( + self, + connections: Sequence[TemplateConnectionType], + model: type[BaseModel], + defaults: dict[str, Any] | None = None, + ) -> None: + for connection in connections: + if isinstance(connection, str): + raise ValueError( + "In extend template operations, 'str' is not a valid type." ) - local_val = local_connection.metadata.value - global_val = connection.metadata.value + self.connections = connections + self.model = model + + if defaults is None: + defaults = {} + self.defaults = defaults + self.output_connection = None + + +@dataclass +class ExtendInfo: + _model: BaseModel + _connections: dict[str, ConnectionType] + + def __post_init__(self) -> None: + external_keys = ( + set(self._model.external_keys) + | {item.key for item in self._model.conns.couts} + | {item.key for item in self._model.conns.cins} + ) + + for key in self._connections: + if key not in external_keys: + raise KeyError(f"Key '{key}' is not a valid key for the model!") + + @property + def model(self) -> BaseModel: + return self._model + + @property + def connections(self) -> dict[str, ConnectionType]: + return self._connections + + +TemplateConnectionType = ( + TemplateBase + | int + | float + | list[int | float] + | EllipsisType + | tuple[slice | int | None | EllipsisType | TemplateBase, ...] + | None + | Tensor[int | float | bool] +) + +ConnectionType = ( + str + | MainValueType + | ExtendTemplate + | NullConnection + | IOKey + | Connection + | Tensor[int | float | bool] +) + +ConnectionInstanceType = ( + str + | MainValueInstance + | ExtendTemplate + | NullConnection + | IOKey + | Connection + | Tensor # type: ignore +) - if conn_is_output and not local_input: - # Check if 2 connections are both output of any models. - raise Exception( - "Given connections are both output connections. Multi-write error!" - ) - elif ( - local_input - and local_val is not TBD - # and global_val is not TBD - and conn_is_output - and global_val != local_val - ): - raise ValueError( - "An input of the extending model tries to write " - "to an output connection in the extended model. " - "Multi-write error!" - ) - elif not local_input and global_val is not TBD and local_val != global_val: - raise ValueError( - "A valued connection of the extended model tries to write " - "to an output connection of the extending model. " - "Multi-write error!" - ) - def _convert_to_iokey( - self, model: BaseModel, key: str, connection: ConnectionType - ) -> IOKey: +class Model(BaseModel): + def __init__( + self, + name: str | None = None, + formula_key: str | None = None, + enforce_jit: bool = True, + ) -> None: + super().__init__(name, formula_key, enforce_jit) + self.connection_map: dict[ConnectionData, Connection] = {} + + def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: + return ExtendInfo(self, kwargs) + + def _create_connection( + self, metadata: IOHyperEdge, key: str | None = None + ) -> ConnectionData: + con_data = super()._create_connection(metadata, key) + con = Connection(con_data) + self.connection_map[con_data] = con + if not con_data.is_key_autogenerated: + assert key is not None + setattr(self, key, con) + return con_data + + # TODO: Refactor _prepare_keys / _unroll_template relation. + def _prepare_keys( + self, + model: BaseModel, + key: str, + connection: ConnectionDataType | ConnectionType, + ) -> BaseKey: local_connection = model.conns.get_connection(key) assert local_connection is not None, "Connection is not found!" + _connection: BaseKey | ConnectionData | MainValueInstance | NullConnection | str match connection: - case NullConnection(): - _connection = IOKey() - case str(): - _connection = IOKey(name=connection) case Connection(): - _connection = IOKey(connections={connection}) + _connection = connection.data case ExtendTemplate(): # Unroll ExtendTemplate con_data = self._unroll_template(connection) - _connection = IOKey(connections={con_data.conn}, expose=False) - case _ if isinstance(connection, MainValueInstance | Tensor): + _connection = BaseKey(connections={con_data}, expose=False) + case _ if isinstance( + connection, MainValueInstance | Tensor + ) and not isinstance(connection, str): # find_dominant_type returns the dominant type in a container. # If a container has a value of type Connection or ExtendTemplate # we add necessary models. - if isinstance(connection, tuple | list) and find_dominant_type( - connection, raise_error=False - ) in [ConnectionData, ExtendTemplate, Connection, IOKey]: - kwargs = { - f"input{idx + 1}": item for idx, item in enumerate(connection) - } - connection_model = ( - ToTuple if isinstance(connection, tuple) else ToList - ) - conv_model = connection_model(n=len(connection)) - self.extend(conv_model, **kwargs) + types = [ConnectionData, ExtendTemplate, Connection, IOKey] + if ( + isinstance(connection, tuple | list) + and find_dominant_type(connection, False) in types + ): + _model = ToTupleOp if isinstance(connection, tuple) else ToListOp + et = ExtendTemplate(connection, _model, {"n": len(connection)}) + con_data = self._unroll_template(et) + _connection = BaseKey(connections={con_data}, expose=False) - result = conv_model.conns.get_connection("output") - assert result is not None - _connection = IOKey(connections={result.conn}, expose=None) else: assert isinstance(connection, MainValueInstance | Tensor) - _connection = IOKey(value=connection) + _connection = BaseKey(value=connection) case IOKey(): expose = connection.expose name = connection.name @@ -336,193 +584,31 @@ def _convert_to_iokey( and connection.connections == set() ): expose = True - _connection = IOKey( + _connection = BaseKey( name=name, expose=expose, connections=connection.connections, - type=connection.data.type, - shape=connection.data.shape, - value=connection.data.value, + type=connection.type, + shape=connection.value_shape, + value=connection.value, ) + case _: + _connection = connection # type: ignore + return super()._prepare_keys(model, key, _connection) - return _connection - - def update_key_name(self, connection: ConnectionData, key: str) -> None: - for key_type in KeyType: - key_dict = self.conns._connection_dict[key_type] - if key_dict.get(connection.key) is not None: - key_dict[key] = key_dict.pop(connection.conn.key) - # Update connection key - connection.key = key - connection.is_key_autogenerated = False - connection.metadata.key_origin = key - setattr(self, key, connection.conn) - break - - def _add_connection( - self, - model: BaseModel, - local_key: str, - given_connection: IOKey, - updates: Updates, - ) -> tuple[ConnectionData, Updates]: - is_input = local_key in model.input_keys - local_connection = model.conns.get_connection(local_key) - assert local_connection is not None, "Connection is not found!" - is_not_valued = local_connection.metadata.value is TBD - - d_map = self.dependency_map.local_output_dependency_map - expose = given_connection.expose - outer_key = given_connection.name - con_obj = None - set_value: ( - ToBeDetermined - | str - | ScalarValueType - | Tensor[int | float | bool] - | NullConnection - ) = NOT_GIVEN - if given_connection.data.value is not TBD: - set_value = given_connection.data.value - - if given_connection.connections == set(): - if outer_key is not None: - con_obj = self.conns.get_connection(outer_key) - if outer_key is None or con_obj is None: - if expose is None and is_input and is_not_valued: - expose = True - con_obj = self.create_connection(local_connection.metadata, outer_key) - if ( - expose is False - and is_input - and set_value is NOT_GIVEN - and local_connection.metadata.value is TBD - and (con_obj is None or con_obj not in d_map) - ): - raise ValueError( - "Expose flag cannot be false when " - "no value is provided for input keys!" - ) + def _get_conn_data(self, conn: str | ConnectionData) -> ConnectionData: + if isinstance(conn, str): + _conn = self.conns.get_connection(conn) else: - initial_conn: ConnectionData - for idx, conn in enumerate(given_connection.connections): - if isinstance(conn, str): - _conn = self.conns.get_connection(conn) - else: - _conn = self.conns.get_con_by_metadata(conn.data.metadata) - if conn.data in model.conns.all.values(): - raise ValueError( - f"Given connection '{conn.data.key}' should not " - "belong to the extending model!" - ) - - if not isinstance(_conn, ConnectionData): - raise KeyError("Requires accessible connection to be processed!") - if idx == 0: - initial_conn = _conn - if outer_key is not None: - self.update_key_name(initial_conn, outer_key) - else: - if _conn in d_map: - if initial_conn in d_map: - raise KeyError( - "IOKey object can not have more than one output " - "connection. Multi-write error!" - ) - initial_conn, _conn = _conn, initial_conn - if ( - not outer_key - and not initial_conn.is_key_autogenerated - and not _conn.is_key_autogenerated - ): - raise KeyError( - "Requires a connection to have only one unique key " - "name but encountered more!" - ) - updates |= self.merge_connections(initial_conn, _conn) - if outer_key is None and not initial_conn.is_key_autogenerated: - outer_key = initial_conn.key - if not outer_key and initial_conn in d_map and expose is True: - raise KeyError("Connection without a name cannot be set as output") - con_obj = initial_conn - - # Name "input" can only be used for input connections. - is_key_name_input = con_obj is not None and (con_key := con_obj.key) == "input" - if not is_input and (outer_key == "input" or is_key_name_input): - raise KeyError( - "The key 'input' is a reserved key which could not be used for " - "internal keys." - ) - - if not is_input and not isinstance(set_value, NullConnection): - raise KeyError( - f"{local_key} key is an output of the model, output values could " - "not be set in extend." - ) + _conn = self.conns.get_con_by_metadata(conn.metadata) + assert isinstance(_conn, ConnectionData) + return _conn - # Inherit submodel connections dict - self.conns.connections_dict.setdefault(local_connection.metadata, set()) - self.conns.connections_dict[local_connection.metadata] |= ( - model.conns.connections_dict.pop(local_connection.metadata, set()) - ) - - # If any value provided, set. - assert con_obj is not None - if not isinstance(set_value, NullConnection): - updates |= con_obj.metadata.set_value(set_value) - - # Check multi-write error for con_obj. - self._check_multi_write(is_input, local_connection, con_obj) - - # If match required, perform. - if con_obj.metadata != local_connection.metadata: - local_key_origin = local_connection.metadata.key_origin - updates |= self._match_hyper_edges( - con_obj.metadata, local_connection.metadata - ) - # If local_connection is an output of the model, - # update con_obj "key_origin" with local_connection's key_origin. - if ( - not is_input - and outer_key not in self.conns.output_keys - or con_obj.metadata.key_origin is None - ): - con_obj.metadata.key_origin = local_key_origin - - unexposed = not (expose or (is_input and con_key in self.conns.io_keys)) - if unexposed: - if is_input: - key_type = (KeyType.LATENT_INPUT, KeyType.INTERNAL)[con_obj in d_map] - else: - key_type = (KeyType.LATENT_OUTPUT, KeyType.INTERNAL)[con_obj in d_map] - else: - key_type = (KeyType.OUTPUT, KeyType.INPUT)[ - is_input and con_obj not in d_map - ] - if con_obj in d_map: - self.conns.couts.discard(con_obj) - self.conns.set_connection_type(con_obj, key_type) - # Update Canonicals - if ( - local_connection in model.conns.cins - and con_obj in self.conns.input_connections - and con_obj.metadata.value is TBD - ): - self.conns.cins.add(con_obj) - - if local_connection in model.conns.couts and ( - con_obj not in self.dependency_map.local_input_dependency_map - or con_obj in self.conns.output_connections - ): - self.conns.couts.add(con_obj) - - # If any type provided, set using models set_types method - # in order to execute constraint solver to propagate type - # updates along the model keys. - if (set_type := given_connection.data.type) is not None: - model.set_types({local_connection.conn: set_type}) - - return con_obj, updates + def update_key_name(self, connection: ConnectionData, key: str) -> None: + super().update_key_name(connection, key) + con_data = self.conns.get_extracted_connection(connection) + conn = self.connection_map[con_data] + setattr(self, key, conn) def _unroll_template(self, template: ExtendTemplate) -> ConnectionData: if template.output_connection is None: @@ -531,223 +617,34 @@ def _unroll_template(self, template: ExtendTemplate) -> ConnectionData: # given connections to the model after it is created. # If we don't do that, it will throw error because of # re-setting a Tensor or Scalar value again in extend. - model_type = ops_table[template.model] # TODO: Remove all TBD if default init arguments will be moved to call!!! - init_fun = model_type.__init__ + code = template.model.__init__.__code__ # "self" argument is common for all models, Exclude it by # starting co_varnames from 1st index. - default_args = init_fun.__code__.co_varnames[ - 1 : init_fun.__code__.co_argcount - ] - default_args_dict = {key: TBD for key in default_args} - default_args_dict |= template.defaults + default_args = code.co_varnames[1 : code.co_argcount] + default_args_dict = {key: TBD for key in default_args} | template.defaults default_args_dict.pop("name", None) # TODO: Reconsider type ignore! - model: PrimitiveModel = model_type(**default_args_dict) # type: ignore - connections: list[ConnectionType] = [] - for connection in template.connections: - if isinstance(connection, ExtendTemplate): - connections.append(self._unroll_template(connection).conn) - else: - assert isinstance( - connection, ConnectionInstanceType - ) # TODO: check if needed - connections.append(connection) - self.extend( - model, - **{ - local_key: outer_con - for local_key, outer_con in zip( - model.input_keys, connections, strict=False - ) - }, - ) + model: Operator = template.model(**default_args_dict) # type: ignore + keys = { + local_key: self._prepare_keys(model, local_key, outer_con) # type: ignore + for local_key, outer_con in zip( + model.input_keys, template.connections, strict=False + ) + } + self.extend(model, **keys) template.output_connection = model.conns.get_connection("output") assert template.output_connection is not None return template.output_connection - def merge_connections( - self, connection1: ConnectionData, connection2: ConnectionData - ) -> Updates: - # This method is used if there is 2 Connection objects to represent same Edge. - # In this case, connection2 is updated with connection1's data and it is removed - # from dag, dependency_map, self attribute (if exists) and Connections object. - - # TODO: Check multi-write error for Connect type. - - main_connection1 = self.conns.get_con_by_metadata(connection1.metadata) - main_connection2 = self.conns.get_con_by_metadata(connection2.metadata) - - if ( - main_connection1 is None - or main_connection2 is None - or main_connection1 == main_connection2 - ): - return Updates() - - # Remove main_connection2 from connections dict - con1_key = main_connection1.key - - if connection2 in self.conns.output_connections: - if con1_key not in self.conns.output_keys: - self.conns.set_connection_type(connection1, KeyType.OUTPUT) - if con1_key in self.input_keys: - self.conns.set_connection_type(main_connection1, KeyType.INTERNAL) - elif ( - main_connection2 in self.conns.internal_connections - and con1_key in self.input_keys - ): - self.conns.set_connection_type(main_connection1, KeyType.INTERNAL) - - # Switch all connection2 objects with connection1 object in current dag. - for m, m_info in self.dag.items(): - local_conns = m.conns.get_cons_by_metadata(main_connection2.metadata) - if local_conns is None: - continue - - for local_conn in local_conns: - if m_info.get(local_conn.key) is not None: - self.dag[m][local_conn.key] = main_connection1 - - # Update dependecy map, we need to update only local maps - for ( - o_conn, - key_info, - ) in self.dependency_map.local_output_dependency_map.items(): - if main_connection2 in key_info[1]: - self.dependency_map.local_output_dependency_map[o_conn][1].remove( - main_connection2 - ) - self.dependency_map.local_output_dependency_map[o_conn][1].add( - main_connection1 - ) - - if main_connection2 in self.dependency_map.local_output_dependency_map: - self.dependency_map.local_output_dependency_map[main_connection1] = ( - self.dependency_map.local_output_dependency_map.pop(main_connection2) - ) - - if main_connection2 in self.dependency_map.local_input_dependency_map: - old_dependencies = self.dependency_map.local_input_dependency_map.pop( - main_connection2 - ) - self.dependency_map.local_input_dependency_map.setdefault( - main_connection1, old_dependencies - ) - for dependecy in old_dependencies: - if ( - dependecy - not in self.dependency_map.local_input_dependency_map[ - main_connection1 - ] - ): - self.dependency_map.local_input_dependency_map[ - main_connection1 - ].append(dependecy) - - self.dependency_map.merge_global_connections(main_connection1, main_connection2) - self.dependency_map.merge_global_caches(main_connection1, main_connection2) - updates = self._match_hyper_edges( - main_connection1.metadata, main_connection2.metadata - ) - - self.conns.remove_connection(main_connection2) - - main_connection2.key = main_connection1.key - main_connection2.is_key_autogenerated = main_connection1.is_key_autogenerated - return updates - - def extend( - self, - model: Model | PrimitiveModel | BaseModel, - **kwargs: ConnectionType, - ) -> None: - # Check possible errors before the extension. - model.check_extendability() - if self.parent is not None: - raise AttributeError("Child model could not be re-extended!") - if self == model: - raise KeyError("Model can not extend with itself!") - if self._enforce_jit and not model.jittable: - raise Exception( - "Model with enforced Jit can not be extended by a non-jittable model! \ - Jit can be unforced by setting enforce_jit = False" - ) - if model.name is not None: - # TODO: We could store model names in a set to check if it is unique. - for m in self.dag: - if m.name == model.name: - raise KeyError(f"Model already has a submodel named {model.name}.") - - model.parent = self - # Freeze the model. - model._freeze() - - updates = Updates() - - shape_info: dict[str, ShapeTemplateType] = {} - type_info: dict[ - str, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] = {} - - submodel_dag: dict[str, ConnectionData] = {} - updates = self.constraint_solver.match(model.constraint_solver) - - # Add canonical output if it is not in external_keys - external_keys = list(model.external_keys) - external_keys += [ - item.key for item in model.conns.couts if item.key not in external_keys - ] - - io_keys: dict[str, IOKey] = { - key: self._convert_to_iokey(model, key, kwargs.get(key, NOT_GIVEN)) - for key in external_keys - } - - for local_key, value in io_keys.items(): - if value.data.shape is not None: - shape_info |= {local_key: value.data.shape} - - if value.data.type is not None: - type_info[local_key] = value.data.type - - con_obj, _updates = self._add_connection(model, local_key, value, updates) - updates |= _updates - submodel_dag[local_key] = con_obj - if con_obj.metadata.is_tensor: - updates.shape_updates.add(con_obj.metadata) - - # Replace shape info keys, which are local keys, with global equivalents. - shape_info = { - submodel_dag[key].key: template for key, template in shape_info.items() - } - type_info = { - submodel_dag[key].key: template for key, template in type_info.items() - } - - # Set given shapes. - self._set_shapes( - shape_info, - updates=updates, - ) # TODO: Should "trace" be set to True?. - - model.constraint_solver.clear() - model.conns.connections_dict = {} - - # Insert to self dag as a FrozenDict."" - # Since we update dag in merge_connections, we could not use FrozenDict. - self.dag[model] = model_dag = submodel_dag - - self.dependency_map.add_model_dag(model, model_dag) - - # Update jittablity by using model's jittablity. - self._jittable &= model.jittable - - def _extend(self, model: BaseModel, kwargs: dict[str, ConnectionType]) -> Self: + def _extend( + self, model: BaseModel, kwargs: dict[str, ConnectionType] | None = None + ) -> Self: + if kwargs is None: + kwargs = {} if self.is_frozen: raise AttributeError("Model is frozen and can not be extended!") @@ -767,10 +664,30 @@ def _extend(self, model: BaseModel, kwargs: dict[str, ConnectionType]) -> Self: else: kwargs[key] = _value - self.extend(model, **kwargs) + self.extend(model, **kwargs) # type: ignore return self - def __add__(self, info: ExtendInfo | BaseModel) -> Self: + @property + def cout(self) -> Connection: + return self.connection_map[self._cout] + + @property + def cin(self) -> Connection: + return self.connection_map[self._cin] + + def set_cin(self, *connections: str | Connection, safe: bool = True) -> None: + data: list[str | ConnectionData] = [ + item if isinstance(item, str) else item.data for item in connections + ] + self._set_cin(*data, safe=safe) + + def set_cout(self, *connections: str | Connection, safe: bool = True) -> None: + data: list[str | ConnectionData] = [ + item if isinstance(item, str) else item.data for item in connections + ] + self._set_cout(*data, safe=safe) + + def __add__(self, info: ExtendInfo | Model) -> Self: # TODO: Check if info is a valid info for canonical connections. # TODO: Add canonical connection information to info. if isinstance(info, BaseModel): @@ -794,227 +711,106 @@ def __add__(self, info: ExtendInfo | BaseModel) -> Self: __iadd__ = __add__ - def __or__(self, info: ExtendInfo | BaseModel) -> Self: + def __or__(self, info: ExtendInfo | Model) -> Self: # TODO: Check if info is a valid info for extend. - if isinstance(info, BaseModel): + if isinstance(info, Model): info = info() return self._extend(info.model, info.connections) __ior__ = __or__ - @staticmethod - def _update_key_name( - new_key: str, - underscored_keys: set[str], - raw_keys: dict[str, list[str]], - key_mappings: dict[str, str], - key_origin: str, - input_set: set[str], - ) -> tuple[str, str]: - # Add underscore if generated key name exists in input keys - key_prefix = "_" - # check any of key_prefix + raw_keys[key_origin] in input keys - flag = True - while flag: - flag = False - for item in raw_keys[key_origin]: - if key_prefix + key_mappings[item] in input_set | set( - key_mappings.values() - ): - key_prefix += "_" - flag = True - - new_key = key_prefix + new_key - underscored_keys.add(key_origin) - # Update same origin key names that has been previously added. - for raw_key in raw_keys[key_origin]: - key_mappings[raw_key] = key_prefix + key_mappings[raw_key] - raw_keys[key_prefix + key_origin] = raw_keys.pop(key_origin) - key_origin = key_prefix + key_origin - return new_key, key_origin - - def generate_keys( + ShapeType = ( + Mapping[str | Connection, ShapeTemplateType] + | Mapping[str, ShapeTemplateType] + | Mapping[Connection, ShapeTemplateType] + ) + + def set_shapes( + self, config: ShapeType | None = None, **kwargs: ShapeTemplateType + ) -> None: + if config is None: + config = {} + _config: dict[str | ConnectionData, ShapeTemplateType] = { + key.data if isinstance(key, Connection) else key: value + for key, value in config.items() + } + self._set_shapes(_config, trace=True, updates=None, **kwargs) + + def set_values( self, - symbolic: bool = True, - include_internals: bool = True, - include_outputs: bool = False, - ) -> dict[str, str]: - key_mappings: dict[str, str] = {} - raw_keys: dict[str, list[str]] = {} - underscored_keys = set[str]() - - if include_outputs: - input_set = set(self.external_keys) - keys = "external_keys" - else: - input_set = set(self.input_keys) - keys = "input_keys" - - sorted_inputs = [ - self.dag[m][key].key - for m in self.get_models_in_topological_order() - for key in getattr(m, keys) - if self.dag[m][key].key in input_set + config: Mapping[ + str | Connection, Tensor[int | float | bool] | MainValueType | str ] - # TODO: remove duplicate loop traverse - for key in sorted_inputs: - new_key = key - if key[0] != "$": - continue - # TODO: Discuss if we want to generate input key only - # if len(self.conns.cins) == 1, or we could name all canonical inputs. - if ( - len(self.conns.cins) == 1 - and key == self.cin.key - and "input" not in self.input_keys - ): - # Handle canonical input - new_key = "input" - else: - key_origin = self.conns.get_key_origin(key) - assert key_origin is not None - # Add prefix until key_origin not in underscored_keys and input_keys. - while ( - key_origin in (underscored_keys | self.input_keys) - or key_origin == "input" - ): - key_origin = "_" + key_origin - - raw_keys.setdefault(key_origin, []) - key_idx = len(raw_keys[key_origin]) - if key_idx == 0: - # Set key origin as is for the initial key. - key_suffix = "" - else: - key_suffix = "_" + str(key_idx) - if key_idx == 1: - # Update initial key if same key origin is encountered - # (add index to initial key). - raw_key = raw_keys[key_origin][0] - key_mappings[raw_key] = key_mappings[raw_key] + "_0" - if key_mappings[raw_key] in self.input_keys: - new_key, key_origin = self._update_key_name( - new_key, - underscored_keys, - raw_keys, - key_mappings, - key_origin, - set(self.input_keys), - ) - - new_key = key_origin + key_suffix - if new_key in self.input_keys: - new_key, key_origin = self._update_key_name( - new_key, - underscored_keys, - raw_keys, - key_mappings, - key_origin, - set(self.input_keys), - ) - raw_keys[key_origin].append(key) - key_mappings[key] = new_key - - if include_internals: - sorted_models = self.get_models_in_topological_order() - internal_key_mappings: dict[str, str] = {} - for idx, m in enumerate(sorted_models): - for key in m.external_keys: - outer_conn = self.dag[m][key] - outer_key = outer_conn.key - if outer_key[0] == "$": - # if key is autogenerated, generate a name for the key - model_name = m.__class__.__name__ - key_origin = outer_conn.metadata.key_origin - assert key_origin is not None - - generated_name = ( - "_" + model_name + "_" + str(idx) + "_" + key_origin - ) + | Mapping[Connection, Tensor[int | float | bool] | MainValueType | str] + | Mapping[str, Tensor[int | float | bool] | MainValueType | str] + | None = None, + **kwargs: Tensor[int | float | bool] | MainValueType | str, + ) -> None: + if config is None: + config = {} + _config: dict[ + str | ConnectionData, Tensor[int | float | bool] | MainValueType | str + ] = { + key if isinstance(key, str) else key.data: value + for key, value in config.items() + } + self._set_values(_config, **kwargs) - # if key is an output key, directly write it - # to the internal_key_mappings - # or - # if key is an input key, first check if the key - # is already in internal_key mappings to avoid - # overwrite - write_to_internal_key_mappings = ( - key in m.conns.output_keys - or internal_key_mappings.get( - outer_key, key_mappings.get(outer_key) - ) - is None - ) + def set_types( + self, + config: Mapping[ + str | Connection, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + | Mapping[ + Connection, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + | Mapping[ + str, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + | None = None, + **kwargs: type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ) -> None: + """ + Set types of any connection in the Model + + This method updates types in given connections. + connections can be given either as Connection or their string + equivalent. Giving a valid type for given connections, this method + will update the connections's types and thereafter runs the + constraints to update affected connections' types. + + Args: + values (dict[str | Connection, MainValueType]): A dictionary where + keys are either strings or Connection objects, and values are + of type of type or UnionType objects. + + """ + if config is None: + config = {} + _config: dict[ + str | ConnectionData, + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] = { + key if isinstance(key, str) else key.data: value + for key, value in config.items() + } + self._set_types(_config, **kwargs) - while ( - generated_name in internal_key_mappings.values() - and write_to_internal_key_mappings - ): - assert key_origin is not None - key_origin = "_" + key_origin - generated_name = ( - "_" + model_name + "_" + str(idx) + "_" + key_origin - ) - - if write_to_internal_key_mappings: - internal_key_mappings[outer_key] = generated_name - - key_mappings = internal_key_mappings | key_mappings - if symbolic: - key_mappings = {key: "$" + value for key, value in key_mappings.items()} - return key_mappings - - def get_unique_submodel_names(self) -> dict[BaseModel, str]: - name_mapping: dict[BaseModel, str] = {} - existing_names: set[str] = set() - model_type_dict: dict[str, list[BaseModel]] = {} - - # First, assign existing names and track used names. - # Also save unnamed models to model_type_dict. - for model in self.dag: - if model.name: - name_mapping[model] = model.name - existing_names.add(model.name) - else: - model_type_dict.setdefault(model.__class__.__name__, []).append(model) - - # Iterate over different model types among unnamed models. - for model_type, model_list in model_type_dict.items(): - counter = 0 - # Iterate over same class model objects to name them. - for i, model in enumerate(model_list): - if len(model_list) == 1: - # If there is only one model of a type, do not increment counter. - counter -= 1 - name = model_type - else: - name = f"{model_type}_{counter + i}" - while name in existing_names: - counter += 1 # counter is incremented until a unique name is found. - name = f"{model_type}_{counter + i}" - name_mapping[model] = name - existing_names.add(name) - return name_mapping - - def _freeze(self) -> None: - for cout in self.conns.couts: - self.conns.set_connection_type(cout, KeyType.OUTPUT, safe=False) - self.dependency_map.update_all_keys() - - # Name unnamed submodels before freezing considering the insertion order. - model_names = self.get_unique_submodel_names() - for m in self.dag: - if m.name is None: - m.name = model_names[m] - - if self.formula_key is not None: - # Must be convertable to primitive. - assert len(self.conns.output_keys) == 1, ( - "Logical models have altenative primitive implementation must " - "have only 1 output." - ) - super()._freeze() + def set_outputs(self, *args: str | Connection, **kwargs: str | Connection) -> None: + _args: list[str | ConnectionData] = [ + key if isinstance(key, str) else key.data for key in args + ] + _kwargs: dict[str, str | ConnectionData] = { + key: value if isinstance(value, str) else value.data + for key, value in kwargs.items() + } + self._set_outputs(*_args, **_kwargs) + # TODO: Update summary, this should work same with both + # Logical Model and Operator def summary( self, shapes: bool = True, @@ -1055,7 +851,7 @@ def summary( # TODO: Remove name argument from summary method if not name and (name := self.name) is None: - name = self.__class__.__name__ + name = self.class_name # construct the table based on relevant information table = get_summary( @@ -1080,151 +876,31 @@ def summary( "var_cache": var_cache, "types": types, } - if isinstance(model, PrimitiveModel): + if isinstance(model, Operator): kwargs.pop("depth") - model.summary(**kwargs) - - def extract_connection_info( - self, - name_mappings: dict[BaseModel, str], - data_to_key_map: dict[IOHyperEdge, list[str]] | None = None, - data_memo: Mapping[int, IOHyperEdge] | None = None, - ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: - conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} - if self.input_keys: - if data_to_key_map is None: - data_to_key_map = {} - if data_memo is None: - data_memo = {} - model_key_map: dict[BaseModel, dict[str, str]] = {} - - # handle the case when model is constructed with += operation. In that case, - # directly take canonical output as the output_key. - - # TODO: We may expose all canonical outputs for summary instead of - # checking only len(self.conns.couts) == 1. - output_keys = ( - ([self.cout.key] if len(self.conns.couts) == 1 else []) - if not self.conns.output_keys - else self.conns.output_keys - ) - # extract key mappings and data map of outer model - key_mappings = self.generate_keys( - include_internals=False, include_outputs=True - ) - data_map = {key: conn.metadata for key, conn in self.conns.all.items()} - - # Sort in topological order - sorted_models = self.get_models_in_topological_order() - - for model in sorted_models: - model_name = name_mappings[model] - m_info = self.dag[model] - # set default structure of conn_info and shape_info - conns = conn_info.setdefault(model_name, ({}, {})) - # include input keys with Tensor value - input_keys = tuple(model.input_keys) - # Generate sub_model key_map and data map - model_key_map[model] = m_key_mappings = model.generate_keys( - include_internals=False, include_outputs=True - ) - m_data_map = { - key: conn.metadata for key, conn in model.conns.all.items() - } - for inner_key in input_keys + tuple(model.conns.output_keys): - # Find the data of the key, if data memo is given, extract its - # copied version and extract the shapes - key_data = data_memo.get( - id(m_data_map[inner_key]), m_data_map[inner_key] - ) - - # Find inner and outer keys. Also find their updated version based - # on their key mappings - updated_inner_key = m_key_mappings.get(inner_key, inner_key) - outer_conn = m_info[inner_key] - outer_key = outer_conn.key - updated_outer_key = data_to_key_map.get( - key_data, [key_mappings.get(outer_key, outer_key)] - ) - - # take and setdefault connection list in which update will be done - conn = conns[inner_key in model.conns.output_keys].setdefault( - updated_inner_key, [] - ) - if inner_key not in input_keys: - continue - - if (val := key_data.value) is not TBD: - conn.append(str(val)) - - elif outer_key in self.input_keys: - # If outer_key in input_keys of overall model, it means - # the input key is overall input to the model. Do the - # updates accordingly - input_name = ["'" + key + "'" for key in updated_outer_key] - conn.extend(input_name) - else: - # if input_key is not in self.input_keys, that means this - # input key connected to a model and it is an internal - # connection. Find the connected model and do the intializations - con_model = self.dependency_map.local_output_dependency_map[ - outer_conn - ][0] - con_generated_keys = model_key_map.setdefault( - con_model, - con_model.generate_keys( - include_internals=False, include_outputs=True - ), - ) - conn_info.setdefault(name_mappings[con_model], ({}, {})) - model_conn = m_info[inner_key] - con = con_model.conns.get_con_by_metadata(model_conn.metadata) - assert con is not None, "Connection is not found" - con_key = con.key - - con_key = con_generated_keys.get(con_key, con_key) - # Since being internal key means having two sided connection, - # Two updates on conn_info dict needs to be done. one for - # model's input key and other for connected_model's output - # key. do the updates accordingly. - conn_info[model_name][0].setdefault( - updated_inner_key, [] - ).append(name_mappings[con_model] + "." + con_key) - conn_info[name_mappings[con_model]][1].setdefault( - con_key, [] - ).append(model_name + "." + updated_inner_key) - - for outer_key in output_keys: - # Lastly, traverse through output keys of the overall model - # Find the connected model, and find the inner key by finding - # the metadata - metadata = self.conns.get_metadata(outer_key) - outer_out_conn = self.conns.get_connection(outer_key) - - assert metadata is not None, "Metadata is not found!" - assert outer_out_conn is not None, "Connection is not found" - - model = self.dependency_map.local_output_dependency_map[outer_out_conn][ - 0 - ] - other_conn = model.conns.get_con_by_metadata(metadata) - assert other_conn is not None, "Connection is not found" - - inner_key = other_conn.key - updated_inner_key = model_key_map[model].get(inner_key, inner_key) - key_data = data_memo.get(id(data_map[outer_key]), data_map[outer_key]) - updated_outer_key = data_to_key_map.get( - key_data, [key_mappings.get(outer_key, outer_key)] - ) - if updated_outer_key[0][0] == "$": - # There is only possibilty of outer key is found to be with $ sign. - # That is, if model is constructed with += operator. In that case, - # canonical output will be external key even if it is not named by - # user. Therefore, handle the case with dicrectly writing $output - updated_outer_key = ["$output"] - model_name = name_mappings[model] - conn_info[model_name][1][updated_inner_key].extend( - ["'" + key + "'" for key in updated_outer_key] - ) + model.summary(**kwargs) # type: ignore + + +def define_unique_names( + models: list[BaseModel] | KeysView[BaseModel], +) -> dict[BaseModel, str]: + # TODO: Move this to Physical model (currently it is only used there) + # TODO: Also add short-naming logic to this function + model_name_dict = {} + single_model_dict = {} + model_count_dict: dict[str, int] = {} + + for model in models: + class_name = model.name or model.class_name + if model_count_dict.setdefault(class_name, 0) == 0: + single_model_dict[class_name] = model + else: + single_model_dict.pop(class_name, None) + model_name_dict[model] = ( + str(class_name) + "_" + str(model_count_dict[class_name]) + ) + model_count_dict[class_name] += 1 - return conn_info + for m in single_model_dict.values(): + model_name_dict[m] = m.name or str(m.class_name) + return model_name_dict diff --git a/mithril/framework/logical/operator.py b/mithril/framework/logical/operator.py new file mode 100644 index 00000000..3e679b1e --- /dev/null +++ b/mithril/framework/logical/operator.py @@ -0,0 +1,146 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import get_args, get_origin + +from ...utils.utils import OrderedSet +from ..common import ( + BaseKey, + ConnectionDataType, + IOHyperEdge, + KeyType, + Tensor, + ToBeDetermined, + Updates, + create_shape_map, +) +from .base import BaseModel + + +class Operator(BaseModel): + """This class contains the simplest / primitive + building blocks of composite models. + """ + + _model_name: str = "" + output_key: str = "output" + cache_name: str = "cache" + + def __init__( + self, + formula_key: str, + name: str | None = None, + **keys: BaseKey | IOHyperEdge, + ) -> None: + super().__init__(name, formula_key) + + self.random_keys: set[str] = set() + # Get shape_templates of TensorTypes and create corresponding shapes. + shape_templates = { + key: value.value_shape + for key, value in keys.items() + if isinstance(value, BaseKey) and value.value_shape is not None + } + shapes = create_shape_map(shape_templates, self.constraint_solver) + data_set: set[IOHyperEdge] = set() + is_diff = False + output_data: IOHyperEdge | None = None + for key, value in keys.items(): + if isinstance(value, BaseKey): + if get_origin(value.type) is Tensor: + if not isinstance(tensor := value.value, Tensor): + assert isinstance(value.value, ToBeDetermined) + tensor = Tensor( + type=get_args(value.type)[0], + shape=shapes[key].node, + ) + edge = IOHyperEdge(value=tensor, interval=value.interval) + data_set.add(edge) + else: + edge_type = ToBeDetermined if value.type is None else value.type + edge = IOHyperEdge( + type=edge_type, + value=value.value, + interval=value.interval, + ) + else: + raise TypeError( + "Operator's can only be instantiated with BaseKey type keys!" + ) + + conn_data = self._create_connection(edge, key) + + if key == Operator.output_key: + self.conns.set_connection_type(conn_data, KeyType.OUTPUT) + output_data = edge + else: + self.conns.set_connection_type(conn_data, KeyType.INPUT) + is_diff |= not edge.is_non_diff + if isinstance(output_data, IOHyperEdge) and isinstance( + output_data.edge_type, Tensor + ): + output_data.differentiable = is_diff + + # Initially run all given tensors' constraints + self.constraint_solver.update_shapes(Updates(data_set)) + + input_conns = OrderedSet(conn for conn in self.conns.input_connections) + out_conn = self.conns.get_connection("output") + assert out_conn is not None + output_conns = OrderedSet({out_conn}) + + for conn in self.conns.input_connections: + self.dependency_map.local_input_dependency_map[conn] = [ + (self, output_conns) + ] + + for conn in output_conns: + self.dependency_map.local_output_dependency_map[conn] = (self, input_conns) + + self.dependency_map.cache_internal_references(out_conn, input_conns) + self.dependency_map.update_all_keys() + + # Link canonicals + canonical_input_key = ( + "input" if "input" in self.input_keys else next(iter(self.input_keys)) + ) + canonical_input_conn = self.conns.get_connection(canonical_input_key) + if canonical_input_conn is not None: + self._set_cin(canonical_input_conn, safe=False) + + canonical_output_key = ( + "output" + if "output" in self.conns.output_keys + else next(iter(self.conns.output_keys)) + ) + canonical_output_conn = self.conns.get_connection(canonical_output_key) + if canonical_output_conn is not None: + self._set_cout(canonical_output_conn, safe=False) + self._freeze() + + @property + def formula_key(self) -> str: + assert self._formula_key is not None + return self._formula_key + + @property + def class_name(self) -> str: + return self._model_name + + def extend( + self, + model: BaseModel | BaseModel, + **kwargs: ConnectionDataType, + ) -> None: + raise NotImplementedError("Operators cannot be extended!") diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/operators.py similarity index 66% rename from mithril/framework/logical/essential_primitives.py rename to mithril/framework/logical/operators.py index 65bbd13e..ac59d7d7 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/operators.py @@ -19,11 +19,8 @@ from ... import core from ..common import ( - NOT_GIVEN, TBD, BaseKey, - Connection, - ConnectionType, Constraint, ScalarValueType, ShapeTemplateType, @@ -62,102 +59,88 @@ to_tensor_constraints, to_tuple_constraints, ) -from .base import ExtendInfo -from .primitive import PrimitiveModel +from .operator import Operator __all__ = [ - "PrimitiveModel", - "Buffer", - "ToTuple", - "Power", - "Add", - "Subtract", - "Multiply", - "Divide", - "FloorDivide", - "Minus", - "MatrixMultiply", - "Shape", - "Reshape", - "Length", - "Size", - "Exponential", - "Item", - "Indexer", - "ToTensor", - "ToList", - "TensorToList", - "Mean", - "Sum", - "Max", - "Min", - "Prod", - "Variance", - "Absolute", - "Equal", - "NotEqual", - "Greater", - "GreaterEqual", - "Less", - "LessEqual", - "LogicalNot", - "LogicalOr", - "LogicalAnd", - "LogicalXOr", - "ShiftLeft", - "ShiftRight", - "ArgMax", - "ArgMin", - "Cast", - "Transpose", - "Sqrt", - "Split", - "Slice", - "Dtype", - "Sine", - "Cosine", - "Minimum", - "Maximum", + "Operator", + "BufferOp", + "ToTupleOp", + "PowerOp", + "AddOp", + "SubtractOp", + "MultiplyOp", + "DivideOp", + "FloorDivideOp", + "MinusOp", + "MatrixMultiplyOp", + "ShapeOp", + "ReshapeOp", + "LengthOp", + "SizeOp", + "ExponentialOp", + "ItemOp", + "IndexerOp", + "ToTensorOp", + "ToListOp", + "TensorToListOp", + "MeanOp", + "SumOp", + "MaxOp", + "MinOp", + "ProdOp", + "VarianceOp", + "AbsoluteOp", + "EqualOp", + "NotEqualOp", + "GreaterOp", + "GreaterEqualOp", + "LessOp", + "LessEqualOp", + "LogicalNotOp", + "LogicalOrOp", + "LogicalAndOp", + "LogicalXOrOp", + "ShiftLeftOp", + "ShiftRightOp", + "ArgMaxOp", + "ArgMinOp", + "CastOp", + "TransposeOp", + "SqrtOp", + "SplitOp", + "SliceOp", + "DtypeOp", + "SineOp", + "CosineOp", + "MinimumOp", + "MaximumOp", ] ConstantType = float | int | core.Constant -class Buffer(PrimitiveModel): - input: Connection - output: Connection +class BufferOp(Operator): + _model_name: str = "Buffer" def __init__( self, input: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, - *, - name: str | None = None, ) -> None: super().__init__( formula_key="buffer", - name=name, output=BaseKey(), input=BaseKey(value=input), ) - self._add_constraint( - fn=buffer_constraint, keys=[PrimitiveModel.output_key, "input"] - ) + self._add_constraint(fn=buffer_constraint, keys=[Operator.output_key, "input"]) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, output=output) +class ToTupleOp(Operator): + _model_name: str = "ToTuple" -class ToTuple(PrimitiveModel): def __init__( self, n: int, - *, - name: str | None = None, **kwargs: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined, ) -> None: self.factory_args = {"n": n} @@ -176,17 +159,15 @@ def __init__( for idx in range(n) } - super().__init__(formula_key="to_tuple", name=name, **key_definitions) + super().__init__(formula_key="to_tuple", name=None, **key_definitions) self._add_constraint( fn=to_tuple_constraints, - keys=[PrimitiveModel.output_key] + [key for key in self.input_keys], + keys=[Operator.output_key] + [key for key in self.input_keys], ) -class ArithmeticOperation(PrimitiveModel): - left: Connection - right: Connection - output: Connection +class ArithmeticOp(Operator): + _model_name: str = "Arithmetic" def __init__( self, @@ -206,41 +187,31 @@ def __init__( edge_constraint = self._add_constraint( fn=edge_type_constraint, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={edge_constraint}, ) bcast_constraint = self._add_constraint( fn=bcast, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={edge_constraint}, ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={bcast_constraint}, ) self.edge_constraint = edge_constraint - def __call__( # type: ignore[override] - self, - left: ConnectionType = NOT_GIVEN, - right: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(left=left, right=right, output=output) - -class Power(PrimitiveModel): - base: Connection - exponent: Connection - output: Connection +class PowerOp(Operator): + _model_name: str = "Power" def __init__( self, @@ -284,63 +255,40 @@ def __init__( ) edge_constraint = self._add_constraint( fn=edge_type_constraint, - keys=[PrimitiveModel.output_key, "base", "exponent"], + keys=[Operator.output_key, "base", "exponent"], ) constrs = {edge_constraint} self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "base", "exponent"], + keys=[Operator.output_key, "base", "exponent"], dependencies=constrs, ) bcast_constraint = self._add_constraint( fn=bcast, - keys=[PrimitiveModel.output_key, "base", "exponent"], + keys=[Operator.output_key, "base", "exponent"], dependencies=constrs, ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "base", "exponent"], + keys=[Operator.output_key, "base", "exponent"], dependencies={bcast_constraint}, ) self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x**y), - keys=[PrimitiveModel.output_key, "base", "exponent"], + keys=[Operator.output_key, "base", "exponent"], dependencies=constrs, ) constrs = constrs - def __call__( # type: ignore[override] - self, - base: ConnectionType = NOT_GIVEN, - exponent: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - *, - name: str | None = None, - threshold: ConnectionType = core.Constant.MIN_POSITIVE_NORMAL, - ) -> ExtendInfo: - kwargs = {"base": base, "exponent": exponent, "output": output} - default = ( - isinstance(threshold, core.Constant) - and threshold == core.Constant.MIN_POSITIVE_NORMAL - ) - if self.robust: - # NOTE: Since we can not provide Tensor objects as default - # arguments, we need to convert default value. - if default: - threshold = Tensor(threshold) # type: ignore - kwargs["threshold"] = threshold - elif not default: - raise ValueError("Threshold cannot be specified when robust mode is off") - - return super().__call__(**kwargs) - - -class Add(ArithmeticOperation): + +class AddOp(ArithmeticOp): + _model_name: str = "Add" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -352,12 +300,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x + y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class Subtract(ArithmeticOperation): +class SubtractOp(ArithmeticOp): + _model_name: str = "Subtract" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -369,12 +319,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x - y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class Multiply(ArithmeticOperation): +class MultiplyOp(ArithmeticOp): + _model_name: str = "Multiply" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -388,35 +340,35 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x * y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class Minimum(ArithmeticOperation): +class MinimumOp(ArithmeticOp): + _model_name: str = "Minimum" + def __init__( self, - name: str | None = None, left: TensorValueType | ToBeDetermined = TBD, right: TensorValueType | ToBeDetermined = TBD, ) -> None: - super().__init__(formula_key="minimum", name=name, left=left, right=right) + super().__init__(formula_key="minimum", left=left, right=right) + +class MaximumOp(ArithmeticOp): + _model_name: str = "Maximum" -class Maximum(ArithmeticOperation): def __init__( self, - name: str | None = None, left: TensorValueType | ToBeDetermined = TBD, right: TensorValueType | ToBeDetermined = TBD, ) -> None: - super().__init__(formula_key="maximum", name=name, left=left, right=right) + super().__init__(formula_key="maximum", left=left, right=right) -class Divide(PrimitiveModel): - numerator: Connection - denominator: Connection - output: Connection +class DivideOp(Operator): + _model_name: str = "Divide" def __init__( self, @@ -436,48 +388,36 @@ def __init__( ) edge_constraint = self._add_constraint( fn=edge_type_constraint, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], ) self._add_constraint( fn=divide_type_constraint, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={edge_constraint}, ) bcast_constraint = self._add_constraint( fn=bcast, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={edge_constraint}, ) self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x / y), - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={edge_constraint}, ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={bcast_constraint}, ) - def __call__( # type: ignore[override] - self, - numerator: ConnectionType = NOT_GIVEN, - denominator: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__( - numerator=numerator, denominator=denominator, output=output - ) - -class FloorDivide(PrimitiveModel): - numerator: Connection - denominator: Connection - output: Connection +class FloorDivideOp(Operator): + _model_name: str = "FloorDivide" def __init__( self, @@ -497,48 +437,36 @@ def __init__( ) edge_constraint = self._add_constraint( fn=edge_type_constraint, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], ) self._add_constraint( fn=floor_divide_type_constraint, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={edge_constraint}, ) bcast_constraint = self._add_constraint( fn=bcast, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={edge_constraint}, ) self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x // y), - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={edge_constraint}, ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "numerator", "denominator"], + keys=[Operator.output_key, "numerator", "denominator"], dependencies={bcast_constraint}, ) - def __call__( # type: ignore[override] - self, - numerator: ConnectionType = NOT_GIVEN, - denominator: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__( - numerator=numerator, denominator=denominator, output=output - ) - -class MatrixMultiply(PrimitiveModel): - left: Connection - right: Connection - output: Connection +class MatrixMultiplyOp(Operator): + _model_name: str = "MatrixMultiply" def __init__( self, @@ -555,32 +483,23 @@ def __init__( right=BaseKey(shape=[("Var2", ...), "y", "z"], type=Tensor, value=right), ) bcast_constraint = self._add_constraint( - fn=bcast_matrix_mult, keys=[PrimitiveModel.output_key, "left", "right"] + fn=bcast_matrix_mult, keys=[Operator.output_key, "left", "right"] ) self._add_constraint( fn=bcast_mat_mul_check, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={bcast_constraint}, ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], ) - def __call__( # type: ignore[override] - self, - left: ConnectionType = NOT_GIVEN, - right: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(left=left, right=right, output=output) - -class Shape(PrimitiveModel): - input: Connection - output: Connection +class ShapeOp(Operator): + _model_name: str = "Shape" def __init__( self, @@ -596,16 +515,9 @@ def __init__( ) self._add_constraint(fn=shape_constraints, keys=["output", "input"]) - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) - -class Reshape(PrimitiveModel): - input: Connection - shape: Connection - output: Connection +class ReshapeOp(Operator): + _model_name: str = "Reshape" def __init__( self, @@ -629,18 +541,9 @@ def __init__( ) self._add_constraint(fn=reshape_constraints, keys=["output", "input", "shape"]) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - shape: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, shape=shape, output=output) - -class Length(PrimitiveModel): - input: Connection - output: Connection +class LengthOp(Operator): + _model_name: str = "Length" def __init__( self, @@ -655,16 +558,9 @@ def __init__( input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) - -class Cast(PrimitiveModel): - input: Connection - dtype: Connection - output: Connection +class CastOp(Operator): + _model_name: str = "Cast" def __init__( self, dtype: core.Dtype | ToBeDetermined = TBD, *, name: str | None = None @@ -677,18 +573,9 @@ def __init__( dtype=BaseKey(type=core.Dtype, value=dtype), ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - dtype: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, dtype=dtype, output=output) - -class Dtype(PrimitiveModel): - input: Connection - output: Connection +class DtypeOp(Operator): + _model_name: str = "Dtype" def __init__( self, @@ -703,16 +590,9 @@ def __init__( input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) - -class Size(PrimitiveModel): - input: Connection - dim: Connection - output: Connection +class SizeOp(Operator): + _model_name: str = "Size" def __init__( self, @@ -731,18 +611,9 @@ def __init__( ) self._add_constraint(fn=size_constraints, keys=["output", "input", "dim"]) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - dim: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, dim=dim, output=output) - -class Item(PrimitiveModel): - input: Connection - output: Connection +class ItemOp(Operator): + _model_name: str = "Item" def __init__( self, @@ -756,22 +627,13 @@ def __init__( output=BaseKey(type=int | float), input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) - self._add_constraint( - fn=item_constraints, keys=[PrimitiveModel.output_key, "input"] - ) + self._add_constraint(fn=item_constraints, keys=[Operator.output_key, "input"]) self._jittable = False - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) - -class ToTensor(PrimitiveModel): - input: Connection - dtype: Connection - output: Connection +class ToTensorOp(Operator): + _model_name: str = "ToTensor" def __init__( self, @@ -789,20 +651,12 @@ def __init__( ) self._add_constraint( - fn=to_tensor_constraints, keys=[PrimitiveModel.output_key, "input"] + fn=to_tensor_constraints, keys=[Operator.output_key, "input"] ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - dtype: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, dtype=dtype, output=output) - -class ToList(PrimitiveModel): - output: Connection +class ToListOp(Operator): + _model_name: str = "ToList" def __init__( self, @@ -828,13 +682,12 @@ def __init__( self._add_constraint( fn=to_list_constraints, - keys=[PrimitiveModel.output_key] + [key for key in self.input_keys], + keys=[Operator.output_key] + [key for key in self.input_keys], ) -class TensorToList(PrimitiveModel): - input: Connection - output: Connection +class TensorToListOp(Operator): + _model_name: str = "TensorToList" def __init__( self, @@ -849,25 +702,17 @@ def __init__( input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) self._add_constraint( - fn=tensor_to_list_constraints, keys=[PrimitiveModel.output_key, "input"] + fn=tensor_to_list_constraints, keys=[Operator.output_key, "input"] ) self._add_constraint( - fn=tensor_to_list_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=tensor_to_list_type_constraint, keys=[Operator.output_key, "input"] ) self._jittable = False - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) - -class Reduce(PrimitiveModel): - input: Connection - axis: Connection - keepdim: Connection - output: Connection +class ReduceOp(Operator): + _model_name: str = "Reduce" def __init__( self, @@ -879,7 +724,6 @@ def __init__( name: str | None = None, **kwargs: BaseKey, ) -> None: - # TODO: Handle axis type for conditional cases below. self.factory_args = {"axis": axis, "keepdim": keepdim} axis_type: UnionType | type if isinstance(axis, tuple): @@ -902,20 +746,13 @@ def __init__( self._add_constraint( fn=reduce_constraints, - keys=[PrimitiveModel.output_key, "input", "axis", "keepdim"], + keys=[Operator.output_key, "input", "axis", "keepdim"], ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - axis: ConnectionType = NOT_GIVEN, - keepdim: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, axis=axis, keepdim=keepdim, output=output) +class MeanOp(ReduceOp): + _model_name: str = "Mean" -class Mean(Reduce): # TODO: Torch expects float input for mean reduction, JAX accepts all types. def __init__( self, @@ -935,11 +772,13 @@ def __init__( ) -class Sum(Reduce): +class SumOp(ReduceOp): + _model_name: str = "Sum" + def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, @@ -948,15 +787,17 @@ def __init__( formula_key="reduce_sum", name=name, axis=axis, keepdim=keepdim, input=input ) self._add_constraint( - fn=reduce_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=reduce_type_constraint, keys=[Operator.output_key, "input"] ) -class Max(Reduce): +class MaxOp(ReduceOp): + _model_name: str = "Max" + def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, @@ -965,15 +806,17 @@ def __init__( formula_key="reduce_max", name=name, axis=axis, keepdim=keepdim, input=input ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) -class ArgMax(Reduce): +class ArgMaxOp(ReduceOp): + _model_name: str = "ArgMax" + def __init__( self, axis: int | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, @@ -984,16 +827,18 @@ def __init__( axis=axis, keepdim=keepdim, input=input, - # axis = Scalar(axis_type, axis), # TODO: Change axis type to int + # axis = Scalar(axis_type, axis), # TODO: Change axis type to int output=BaseKey(shape=[("Var_out", ...)], type=Tensor[int]), ) -class Min(Reduce): +class MinOp(ReduceOp): + _model_name: str = "Min" + def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, @@ -1002,15 +847,17 @@ def __init__( formula_key="reduce_min", name=name, axis=axis, keepdim=keepdim, input=input ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) -class ArgMin(Reduce): +class ArgMinOp(ReduceOp): + _model_name: str = "ArgMin" + def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, @@ -1021,16 +868,18 @@ def __init__( axis=axis, keepdim=keepdim, input=input, - # axis = Scalar(axis_type, axis), # TODO: Change axis type to int + # axis = Scalar(axis_type, axis), # TODO: Change axis type to int output=BaseKey(shape=[("Var_out", ...)], type=Tensor[int]), ) -class Prod(Reduce): +class ProdOp(ReduceOp): + _model_name: str = "Prod" + def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, @@ -1043,17 +892,17 @@ def __init__( input=input, ) self._add_constraint( - fn=reduce_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=reduce_type_constraint, keys=[Operator.output_key, "input"] ) -class Variance(Reduce): - correction: Connection +class VarianceOp(ReduceOp): + _model_name: str = "Variance" def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, - keepdim: bool = False, + keepdim: bool | ToBeDetermined = False, correction: int | float | None = 0.0, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, @@ -1071,26 +920,9 @@ def __init__( self.factory_args = {"axis": axis, "correction": correction, "keepdim": keepdim} # TODO: Should we remove axis, correction and keepdim from factory_args? - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - axis: ConnectionType = NOT_GIVEN, - keepdim: ConnectionType = NOT_GIVEN, - correction: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super(Reduce, self).__call__( - input=input, - axis=axis, - keepdim=keepdim, - correction=correction, - output=output, - ) - -class SingleInputOperation(PrimitiveModel): - input: Connection - output: Connection +class SingleInputOperationOp(Operator): + _model_name: str = "SingleInputOperation" def __init__( self, @@ -1107,21 +939,18 @@ def __init__( ) # Finalize kwargs. new_kwargs: Mapping[str, BaseKey] = default_kwargs | kwargs - super().__init__(formula_key, name=name, **new_kwargs) + super().__init__(formula_key=formula_key, name=name, **new_kwargs) if polymorphic_constraint: self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "input"], + keys=[Operator.output_key, "input"], ) - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) +class AbsoluteOp(SingleInputOperationOp): + _model_name: str = "Absolute" -class Absolute(SingleInputOperation): def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1131,7 +960,9 @@ def __init__( super().__init__(formula_key="abs", name=name, input=input) -class Minus(SingleInputOperation): +class MinusOp(SingleInputOperationOp): + _model_name: str = "Minus" + def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1141,7 +972,9 @@ def __init__( super().__init__(formula_key="minus", name=name, input=input) -class Exponential(SingleInputOperation): +class ExponentialOp(SingleInputOperationOp): + _model_name: str = "Exponential" + def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1157,9 +990,8 @@ def __init__( ) -class Sqrt(PrimitiveModel): - input: Connection - output: Connection +class SqrtOp(Operator): + _model_name: str = "Sqrt" def __init__( self, @@ -1188,35 +1020,9 @@ def __init__( input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - *, - cutoff: ConnectionType = core.Constant.MIN_POSITIVE_NORMAL, - ) -> ExtendInfo: - kwargs = {"input": input, "output": output} - - default = ( - isinstance(cutoff, core.Constant) - and cutoff == core.Constant.MIN_POSITIVE_NORMAL - ) - if self.robust: - if default: - # NOTE: Since we can not provide Tensor objects as default - # arguments, we need to convert default value. - cutoff = Tensor(cutoff) # type: ignore - kwargs["cutoff"] = cutoff - elif not default: - raise ValueError("Cutoff cannot be specified when robust mode is off") - - return super().__call__(**kwargs) - -class RelationalOperators(PrimitiveModel): - left: Connection - right: Connection - output: Connection +class RelationalOperatorsOp(Operator): + _model_name: str = "RelationalOperators" def __init__( self, @@ -1253,21 +1059,15 @@ def __init__( self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={bcast_constraint}, ) self.edge_constraint = edge_constraint - def __call__( # type: ignore[override] - self, - left: ConnectionType = NOT_GIVEN, - right: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(left=left, right=right, output=output) +class GreaterOp(RelationalOperatorsOp): + _model_name: str = "Greater" -class Greater(RelationalOperators): def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -1279,12 +1079,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x > y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class Less(RelationalOperators): +class LessOp(RelationalOperatorsOp): + _model_name: str = "Less" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -1296,12 +1098,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x < y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class Equal(RelationalOperators): +class EqualOp(RelationalOperatorsOp): + _model_name: str = "Equal" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -1313,12 +1117,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x == y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class NotEqual(RelationalOperators): +class NotEqualOp(RelationalOperatorsOp): + _model_name: str = "NotEqual" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -1330,12 +1136,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x != y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class LessEqual(RelationalOperators): +class LessEqualOp(RelationalOperatorsOp): + _model_name: str = "LessEqual" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -1347,12 +1155,14 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x <= y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class GreaterEqual(RelationalOperators): +class GreaterEqualOp(RelationalOperatorsOp): + _model_name: str = "GreaterEqual" + def __init__( self, left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, @@ -1364,14 +1174,13 @@ def __init__( self._add_constraint( partial(general_forward_constraint, callable=lambda x, y: x >= y), - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], dependencies={self.edge_constraint}, ) -class LogicalNot(PrimitiveModel): - input: Connection - output: Connection +class LogicalNotOp(Operator): + _model_name: str = "LogicalNot" def __init__( self, @@ -1386,16 +1195,9 @@ def __init__( input=BaseKey(shape=[("Var", ...)], type=Tensor[bool], value=input), ) - def __call__( # type: ignore[override] - self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN - ) -> ExtendInfo: - return super().__call__(input=input, output=output) - -class BitwiseOperators(PrimitiveModel): - left: Connection - right: Connection - output: Connection +class BitwiseOperatorsOp(Operator): + _model_name: str = "BitwiseOperators" def __init__( self, @@ -1414,16 +1216,10 @@ def __init__( ) self._add_constraint(bcast, ["output", "left", "right"]) - def __call__( # type: ignore[override] - self, - left: ConnectionType = NOT_GIVEN, - right: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(left=left, right=right, output=output) +class LogicalAndOp(BitwiseOperatorsOp): + _model_name: str = "LogicalAnd" -class LogicalAnd(BitwiseOperators): def __init__( self, left: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1434,7 +1230,9 @@ def __init__( super().__init__(formula_key="logical_and", name=name, left=left, right=right) -class LogicalOr(BitwiseOperators): +class LogicalOrOp(BitwiseOperatorsOp): + _model_name: str = "LogicalOr" + def __init__( self, left: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1445,7 +1243,9 @@ def __init__( super().__init__(formula_key="logical_or", name=name, left=left, right=right) -class LogicalXOr(BitwiseOperators): +class LogicalXOrOp(BitwiseOperatorsOp): + _model_name: str = "LogicalXOr" + def __init__( self, left: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1457,10 +1257,8 @@ def __init__( self.factory_args = {"left": left, "right": right} -class ShiftLeft(PrimitiveModel): - input: Connection - shift: Connection - output: Connection +class ShiftLeftOp(Operator): + _model_name: str = "ShiftLeft" def __init__( self, @@ -1479,19 +1277,9 @@ def __init__( self._add_constraint(bcast, ["output", "input", "shift"]) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - shift: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, shift=shift, output=output) - -class ShiftRight(PrimitiveModel): - input: Connection - shift: Connection - output: Connection +class ShiftRightOp(Operator): + _model_name: str = "ShiftRight" def __init__( self, @@ -1510,21 +1298,9 @@ def __init__( self._add_constraint(bcast, ["output", "input", "shift"]) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - shift: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, shift=shift, output=output) - -class Transpose(PrimitiveModel): - # NOTE: Consider if axes type list[int] is conventionally True since it is generally - # used tuple[int] in these type of cases - input: Connection - axes: Connection - output: Connection +class TransposeOp(Operator): + _model_name: str = "Transpose" def __init__( self, @@ -1575,20 +1351,9 @@ def __init__( fn=general_tensor_type_constraint, keys=["output", "input"] ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - axes: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, axes=axes, output=output) - -class Split(PrimitiveModel): - split_size: Connection - axis: Connection - input: Connection - output: Connection +class SplitOp(Operator): + _model_name: str = "Split" def __init__( self, @@ -1611,23 +1376,9 @@ def __init__( fn=split_constraints, keys=["output", "input", "split_size", "axis"] ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - split_size: ConnectionType = NOT_GIVEN, - axis: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__( - input=input, split_size=split_size, axis=axis, output=output - ) - -class Slice(PrimitiveModel): - start: Connection - stop: Connection - step: Connection - output: Connection +class SliceOp(Operator): + _model_name: str = "Slice" def __init__( self, @@ -1649,20 +1400,9 @@ def __init__( fn=slice_constraints, keys=["output", "start", "stop", "step"] ) - def __call__( # type: ignore[override] - self, - start: ConnectionType = NOT_GIVEN, - stop: ConnectionType = NOT_GIVEN, - step: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(start=start, stop=stop, step=step, output=output) - -class Indexer(PrimitiveModel): - input: Connection - index: Connection - output: Connection +class IndexerOp(Operator): + _model_name: str = "Indexer" def __init__( self, @@ -1687,37 +1427,31 @@ def __init__( ) edge_constraints = self._add_constraint( - fn=edge_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=edge_type_constraint, keys=[Operator.output_key, "input"] ) indexer_initial_constraints = self._add_constraint( fn=indexer_initial_type_constraint, - keys=[PrimitiveModel.output_key, "input", "index"], + keys=[Operator.output_key, "input", "index"], dependencies={edge_constraints}, ) self._add_constraint( fn=indexer_constraints, - keys=[PrimitiveModel.output_key, "input", "index"], + keys=[Operator.output_key, "input", "index"], dependencies={indexer_initial_constraints}, ) self._add_constraint( fn=indexer_type_constraint, - keys=[PrimitiveModel.output_key, "input", "index"], + keys=[Operator.output_key, "input", "index"], dependencies={indexer_initial_constraints}, ) - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - index: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, index=index, output=output) +class SineOp(SingleInputOperationOp): + _model_name: str = "Sine" -class Sine(SingleInputOperation): def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1733,7 +1467,9 @@ def __init__( ) -class Cosine(SingleInputOperation): +class CosineOp(SingleInputOperationOp): + _model_name: str = "Cosine" + def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index b6e95482..911d1d2b 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -12,232 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Mapping -from typing import get_args, get_origin - -from ...utils.utils import OrderedSet -from ..common import ( - TBD, - BaseKey, - Connection, - IOHyperEdge, - KeyType, - Tensor, - ToBeDetermined, - UniadicRecord, - Updates, - Variadic, - create_shape_map, - get_summary, - get_summary_shapes, - get_summary_types, -) +from ... import core +from ..common import BaseKey, ConnectionDataType from .base import BaseModel +from .model import IOKey, Model +from .operator import Operator +__all__ = ["PrimitiveModel", "OperatorModel"] -class PrimitiveModel(BaseModel): - """This class contains the simplest / primitive - building blocks of composite models. - """ +ConstantType = float | int | core.Constant - output_key = "output" - cache_name = "cache" - output: Connection +class OperatorModel(Model): def __init__( self, - formula_key: str, + model: Operator, *, name: str | None = None, - **kwargs: BaseKey | IOHyperEdge, ) -> None: - self._formula_key = formula_key - self._grad_formula = formula_key + "_grad" - - super().__init__(name=name) - - self.random_keys: set[str] = set() - # Get shape_templates of TensorTypes and create corresponding shapes. - shape_templates = { - key: value.shape - for key, value in kwargs.items() - if isinstance(value, BaseKey) and value.shape is not None - } - shapes = create_shape_map(shape_templates, self.constraint_solver) - data_set: set[IOHyperEdge] = set() - is_diff = False - output_data: IOHyperEdge | None = None - for key, value in kwargs.items(): - if isinstance(value, BaseKey): - if get_origin(value.type) is Tensor or value.type is Tensor: - val_type = ( - Tensor[int | float | bool] - if value.type is Tensor - else value.type - ) - if not isinstance(tensor := value.value, Tensor): - assert isinstance(value.value, ToBeDetermined) - tensor = Tensor( - type=get_args(val_type)[0], - shape=shapes[key].node, - ) - edge = IOHyperEdge(value=tensor, interval=value.interval) - data_set.add(edge) - else: - edge_type = ToBeDetermined if value.type is None else value.type - edge = IOHyperEdge( - type=edge_type, value=value.value, interval=value.interval - ) - else: - raise TypeError( - "PrimitiveModel's can only be instantiated with BaseKey type keys!" - ) - - conn_data = self.create_connection(edge, key) - - if key == PrimitiveModel.output_key: - self.conns.set_connection_type(conn_data, KeyType.OUTPUT) - output_data = edge - else: - self.conns.set_connection_type(conn_data, KeyType.INPUT) - is_diff |= not edge.is_non_diff - if isinstance(output_data, IOHyperEdge) and isinstance( - output_data.edge_type, Tensor - ): - output_data.differentiable = is_diff - - # Initially run all given tensors' constraints - self.constraint_solver.update_shapes(Updates(data_set)) - - input_conns = OrderedSet({conn for conn in self.conns.input_connections}) - out_conn = self.conns.get_connection("output") - assert out_conn is not None - output_conns = OrderedSet({out_conn}) - - for conn in self.conns.input_connections: - self.dependency_map.local_input_dependency_map[conn] = [ - (self, output_conns) - ] - - for conn in output_conns: - self.dependency_map.local_output_dependency_map[conn] = (self, input_conns) - - self.dependency_map.cache_internal_references(out_conn, input_conns) - self.dependency_map.update_all_keys() - - # Link canonicals - canonical_input_key = ( - "input" if "input" in self.input_keys else next(iter(self.input_keys)) - ) - canonical_input_conn = self.conns.get_connection(canonical_input_key) - if canonical_input_conn is not None: - self.set_cin(canonical_input_conn.conn, safe=False) - - canonical_output_key = ( - "output" - if "output" in self.conns.output_keys - else next(iter(self.conns.output_keys)) - ) - canonical_output_conn = self.conns.get_connection(canonical_output_key) - if canonical_output_conn is not None: - self.set_cout(canonical_output_conn.conn, safe=False) - self._freeze() - - def __iadd__(self, other: BaseModel) -> BaseModel: - raise Exception( - f"Primitive '{self.__class__.__name__}' model can not be extended!" - ) - - @property - def formula_key(self) -> str: - return self._formula_key + super().__init__(name=name, enforce_jit=model._jittable) + self._extend(model, {k: IOKey(k, expose=True) for k in model.external_keys}) @property - def grad_formula(self) -> str: - return self._grad_formula + def submodel(self) -> Operator: + m = next(iter(self.dag.keys())) + assert isinstance(m, Operator) + return m - def extract_connection_info( + def extend( self, - name_mappings: dict[BaseModel, str], - data_to_key_map: dict[IOHyperEdge, list[str]] | None = None, - data_memo: Mapping[int, IOHyperEdge] | None = None, - ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: - if data_to_key_map is None: - data_to_key_map = {} - if data_memo is None: - data_memo = {} - # construct the data_map - data_map = {key: conn.metadata for key, conn in self.conns.all.items()} - model_name = next(iter(name_mappings.values())) - - conns: tuple[dict[str, list[str]], dict[str, list[str]]] = ({}, {}) - - # Take the input_keys with tensor values - input_keys = tuple(self.input_keys) - - for key in tuple(input_keys) + tuple(self.conns.output_keys): - # find data of the key. - # If data_memo is given, take its copied version in physical model - key_data = data_memo.get(id(data_map[key]), data_map[key]) - - conn = conns[key in self.conns.output_keys].setdefault(key, []) - # try to find outer key's real name in data_to_key_map - outer_key = data_to_key_map.get(key_data, [key]) - outer_key = ["'" + key + "'" for key in outer_key] - if not key_data.is_tensor and key_data.value is not TBD: - # If value of the scalar is determined, write that value directly. - outer_key = [str(key_data.value)] - conn.extend(outer_key) + model: BaseModel | BaseModel, + **kwargs: ConnectionDataType, + ) -> None: + if len(self.dag) > 0: + raise RuntimeError("Primitive models cannot have submodels.") + super().extend(model, **kwargs) - return {model_name: conns} - def summary( +class PrimitiveModel(OperatorModel): + def __init__( self, - shapes: bool = True, - types: bool = False, - symbolic: bool = False, + formula_key: str, + *, name: str | None = None, - alternative_shapes: bool = False, - uni_cache: dict[UniadicRecord, str] | None = None, - var_cache: dict[Variadic, str] | None = None, + **kwargs: BaseKey, ) -> None: - if uni_cache is None: - uni_cache = {} - if var_cache is None: - var_cache = {} - - type_info = None - shape_info = None - name_mappings: dict[BaseModel, str] = { - self: name if name else self.__class__.__name__ - } - # extract model topology - conn_info = self.extract_connection_info(name_mappings) - - model_shapes = { - sub_model_name: sub_model.get_shapes( - uni_cache, var_cache, symbolic, alternative_shapes - ) - for sub_model, sub_model_name in name_mappings.items() - } - if shapes: - # extract model shapes - shape_info = get_summary_shapes(model_shapes, conn_info) - - if types: - # extract model types - type_info = get_summary_types(name_mappings) - - if not name: - name = self.__class__.__name__ - - # construct the table based on relevant information - table = get_summary( - conns=conn_info, - name=name, - shape=shape_info, # type: ignore - types=type_info, - ) - - table.compile() - table.display() + model = Operator(formula_key, self.class_name, **kwargs) + super().__init__(model=model, name=name) diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index 80e8c6a4..d07d69ea 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -23,6 +23,7 @@ DataEvalType, IOHyperEdge, MainValueType, + Tensor, ToBeDetermined, Updates, ) @@ -127,6 +128,25 @@ def _set_data_value(self, key: str, data: IOHyperEdge) -> None: value = getattr(self.backend, value.name) self.data_values[key] = value # type: ignore + # Add constant values of given models __call__ to constant_keys if any. + # TODO: merge convert_data_to_physical with _set_data_value + @staticmethod + def convert_data_to_physical( + value: AllValueType, backend: Backend[DataType] + ) -> DataType | AllValueType: + match value: + case Constant(): + value = epsilon_table[backend.precision][value] + case Dtype(): + value = getattr(backend, value.name) + case Tensor(): + value = backend.array( + StaticDataStore.convert_data_to_physical(value.value, backend) + ) + case _: + value = value + return value + def _infer_tensor_value_type( self, value: DataType ) -> type[bool] | type[int] | type[float]: diff --git a/mithril/framework/physical/flat_graph.py b/mithril/framework/physical/flat_graph.py index 7b840f6d..d9b5e789 100644 --- a/mithril/framework/physical/flat_graph.py +++ b/mithril/framework/physical/flat_graph.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, KeysView, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass @@ -26,7 +26,6 @@ from ..common import ( TBD, AllValueType, - Connection, ConstraintSolver, DataEvalType, IOHyperEdge, @@ -39,8 +38,9 @@ ValueType, is_type_adjustment_required, ) -from ..logical import Buffer -from ..logical.primitive import PrimitiveModel +from ..logical.model import Connection +from ..logical.operator import Operator +from ..logical.operators import BufferOp from .data_store import StaticDataStore @@ -74,7 +74,7 @@ class Node: """A node representing a primitive model and its connections in the graph. Attributes: - model (PrimitiveModel): The primitive model associated with this node. + model (Operator): The primitive model associated with this node. connections (dict[str, Connection]): A dictionary mapping connection names to Connection objects. @@ -83,7 +83,7 @@ class Node: connection of the node. """ - model: PrimitiveModel + model: Operator connections: dict[str, GConnection] def __hash__(self) -> int: @@ -111,7 +111,7 @@ def __init__( memo = {} self.backend: ml.Backend[DataType] = backend - self.nodes: dict[PrimitiveModel, Node] = {} + self.nodes: dict[Operator, Node] = {} self.connections: dict[ str, GConnection ] = {} # Assumed connections added in topological order. @@ -197,8 +197,8 @@ def set_random_seed_values(self, **seed_mapping: int) -> None: def update_cached_data(self, updates: Updates) -> set[str]: return self.data_store.update_cached_data(updates) - def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> None: - output_key = keys[PrimitiveModel.output_key] + def add_value(self, model: Operator, keys: dict[str, str]) -> None: + output_key = keys[Operator.output_key] if model.random_keys: self.random_keys |= {keys[key] for key in model.random_keys} @@ -209,11 +209,11 @@ def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> None: out_conn = GConnection(node, output_key, [], [], set()) self.connections[output_key] = out_conn - node.connections[PrimitiveModel.output_key] = out_conn + node.connections[Operator.output_key] = out_conn # Create input connections for inner_key, outer_key in keys.items(): - if inner_key == PrimitiveModel.output_key: + if inner_key == Operator.output_key: continue conn = self.connections.get(outer_key, None) @@ -229,7 +229,7 @@ def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> None: self.nodes[model] = node self._all_target_keys.add(output_key) - self._topological_order.append(node.connections[PrimitiveModel.output_key].key) + self._topological_order.append(node.connections[Operator.output_key].key) for conn in node.connections.values(): self._update_connection_keys(conn) @@ -269,8 +269,7 @@ def all_source_keys(self) -> set[str]: def _update_topological_order(self) -> None: self._topological_order = [ - node.connections[PrimitiveModel.output_key].key - for node in self.nodes.values() + node.connections[Operator.output_key].key for node in self.nodes.values() ] def _update_all_source_keys(self) -> None: @@ -295,7 +294,7 @@ def _update_connection_keys(self, connection: GConnection) -> None: if connection.node is not None: for inner_key, conn in connection.node.connections.items(): - if inner_key == PrimitiveModel.output_key: + if inner_key == Operator.output_key: continue key = conn.key source_keys.append(key) @@ -310,14 +309,10 @@ def get_target_keys(connection: GConnection) -> list[str]: target_keys += get_target_keys(connection) if ( connection.node is not None - and connection.key - != connection.node.connections[PrimitiveModel.output_key].key - and connection.node.connections[PrimitiveModel.output_key].key - in self.connections + and connection.key != connection.node.connections[Operator.output_key].key + and connection.node.connections[Operator.output_key].key in self.connections ): - target_keys.append( - connection.node.connections[PrimitiveModel.output_key].key - ) + target_keys.append(connection.node.connections[Operator.output_key].key) # Make sure connection key registered all_source and all_target keys if len(target_keys) > 0: @@ -328,23 +323,23 @@ def get_target_keys(connection: GConnection) -> list[str]: connection.target_keys = list(target_keys) connection.source_keys = list(source_keys) - def get_model(self, key: str) -> PrimitiveModel: + def get_model(self, key: str) -> Operator: conn = self.connections.get(key, None) if conn is None or conn.node is None: raise ValueError(f"Model not found for key: {key}") return conn.node.model - def get_model_out_key(self, model: PrimitiveModel) -> str | None: + def get_model_out_key(self, model: Operator) -> str | None: node = self.nodes.get(model, None) if node is None: return None - return node.connections[PrimitiveModel.output_key].key + return node.connections[Operator.output_key].key - def get_model_outer_key(self, model: PrimitiveModel, inner_key: str) -> str: + def get_model_outer_key(self, model: Operator, inner_key: str) -> str: return self.nodes[model].connections[inner_key].key - def get_model_connections(self, model: PrimitiveModel): # type: ignore + def get_model_connections(self, model: Operator): # type: ignore return self.nodes[model].connections.values() def get_connection(self, key: str) -> GConnection | None: @@ -391,7 +386,7 @@ def prune_duplicate_nodes( self._remove_node(node) continue - if isinstance(node.model, Buffer): + if isinstance(node.model, BufferOp): input_conn = node.connections["input"] input_conn = self._temp_connection_info.get(input_conn, input_conn) output_conn = node.connections["output"] @@ -523,7 +518,7 @@ def _prune_node(self, node: Node, conn: GConnection) -> None: self._update_connection_keys(conn_) if ( - key := node.connections[PrimitiveModel.output_key].key + key := node.connections[Operator.output_key].key ) not in self.output_keys and key in self._all_target_keys: self._all_target_keys.remove(key) @@ -536,7 +531,7 @@ def _prune_node(self, node: Node, conn: GConnection) -> None: def _remove_node(self, node: Node) -> None: connections = set(node.connections.values()) - output_conn = node.connections[PrimitiveModel.output_key] + output_conn = node.connections[Operator.output_key] # To remove node, node should not be used any other nodes or # Output of this node is already cached, so we can remove this node. @@ -622,7 +617,7 @@ def infer_ignore_step( keys.add(value) queue.add(value) - def get_models(self) -> Iterable[PrimitiveModel]: + def get_models(self) -> KeysView[Operator]: return self.nodes.keys() def infer_static_keys(self) -> Updates: diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 5b8b7f6b..f93465e2 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -27,15 +27,12 @@ from ..common import ( NOT_GIVEN, TBD, - Connection, ConnectionData, - ConnectionType, DataEvalType, EvaluateAllType, EvaluateGradientsType, EvaluateType, IOHyperEdge, - IOKey, MainValueType, ParamsEvalType, ShapeResultType, @@ -53,9 +50,13 @@ get_summary_types, ) from ..logical.base import BaseModel -from ..logical.model import Model -from ..logical.primitive import PrimitiveModel -from ..utils import define_unique_names +from ..logical.model import ( + Connection, + Model, + define_unique_names, +) +from ..logical.operator import Operator +from .data_store import StaticDataStore from .flat_graph import FlatGraph __all__ = ["PhysicalModel"] @@ -81,7 +82,7 @@ class PhysicalModel(GenericDataType[DataType]): def __init__( self, - model: BaseModel, + model: Model, backend: Backend[DataType], *, discard_keys: StringOrConnectionSetType, @@ -94,27 +95,15 @@ def __init__( safe_names: bool, use_short_namings: bool, ) -> None: - if isinstance(model, PrimitiveModel): - # TODO: Remove wrapping with Model in the future. - _model = deepcopy(model) - extend_info = _model() - model_keys: dict[str, ConnectionType] = {} - for key in _model.external_keys: - value = extend_info.connections.get(key, NOT_GIVEN) - # NOTE: Do not set default value if it is given in constant_keys. - value = (value, NOT_GIVEN)[key in constant_keys] - default_val = _model.conns.get_data(key).value - if (value is NOT_GIVEN and default_val is TBD) or ( - key in _model.output_keys - ): - # Non-valued connections are only named with their key names. - model_keys[key] = key - else: - val = default_val if default_val is not TBD else value - model_keys[key] = IOKey(key, val) # type: ignore + if len(model.conns.output_keys) == 0 and len(model.conns.couts) == 0: + raise KeyError("Models with no output keys can not be compiled.") - model = Model() - model |= _model(**model_keys) + # TODO: Update StaticDataStore.convert_data_to_physical function. + constant_keys = { # type: ignore + key: StaticDataStore.convert_data_to_physical(value, backend) # type: ignore + for key, value in model().connections.items() + if value is not NOT_GIVEN + } | constant_keys self.backend: Backend[DataType] = backend self._output_keys: set[str] = set(model.conns.output_keys) @@ -134,9 +123,6 @@ def __init__( # TODO: This is a temporary solution, a better way will be implemented # in another PR. if len(model.conns.output_keys) == 0: - if len(model.conns.couts) == 0: - raise KeyError("Models with no output keys can not be compiled.") - for cout in model.conns.couts: current_name = flat_model.assigned_edges[cout.metadata].name key_origin = cout.metadata.key_origin @@ -235,7 +221,7 @@ def __init__( # solver to propagate updates. self.flat_graph.constraint_solver(updates) - output = PrimitiveModel.output_key + output = Operator.output_key _data_dict: dict[str, IOHyperEdge] = {} self._infer_differentiability(model_data) @@ -247,7 +233,7 @@ def __init__( # NOTE: maybe move adding cache to generate_code methods. if self.backend.backend_type == "numpy": - cache_name = "_".join([mappings[output], p_model.cache_name]) + cache_name = "_".join([mappings[output], Operator.cache_name]) mappings["cache"] = cache_name # TODO: Why do we have to provide cache_value here? It is # NONE | dict(). @@ -431,7 +417,7 @@ def input_keys(self) -> set[str]: def _infer_differentiability(self, model_data: dict[str, IOHyperEdge]) -> None: # Infer output differentiability only for the models # that have a Tensor type output. - output_key = PrimitiveModel.output_key + output_key = Operator.output_key output_edge = model_data[output_key] if output_edge.is_tensor: # If any of the inputs are differentiable, then @@ -570,9 +556,7 @@ def _pre_compile( ) ) ): - self.ignore_grad_keys.add( - node.connections[PrimitiveModel.output_key].key - ) + self.ignore_grad_keys.add(node.connections[Operator.output_key].key) if len(self._output_keys - self.ignore_grad_keys) == 0 and not self.inference: raise ValueError("All outputs gradient are ignored.") @@ -755,14 +739,14 @@ def summary( # Extract all summary information dag: list[BaseModel] | dict[BaseModel, dict[str, ConnectionData]] if model is not None: - dag = model.dag if isinstance(model, Model) else [model] + dag = list(model.dag) if isinstance(model, Model) else [model] name_mappings = define_unique_names(dag) conn_info = model.extract_connection_info( name_mappings, data_to_key_map, self.flat_graph.data_memo ) else: # Remove unused models and cached models - all_models = list(self.flat_graph.get_models()) + all_models: list[BaseModel] = list(self.flat_graph.get_models()) for key in self.flat_graph.unused_keys | self.flat_graph.cached_data.keys(): if ( unused_model := self.flat_graph.connections.get(key) @@ -801,7 +785,7 @@ def summary( # if verbose, find the name of the model and create the table object and # display it based on extracted infos if name is None: - name = model.__class__.__name__ if model else self.__class__.__name__ + name = model.class_name if model else self.__class__.__name__ table = get_summary( conns=conn_info, name=name, @@ -814,7 +798,7 @@ def summary( table.display() if depth > 0: for model, model_name in name_mappings.items(): - if not isinstance(model, PrimitiveModel): + if not isinstance(model, Operator): self.summary( model=model, depth=depth - 1, @@ -827,12 +811,12 @@ def summary( ) def extract_connection_info( - self, name_mappings: dict[PrimitiveModel, str] | None = None + self, name_mappings: dict[Operator, str] | None = None ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: if name_mappings is None: - name_mappings = define_unique_names(self.flat_graph.get_models()) + name_mappings = define_unique_names(self.flat_graph.get_models()) # type: ignore conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} - + assert name_mappings is not None for model, model_name in name_mappings.items(): conn_info.setdefault(model_name, ({}, {})) model_node = self.flat_graph.nodes[model] @@ -899,7 +883,7 @@ def _step_random_seed_values(self) -> None: def _replace_with_primitive( self, model: Model, key_mappings: dict[str, str] - ) -> tuple[PrimitiveModel, dict[str, str]]: + ) -> tuple[Operator, dict[str, str]]: assert model.formula_key is not None formula = self.backend.primitive_function_dict[model.formula_key] primitive_input_keys = formula.__code__.co_varnames[ @@ -929,9 +913,7 @@ def _replace_with_primitive( kwargs = {key: model.conns.all[key].metadata for key in external_keys} - primitive = PrimitiveModel( - formula_key=model.formula_key, name=model.name, **kwargs - ) + primitive = Operator(formula_key=model.formula_key, name=model.name, **kwargs) primitive.parent = model.parent p_key_mappings: dict[str, str] = {} @@ -1040,7 +1022,7 @@ def __init__( short_namings (bool): Flag to determine if short namings should be used. """ - self.mappings: dict[PrimitiveModel, dict[str, Name]] = {} + self.mappings: dict[Operator, dict[str, Name]] = {} self.assigned_edges: dict[IOHyperEdge, Name] = {} self.assigned_names: dict[str, Name] = {} self.external_edges: dict[IOHyperEdge, str] = {} @@ -1048,7 +1030,7 @@ def __init__( self.key_origins: dict[str, int] = {} self.reserved_keys: set[str] = reserved_keys if reserved_keys else set() self.queued_models: dict[ - IOHyperEdge, list[tuple[PrimitiveModel, dict[str, str], str]] + IOHyperEdge, list[tuple[Operator, dict[str, str], str]] ] = {} self._external_mapping: dict[str, Name] = {} self.model = model @@ -1202,7 +1184,7 @@ def generate_keys( if mappings is None: mappings = {} - if isinstance(model, PrimitiveModel): + if isinstance(model, Operator): if not self._is_primitive_ready(model): self._add_primitive_to_queue(model, mappings, parent_name) return @@ -1212,16 +1194,16 @@ def generate_keys( elif isinstance(model, Model): self._process_model(model, mappings, parent_name) else: - raise ValueError("Model must be either PrimitiveModel or Model") + raise ValueError("Model must be either Operator or Model") def _process_primitive_model( - self, model: PrimitiveModel, mappings: dict[str, str], parent_name: str + self, model: Operator, mappings: dict[str, str], parent_name: str ) -> None: """ Process a primitive model. Args: - model (PrimitiveModel): The primitive model. + model (Operator): The primitive model. mappings (dict[str, str]): The mappings of keys. """ @@ -1246,9 +1228,11 @@ def _process_primitive_model( self.assigned_edges[conn.metadata] = name self.mappings[model][key] = name - output_edge = model.output.metadata - self.used_edges.add(output_edge) - self._check_for_queue(output_edge) + # output_edge = model.output.metadata + output_con = model.conns.get_connection("output") + assert output_con is not None + self.used_edges.add(output_con.metadata) + self._check_for_queue(output_con.metadata) def _process_model( self, model: Model, mappings: dict[str, str], parent_name: str @@ -1292,12 +1276,12 @@ def _check_for_queue(self, hyperedge: IOHyperEdge) -> None: m, mappings=mappings, parent_name=parent_name ) - def _is_primitive_ready(self, model: PrimitiveModel) -> bool: + def _is_primitive_ready(self, model: Operator) -> bool: """ Check if a primitive model is ready to be processed. Args: - model (PrimitiveModel): The primitive model. + model (Operator): The primitive model. Returns: bool: True if the model is ready, False otherwise. @@ -1309,13 +1293,13 @@ def _is_primitive_ready(self, model: PrimitiveModel) -> bool: return True def _add_primitive_to_queue( - self, model: PrimitiveModel, mappings: dict[str, str], parent_name: str + self, model: Operator, mappings: dict[str, str], parent_name: str ) -> None: """ Add a primitive model to the queue. Args: - model (PrimitiveModel): The primitive model. + model (Operator): The primitive model. input_edges (set[IOHyperEdge]): The input edges. mappings (dict[str, str]): The mappings of keys. """ @@ -1377,6 +1361,6 @@ def __iter__(self) -> FlatModel: self._iter = iter(self.mappings.items()) return self - def __next__(self) -> tuple[PrimitiveModel, dict[str, str]]: + def __next__(self) -> tuple[Operator, dict[str, str]]: model, mapping = next(self._iter) return model, {key: name.name for key, name in mapping.items()} diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 82b053d7..445dbe2c 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -12,40 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Iterable +from collections.abc import Callable from functools import reduce from itertools import product from types import FunctionType, GenericAlias, UnionType -from typing import TYPE_CHECKING, Any, TypeVar - -if TYPE_CHECKING: - from .logical.base import BaseModel - - -T = TypeVar("T", bound="BaseModel") - - -def define_unique_names(models: Iterable[T]) -> dict[T, str]: - # TODO: Move this to Physical model (currently it is only used there) - # TODO: Also add short-naming logic to this function - model_name_dict: dict[T, str] = {} - single_model_dict: dict[str, T] = {} - model_count_dict: dict[str, int] = {} - - for model in models: - class_name = model.__class__.__name__ - if model_count_dict.setdefault(class_name, 0) == 0: - single_model_dict[class_name] = model - else: - single_model_dict.pop(class_name, None) - model_name_dict[model] = ( - str(class_name) + "_" + str(model_count_dict[class_name]) - ) - model_count_dict[class_name] += 1 - - for m in single_model_dict.values(): - model_name_dict[m] = str(m.__class__.__name__) - return model_name_dict +from typing import Any def align_shapes(all_dicts: list[dict[Any, Any]]) -> None: diff --git a/mithril/models/models.py b/mithril/models/models.py index 1d93660a..b621b48c 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -19,83 +19,82 @@ from collections.abc import Sequence from copy import deepcopy -from ..framework import Model from ..framework.common import ( NOT_GIVEN, TBD, - Connection, - ConnectionType, - IOKey, MainValueType, ShapeTemplateType, Tensor, ToBeDetermined, ) from ..framework.constraints import polynomial_kernel_constraint -from ..framework.logical.base import BaseModel, ExtendInfo -from ..framework.logical.essential_primitives import ( - Absolute, - Add, - ArgMax, - Buffer, - Cast, - Divide, - Dtype, - Exponential, - Greater, - Indexer, - Length, - MatrixMultiply, - Mean, - Minus, - Multiply, - Power, - Reshape, - Shape, - Size, - Slice, - Sqrt, - Subtract, - Sum, - Transpose, - Variance, +from ..framework.logical.model import ( + Connection, + ConnectionType, + ExtendInfo, + IOKey, + Model, ) -from ..framework.logical.primitive import PrimitiveModel from ..utils.utils import PaddingType, convert_to_list, convert_to_tuple from .primitives import ( + Absolute, + Add, + ArgMax, AUCCore, + Buffer, CartesianDifference, + Cast, Cholesky, Concat, DistanceMatrix, + Divide, + Dtype, Eigvalsh, + Exponential, Eye, EyeComplement, GPRAlpha, GPRVOuter, + Greater, + Indexer, KLDivergence, + Length, Log, + MatrixMultiply, + Mean, + Minus, + Multiply, NormModifier, PaddingConverter1D, PaddingConverter2D, PermuteTensor, PolynomialFeatures, + Power, PrimitiveConvolution1D, PrimitiveConvolution2D, PrimitiveMaxPool1D, PrimitiveMaxPool2D, + Reshape, + Shape, Sigmoid, Sign, + Size, + Slice, + Sqrt, Square, Squeeze, StableReciprocal, StrideConverter, + Subtract, + Sum, Tanh, + Transpose, TransposedDiagonal, Trapezoid, TsnePJoint, TupleConverter, Unique, + Variance, Where, ) @@ -155,7 +154,7 @@ class Pool1D(Model): output: Connection @property - def pool_model(self) -> type[BaseModel]: + def pool_model(self) -> type[Model]: raise NotImplementedError("Pool Model should be indicated!") def __init__( @@ -222,7 +221,7 @@ def __call__( # type: ignore[override] class MaxPool1D(Pool1D): @property - def pool_model(self) -> type[PrimitiveModel]: + def pool_model(self) -> type[Model]: return PrimitiveMaxPool1D @@ -247,7 +246,7 @@ class Pool2D(Model): output: Connection @property - def pool_model(self) -> type[BaseModel]: + def pool_model(self) -> type[Model]: raise NotImplementedError("Pool Model should be indicated!") def __init__( @@ -324,7 +323,7 @@ def __call__( # type: ignore[override] class MaxPool2D(Pool2D): @property - def pool_model(self) -> type[PrimitiveModel]: + def pool_model(self) -> type[Model]: return PrimitiveMaxPool2D @@ -628,7 +627,7 @@ class Layer(Model): def __init__( self, - activation: BaseModel, + activation: Model, dimension: int | None = None, input: Tensor[int | float | bool] | ToBeDetermined = TBD, weight: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -1086,7 +1085,7 @@ class KernelizedSVM(Model): def __init__( self, - kernel: BaseModel, + kernel: Model, weight: Tensor[int | float | bool] | ToBeDetermined = TBD, bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, @@ -1257,7 +1256,7 @@ class MLP(Model): def __init__( self, - activations: list[BaseModel], + activations: list[Model], dimensions: Sequence[int | None], input_name_templates: dict[str, str] | None = None, input: Tensor[int | float | bool] | ToBeDetermined = TBD, diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index dc0134fc..ef4e37a0 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -14,16 +14,19 @@ from __future__ import annotations +from collections.abc import Sequence from types import NoneType +from typing import Any -from ..core import Constant, Dtype +from .. import core +from ..core import Constant from ..framework.common import ( NOT_GIVEN, TBD, BaseKey, - Connection, - ConnectionType, + ScalarValueType, Tensor, + TensorValueType, ToBeDetermined, ) from ..framework.constraints import ( @@ -51,9 +54,63 @@ tuple_converter_constraint, where_constrains, ) -from ..framework.logical.base import BaseModel -from ..framework.logical.essential_primitives import SingleInputOperation -from ..models import ExtendInfo, PrimitiveModel +from ..framework.logical import Model +from ..framework.logical.model import Connection, ConnectionType, ExtendInfo +from ..framework.logical.operator import Operator +from ..framework.logical.operators import ( + AbsoluteOp, + AddOp, + ArgMaxOp, + ArgMinOp, + BufferOp, + CastOp, + CosineOp, + DivideOp, + DtypeOp, + EqualOp, + ExponentialOp, + FloorDivideOp, + GreaterEqualOp, + GreaterOp, + IndexerOp, + ItemOp, + LengthOp, + LessEqualOp, + LessOp, + LogicalAndOp, + LogicalNotOp, + LogicalOrOp, + LogicalXOrOp, + MatrixMultiplyOp, + MaximumOp, + MaxOp, + MeanOp, + MinimumOp, + MinOp, + MinusOp, + MultiplyOp, + NotEqualOp, + PowerOp, + ProdOp, + ReshapeOp, + ShapeOp, + ShiftLeftOp, + ShiftRightOp, + SineOp, + SizeOp, + SliceOp, + SplitOp, + SqrtOp, + SubtractOp, + SumOp, + TensorToListOp, + ToListOp, + ToTensorOp, + ToTupleOp, + TransposeOp, + VarianceOp, +) +from ..framework.logical.primitive import OperatorModel, PrimitiveModel from ..utils.utils import PaddingType __all__ = [ @@ -118,8 +175,59 @@ "Trapezoid", "Pad", "Randn", + "PrimitiveModel", + "Buffer", + "ToTuple", + "Power", + "Add", + "Subtract", + "Multiply", + "Divide", + "FloorDivide", + "Minus", + "MatrixMultiply", + "Shape", + "Reshape", + "Length", + "Size", + "Exponential", + "Item", + "Indexer", + "ToTensor", + "ToList", + "TensorToList", + "Mean", + "Sum", + "Max", + "Min", + "Prod", + "Variance", + "Absolute", + "Equal", + "NotEqual", + "Greater", + "GreaterEqual", + "Less", + "LessEqual", + "LogicalNot", + "LogicalOr", + "LogicalAnd", + "LogicalXOr", + "ShiftLeft", + "ShiftRight", + "ArgMax", + "ArgMin", + "Cast", + "Transpose", + "Sqrt", + "Split", + "Slice", + "Dtype", + "Sine", + "Cosine", + "Minimum", + "Maximum", ] - # Define types used to define keys: ConstantType = float | int | Constant @@ -139,7 +247,7 @@ class SupervisedLoss(PrimitiveModel): Parameters ---------- - PrimitiveModel : _type_ + Operator : _type_ _description_ """ @@ -167,22 +275,22 @@ def __init__( # Set constraints. bcast_constraint = self._add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "input", "target"] + fn=bcast, keys=[Operator.output_key, "input", "target"] ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "input", "target"], + keys=[Operator.output_key, "input", "target"], dependencies={bcast_constraint}, ) if polymorphic_constraint: self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "input", "target"], + keys=[Operator.output_key, "input", "target"], ) - self.safe_shapes = { + self.submodel.safe_shapes = { "output": ["N", ("Var", ...)], "input": ["N", ("Var", ...)], "target": ["N", ("Var", ...)], @@ -287,21 +395,21 @@ def __init__( ) bcast_constraint = self._add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "input", "target"] + fn=bcast, keys=[Operator.output_key, "input", "target"] ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "input", "target"], + keys=[Operator.output_key, "input", "target"], dependencies={bcast_constraint}, ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "input", "target", "quantile"], + keys=[Operator.output_key, "input", "target", "quantile"], ) - self.safe_shapes = { + self.submodel.safe_shapes = { "output": ["N", ("Var", ...)], "input": ["N", ("Var", ...)], "target": ["N", ("Var", ...)], @@ -409,7 +517,7 @@ def __call__( # type: ignore[override] "output": output, } # Check if the given argument set is valid. - if self.formula_key == "cross_entropy_with_log_probs": + if self.submodel.formula_key == "cross_entropy_with_log_probs": args: list[str] = [] if robust is not False: args.append("robust") @@ -455,18 +563,18 @@ def __init__( cutoff=BaseKey(shape=[], type=Tensor, value=cutoff), ) - self.safe_shapes = { + self.submodel.safe_shapes = { "output": ["N", ("Var", ...)], "input": ["N", ("Var", ...)], "target": ["N", ("Var", ...)], } bcast_constraint = self._add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "input", "target"] + fn=bcast, keys=[Operator.output_key, "input", "target"] ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "input", "target"], + keys=[Operator.output_key, "input", "target"], dependencies={bcast_constraint}, ) @@ -547,12 +655,12 @@ def __init__( super().__init__(formula_key=formula_key, name=name, **kwargs) bcast_constraint = self._add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "input", "target"] + fn=bcast, keys=[Operator.output_key, "input", "target"] ) self._add_constraint( fn=bcast_error_check, - keys=[PrimitiveModel.output_key, "input", "target"], + keys=[Operator.output_key, "input", "target"], dependencies={bcast_constraint}, ) @@ -666,7 +774,10 @@ def __call__( # type: ignore[override] return super().__call__(input=input, cutoff=cutoff, output=output) -class Sign(SingleInputOperation): +class Sign(PrimitiveModel): + input: Connection + output: Connection + def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, @@ -676,20 +787,37 @@ def __init__( super().__init__( formula_key="sign", name=name, - polymorphic_constraint=False, - input=input, - output=BaseKey(shape=[("Var", ...)], type=Tensor[int]), + output=BaseKey(shape=[("Var", ...)], type=Tensor[float]), + input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class Square(PrimitiveModel): + input: Connection + output: Connection -class Square(SingleInputOperation): def __init__( self, input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: - super().__init__(formula_key="square", name=name, input=input) + super().__init__( + formula_key="square", + name=name, + output=BaseKey(shape=[("Var", ...)], type=Tensor[float]), + input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), + ) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) ############################# Activation Types ############################## @@ -716,12 +844,12 @@ def __init__( ) # Finalize kwargs. kwargs = default_kwargs | kwargs - super().__init__(formula_key, name=name, **kwargs) + super().__init__(name=name, formula_key=formula_key, **kwargs) if polymorphic_constraint: self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "input"], + keys=[Operator.output_key, "input"], ) def __call__( # type: ignore[override] @@ -767,9 +895,7 @@ def __call__( # type: ignore[override] approximate: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: - return BaseModel.__call__( - self, input=input, approximate=approximate, output=output - ) + return Model.__call__(self, input=input, approximate=approximate, output=output) class Sigmoid(Activation): @@ -799,7 +925,7 @@ def __call__( # type: ignore[override] axis: ConnectionType = -1, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: - return BaseModel.__call__(self, input=input, axis=axis, output=output) + return Model.__call__(self, input=input, axis=axis, output=output) class Softplus(Activation): @@ -847,7 +973,7 @@ def __call__( # type: ignore[override] slope: ConnectionType = 0.01, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: - return PrimitiveModel.__call__(self, input=input, slope=slope, output=output) + return Model.__call__(self, input=input, slope=slope, output=output) class StopGradient(PrimitiveModel): @@ -860,9 +986,10 @@ def __init__( *, name: str | None = None, ) -> None: + # super().__init__(formula_key="stop_gradient", name=name) super().__init__( - formula_key="stop_gradient", name=name, + formula_key="stop_gradient", output=BaseKey(shape=[("Var", ...)], type=Tensor), input=BaseKey(shape=[("Var", ...)], type=Tensor, value=input), ) @@ -885,16 +1012,17 @@ def __init__( *, name: str | None = None, ) -> None: + # super().__init__(formula_key="cartesian_diff", name=name) super().__init__( - formula_key="cartesian_diff", name=name, + formula_key="cartesian_diff", output=BaseKey(shape=["N", "M", "dim"], type=Tensor), left=BaseKey(shape=["N", "dim"], type=Tensor, value=left), right=BaseKey(shape=["M", "dim"], type=Tensor, value=right), ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "left", "right"], + keys=[Operator.output_key, "left", "right"], ) def __call__( # type: ignore[override] @@ -939,7 +1067,7 @@ def __init__( ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key] + input_keys, + keys=[Operator.output_key] + input_keys, ) @@ -991,7 +1119,7 @@ def __init__( ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) self.indices.set_differentiable(False) @@ -1058,7 +1186,7 @@ def __init__( self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key] + constraint_keys, + keys=[Operator.output_key] + constraint_keys, ) def __call__( # type: ignore[override] @@ -1082,7 +1210,7 @@ def __call__( # type: ignore[override] } if "bias" not in self.input_keys and bias != NOT_GIVEN: - raise ValueError(f"Model does not have 'bias' input. \ + raise ValueError(f"Operator does not have 'bias' input. \ Got {bias} as bias argument!") elif "bias" in self.input_keys: kwargs |= {"bias": bias} @@ -1139,7 +1267,7 @@ def __init__( formula_key = "conv2d" kwargs.pop("bias") - super().__init__(formula_key, name=name, **kwargs) + super().__init__(formula_key=formula_key, name=name, **kwargs) self._add_constraint( fn=conv_2d_constraints, @@ -1151,7 +1279,7 @@ def __init__( constraint_keys.append("bias") self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key] + constraint_keys, + keys=[Operator.output_key] + constraint_keys, ) def __call__( # type: ignore[override] @@ -1176,7 +1304,7 @@ def __call__( # type: ignore[override] if "bias" not in self.input_keys and bias != NOT_GIVEN: raise ValueError( - f"Model does not have 'bias' input. Got {bias} as bias argument!" + "Operator does not have 'bias' input." " Got {bias} as bias argument!" ) elif "bias" in self.input_keys: kwargs |= {"bias": bias} @@ -1209,10 +1337,10 @@ def __init__( self._add_constraint( fn=flatten_constrains, - keys=[PrimitiveModel.output_key, "input", "start_dim", "end_dim"], + keys=[Operator.output_key, "input", "start_dim", "end_dim"], ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1261,7 +1389,7 @@ def __init__( ) # TODO: Torch does not accept any int type inputs but JAX implementation does. self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1305,7 +1433,7 @@ def __init__( self._add_constraint( fn=padding_1d_constraint, - keys=[PrimitiveModel.output_key, "input", "kernel_size"], + keys=[Operator.output_key, "input", "kernel_size"], ) def __call__( # type: ignore[override] @@ -1351,7 +1479,7 @@ def __init__( self._add_constraint( fn=padding_2d_constraint, - keys=[PrimitiveModel.output_key, "input", "kernel_size"], + keys=[Operator.output_key, "input", "kernel_size"], ) def __call__( # type: ignore[override] @@ -1384,7 +1512,7 @@ def __init__( ) self._add_constraint( fn=stride_constraint, - keys=[PrimitiveModel.output_key, "input", "kernel_size"], + keys=[Operator.output_key, "input", "kernel_size"], ) def __call__( # type: ignore[override] @@ -1426,7 +1554,7 @@ def __init__( ), ) self._add_constraint( - fn=tuple_converter_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=tuple_converter_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1479,7 +1607,7 @@ def __init__( keys=["output", "input", "stride", "padding", "dilation", "kernel_size"], ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1537,7 +1665,7 @@ def __init__( ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1571,7 +1699,7 @@ def __init__( self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "left", "right", "norm"], + keys=[Operator.output_key, "left", "right", "norm"], ) def __call__( # type: ignore[override] @@ -1608,7 +1736,7 @@ def __init__( fn=polynomial_features_constraints, keys=["output", "input", "degree"] ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1674,7 +1802,7 @@ def __init__( self, N: int | ToBeDetermined = TBD, M: int | ToBeDetermined | None = None, - dtype: Dtype | None = None, + dtype: core.Dtype | None = None, *, name: str | None = None, ) -> None: @@ -1684,7 +1812,7 @@ def __init__( output=BaseKey(shape=["N", "M"], type=Tensor[float]), N=BaseKey(type=int, value=N), M=BaseKey(type=int | None, value=M), - dtype=BaseKey(type=Dtype | None, value=dtype), + dtype=BaseKey(type=core.Dtype | None, value=dtype), ) self._add_constraint(fn=eye_constraints, keys=["output", "N", "M"]) @@ -1708,7 +1836,7 @@ def __init__( self, N: int | ToBeDetermined = TBD, M: int | ToBeDetermined | None = None, - dtype: Dtype | None = None, + dtype: core.Dtype | None = None, *, name: str | None = None, ) -> None: @@ -1718,7 +1846,7 @@ def __init__( output=BaseKey(shape=["N", "M"], type=Tensor[float]), N=BaseKey(type=int, value=N), M=BaseKey(type=int | None, value=M), - dtype=BaseKey(type=Dtype | None, value=dtype), + dtype=BaseKey(type=core.Dtype | None, value=dtype), ) self._add_constraint(fn=eye_constraints, keys=["output", "N", "M"]) @@ -1841,7 +1969,7 @@ def __init__( ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -1862,7 +1990,7 @@ def __init__( start: int | float | ToBeDetermined = 0, stop: int | float | ToBeDetermined = TBD, step: int | float | ToBeDetermined = 1, - dtype: Dtype | None = None, + dtype: core.Dtype | None = None, *, name: str | None = None, ) -> None: @@ -1890,7 +2018,7 @@ def __init__( start=BaseKey(type=int | float, value=start), stop=BaseKey(type=int | float, value=stop), step=BaseKey(type=int | float, value=step), - dtype=BaseKey(type=Dtype | None, value=dtype), + dtype=BaseKey(type=core.Dtype | None, value=dtype), ) # self.set_canonical_input("stop") self.set_cin("stop", safe=False) @@ -1901,7 +2029,7 @@ def __init__( ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "start", "stop", "step"], + keys=[Operator.output_key, "start", "stop", "step"], ) def __call__( # type: ignore[override] @@ -1927,7 +2055,7 @@ def __init__( self, shape: tuple[int, ...] | ToBeDetermined = TBD, key: int | ToBeDetermined = TBD, - dtype: Dtype | None = None, + dtype: core.Dtype | None = None, *, name: str | None = None, ) -> None: @@ -1937,10 +2065,12 @@ def __init__( output=BaseKey(shape=[("output", ...)], type=Tensor), shape=BaseKey(type=tuple[int, ...], value=shape), key=BaseKey(type=int, value=key), - dtype=BaseKey(type=Dtype | None, value=dtype), + dtype=BaseKey(type=core.Dtype | None, value=dtype), ) - self.random_keys.add("key") + self.submodel.random_keys.add( + "key" + ) # since random_keys must be in primitive models self.add_constraint(randn_constraints, keys=["output", "shape"]) def __call__( # type: ignore[override] @@ -1977,7 +2107,7 @@ def __init__( fn=broadcast_to_constraints, keys=["output", "shape", "input"] ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -2041,7 +2171,7 @@ def __init__( self._add_constraint(fn=squeeze_constraints, keys=["output", "input"]) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -2105,7 +2235,7 @@ def __init__( self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "weight"], + keys=[Operator.output_key, "weight"], ) def __call__( # type: ignore[override] @@ -2186,7 +2316,7 @@ def __call__( # type: ignore[override] and attn_mask.value is not None # TODO: Here will be updated! ): raise KeyError( - "Model does not have 'attn_mask' input. Got attn_mask argument!" + "Operator does not have 'attn_mask' input." " Got attn_mask argument!" ) return super().__call__( @@ -2268,7 +2398,7 @@ def __init__( fn=swap_axes_constraints, keys=["output", "input", "axis1", "axis2"] ) self._add_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -2309,7 +2439,7 @@ def __init__( ) self._add_constraint( fn=general_tensor_type_constraint, - keys=[PrimitiveModel.output_key, "input1", "input2"], + keys=[Operator.output_key, "input1", "input2"], ) def __call__( # type: ignore[override] @@ -2457,7 +2587,7 @@ def __init__( ) self._add_constraint( - fn=pad_constraints, keys=[PrimitiveModel.output_key, "input", "pad_width"] + fn=pad_constraints, keys=[Operator.output_key, "input", "pad_width"] ) def __call__( # type: ignore[override] @@ -2490,3 +2620,967 @@ def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN ) -> ExtendInfo: return super().__call__(input=input, output=output) + + +class Buffer(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=BufferOp(input=input)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class ToTuple(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + n: int, + *, + name: str | None = None, + **kwargs: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined, + ) -> None: + super().__init__(name=name, model=ToTupleOp(n, **kwargs)) + + +class ArithmeticOperation(OperatorModel): + left: Connection + right: Connection + output: Connection + + def __init__( + self, + model: Operator, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=model) + + def __call__( # type: ignore[override] + self, + left: ConnectionType = NOT_GIVEN, + right: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(left=left, right=right, output=output) + + +class Power(OperatorModel): + base: Connection + exponent: Connection + output: Connection + + def __init__( + self, + robust: bool = False, + base: Tensor[int | float | bool] | int | float | ToBeDetermined = TBD, + exponent: Tensor[int | float | bool] | int | float | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.robust = robust + m = PowerOp(robust=robust, base=base, exponent=exponent) + super().__init__(name=name, model=m) + + def __call__( # type: ignore[override] + self, + base: ConnectionType = NOT_GIVEN, + exponent: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + *, + threshold: ConnectionType = core.Constant.MIN_POSITIVE_NORMAL, + ) -> ExtendInfo: + kwargs = {"base": base, "exponent": exponent, "output": output} + default = ( + isinstance(threshold, core.Constant) + and threshold == core.Constant.MIN_POSITIVE_NORMAL + ) + if self.robust: + # NOTE: Since we can not provide Tensor objects as default + # arguments, we need to convert default value. + if default: + threshold = Tensor(threshold) # type: ignore + kwargs["threshold"] = threshold + elif not default: + raise ValueError("Threshold cannot be specified when robust mode is off") + + return super().__call__(**kwargs) + + +class Add(ArithmeticOperation): + def __init__( + self, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(AddOp(left=left, right=right), name=name) + + +class Subtract(ArithmeticOperation): + def __init__( + self, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(SubtractOp(left=left, right=right), name=name) + + +class Multiply(ArithmeticOperation): + def __init__( + self, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(MultiplyOp(left=left, right=right), name=name) + + +class Minimum(ArithmeticOperation): + def __init__( + self, + left: TensorValueType | ToBeDetermined = TBD, + right: TensorValueType | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(MinimumOp(left=left, right=right), name=name) + + +class Maximum(ArithmeticOperation): + def __init__( + self, + left: TensorValueType | ToBeDetermined = TBD, + right: TensorValueType | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(MaximumOp(left=left, right=right), name=name) + + +class Divide(OperatorModel): + numerator: Connection + denominator: Connection + output: Connection + + def __init__( + self, + numerator: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + denominator: Tensor[int | float | bool] + | ScalarValueType + | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + m = DivideOp(numerator=numerator, denominator=denominator) + super().__init__(name=name, model=m) + + def __call__( # type: ignore[override] + self, + numerator: ConnectionType = NOT_GIVEN, + denominator: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__( + numerator=numerator, denominator=denominator, output=output + ) + + +class FloorDivide(OperatorModel): + numerator: Connection + denominator: Connection + output: Connection + + def __init__( + self, + numerator: Tensor[int | float | bool] | ToBeDetermined = TBD, + denominator: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + m = FloorDivideOp(numerator=numerator, denominator=denominator) + super().__init__(name=name, model=m) + + def __call__( # type: ignore[override] + self, + numerator: ConnectionType = NOT_GIVEN, + denominator: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__( + numerator=numerator, denominator=denominator, output=output + ) + + +class MatrixMultiply(OperatorModel): + left: Connection + right: Connection + output: Connection + + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=MatrixMultiplyOp(left=left, right=right)) + + def __call__( # type: ignore[override] + self, + left: ConnectionType = NOT_GIVEN, + right: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(left=left, right=right, output=output) + + +class Shape(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ShapeOp(input=input)) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class Reshape(OperatorModel): + input: Connection + shape: Connection + output: Connection + + def __init__( + self, + shape: tuple[int | None, ...] | list[int] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ReshapeOp(shape=shape, input=input)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + shape: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, shape=shape, output=output) + + +class Length(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LengthOp(input=input)) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class Cast(OperatorModel): + input: Connection + dtype: Connection + output: Connection + + def __init__( + self, dtype: core.Dtype | ToBeDetermined = TBD, *, name: str | None = None + ) -> None: + super().__init__(name=name, model=CastOp(dtype=dtype)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + dtype: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, dtype=dtype, output=output) + + +class Dtype(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=DtypeOp(input=input)) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class Size(OperatorModel): + input: Connection + dim: Connection + output: Connection + + def __init__( + self, + dim: int | tuple[int, ...] | None | ToBeDetermined = None, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=SizeOp(input=input, dim=dim)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + dim: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, dim=dim, output=output) + + +class Item(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ItemOp(input=input)) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class ToTensor(OperatorModel): + input: Connection + dtype: Connection + output: Connection + + def __init__( + self, + input: TensorValueType | ToBeDetermined = TBD, + dtype: core.Dtype | None = None, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ToTensorOp(input=input, dtype=dtype)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + dtype: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, dtype=dtype, output=output) + + +class ToList(OperatorModel): + output: Connection + + def __init__( + self, + n: int, + *, + name: str | None = None, + **kwargs: ScalarValueType | ToBeDetermined, + ) -> None: + super().__init__(name=name, model=ToListOp(n, name=name, **kwargs)) + + +class TensorToList(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self._enforce_jit = False + m = TensorToListOp(input=input) + m._enforce_jit = False + super().__init__(name=name, model=m) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class Reduce(OperatorModel): + input: Connection + axis: Connection + keepdim: Connection + output: Connection + + def __init__(self, model: Operator, *, name: str | None = None) -> None: + super().__init__(name=name, model=model) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + axis: ConnectionType = NOT_GIVEN, + keepdim: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, axis=axis, keepdim=keepdim, output=output) + + +class Mean(Reduce): + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(MeanOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class Sum(Reduce): + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(SumOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class Max(Reduce): + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(MaxOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class ArgMax(Reduce): + def __init__( + self, + axis: int | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(ArgMaxOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class Min(Reduce): + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(MinOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class ArgMin(Reduce): + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(ArgMinOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class Prod(Reduce): + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim} + super().__init__(ProdOp(axis=axis, keepdim=keepdim, input=input), name=name) + + +class Variance(Reduce): + correction: Connection + + def __init__( + self, + axis: int | tuple[int, ...] | None | ToBeDetermined = None, + keepdim: bool | ToBeDetermined = False, + correction: int | float | None = 0.0, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + self.factory_args = {"axis": axis, "keepdim": keepdim, "correction": correction} + super().__init__( + VarianceOp(axis=axis, keepdim=keepdim, input=input, correction=correction), + name=name, + ) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + axis: ConnectionType = NOT_GIVEN, + keepdim: ConnectionType = NOT_GIVEN, + correction: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super(Reduce, self).__call__( + input=input, + axis=axis, + keepdim=keepdim, + correction=correction, + output=output, + ) + + +class SingleInputModel(OperatorModel): + input: Connection + output: Connection + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class Absolute(SingleInputModel): + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=AbsoluteOp(input=input)) + + +class Minus(SingleInputModel): + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=MinusOp(input=input)) + + +class Exponential(SingleInputModel): + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ExponentialOp(input=input)) + + +class Sqrt(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + robust: bool = False, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + cutoff: Tensor[int | float | bool] | ToBeDetermined = TBD, + name: str | None = None, + ) -> None: + self.robust = robust + m = SqrtOp(robust=robust, input=input, cutoff=cutoff) + super().__init__(name=name, model=m) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + *, + cutoff: ConnectionType = core.Constant.MIN_POSITIVE_NORMAL, + ) -> ExtendInfo: + kwargs = {"input": input, "output": output} + + default = ( + isinstance(cutoff, core.Constant) + and cutoff == core.Constant.MIN_POSITIVE_NORMAL + ) + if self.robust: + if default: + # NOTE: Since we can not provide Tensor objects as default + # arguments, we need to convert default value. + cutoff = Tensor(cutoff) # type: ignore + kwargs["cutoff"] = cutoff + elif not default: + raise ValueError("Cutoff cannot be specified when robust mode is off") + + return super().__call__(**kwargs) + + +class RelationalModel(OperatorModel): + left: Connection + right: Connection + output: Connection + + def __call__( # type: ignore[override] + self, + left: ConnectionType = NOT_GIVEN, + right: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(left=left, right=right, output=output) + + +class Greater(RelationalModel): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=GreaterOp(left=left, right=right)) + + +class Less(RelationalModel): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LessOp(left=left, right=right)) + + +class Equal(RelationalModel): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=EqualOp(left=left, right=right)) + + +class NotEqual(RelationalModel): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=NotEqualOp(left=left, right=right)) + + +class LessEqual(RelationalModel): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LessEqualOp(left=left, right=right)) + + +class GreaterEqual(RelationalModel): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=GreaterEqualOp(left=left, right=right)) + + +class LogicalNot(OperatorModel): + input: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LogicalNotOp(input=input)) + + def __call__( # type: ignore[override] + self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN + ) -> ExtendInfo: + return super().__call__(input=input, output=output) + + +class BitwiseOperators(OperatorModel): + left: Connection + right: Connection + output: Connection + + def __call__( # type: ignore[override] + self, + left: ConnectionType = NOT_GIVEN, + right: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(left=left, right=right, output=output) + + +class LogicalAnd(BitwiseOperators): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LogicalAndOp(left=left, right=right)) + + +class LogicalOr(BitwiseOperators): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LogicalOrOp(left=left, right=right)) + + +class LogicalXOr(BitwiseOperators): + def __init__( + self, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=LogicalXOrOp(left=left, right=right)) + + +class ShiftLeft(OperatorModel): + input: Connection + shift: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + shift: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ShiftLeftOp(input=input, shift=shift)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + shift: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, shift=shift, output=output) + + +class ShiftRight(OperatorModel): + input: Connection + shift: Connection + output: Connection + + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + shift: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=ShiftRightOp(input=input, shift=shift)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + shift: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, shift=shift, output=output) + + +class Transpose(OperatorModel): + # NOTE: Consider if axes type list[int] is conventionally True since it is generally + # used tuple[int] in these type of cases + input: Connection + axes: Connection + output: Connection + + def __init__( + self, + axes: int | list[int] | tuple[int, ...] | None | ToBeDetermined = None, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=TransposeOp(input=input, axes=axes)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + axes: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, axes=axes, output=output) + + +class Split(OperatorModel): + split_size: Connection + axis: Connection + input: Connection + output: Connection + + def __init__( + self, + split_size: int, # TODO: should we add default for split_size? + axis: int = 0, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ): + m = SplitOp(split_size=split_size, axis=axis, input=input) + super().__init__(name=name, model=m) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + split_size: ConnectionType = NOT_GIVEN, + axis: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__( + input=input, split_size=split_size, axis=axis, output=output + ) + + +class Slice(OperatorModel): + start: Connection + stop: Connection + step: Connection + output: Connection + + def __init__( + self, + start: int | None | ToBeDetermined = TBD, + stop: int | None | ToBeDetermined = TBD, + step: int | None | ToBeDetermined = TBD, + *, + name: str | None = None, + ): + super().__init__(name=name, model=SliceOp(start=start, stop=stop, step=step)) + + def __call__( # type: ignore[override] + self, + start: ConnectionType = NOT_GIVEN, + stop: ConnectionType = NOT_GIVEN, + step: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(start=start, stop=stop, step=step, output=output) + + +class Indexer(OperatorModel): + input: Connection + index: Connection + output: Connection + + def __init__( + self, + index: int | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | Sequence[Any] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=IndexerOp(input=input, index=index)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + index: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, index=index, output=output) + + +class Sine(SingleInputModel): + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=SineOp(input=input)) + + +class Cosine(SingleInputModel): + def __init__( + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name, model=CosineOp(input=input)) diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index b7d7903e..ebf8e9dc 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -19,15 +19,14 @@ from copy import deepcopy from typing import Any, Self, TypedDict -from ..framework import BaseModel, ExtendInfo, Model +from ..framework import BaseModel from ..framework.common import ( NOT_GIVEN, TBD, - Connection, + BaseKey, ConnectionData, - ConnectionType, + ConnectionDataType, IOHyperEdge, - IOKey, KeyType, Table, Tensor, @@ -36,8 +35,16 @@ get_shapes, get_summary_shapes, ) -from ..framework.logical import ( +from ..framework.logical.model import ( + Connection, + ExtendInfo, + IOKey, + Model, +) +from ..framework.physical.model import FinalCost, LossKey +from .primitives import ( Buffer, + Concat, Divide, Max, Mean, @@ -49,8 +56,6 @@ Sum, ToTensor, ) -from ..framework.physical.model import FinalCost, LossKey -from .primitives import Concat __all__ = ["TrainModel"] @@ -79,7 +84,7 @@ def _create_size() -> Model: class TrainModel(Model): - def __init__(self, model: BaseModel) -> None: + def __init__(self, model: Model) -> None: super().__init__() self._model = model self._losses: list[LossModelDict] = [] @@ -103,20 +108,18 @@ def __init__(self, model: BaseModel) -> None: raise KeyError( f"'{FinalCost}' could not be used as an external key in TrainModel!" ) - + # TODO: We can use _extend instead of extend in TrainModel. self.extend(model, **extend_kwargs) # self.loss_keys: dict[str, Connection] = {} self.loss_keys: dict[str, str] = {} self.regularization_keys: list[str] = [] self.metric_keys: list[str] = [] self.loss_combiner: BaseModel = Sum() - # TODO: Update type of reg_coef_map. float | Tensor[int | float | bool] - # is not correct. self.reg_coef_map: dict[ - float | Tensor[int | float | bool], set[Connection] + float | Tensor[int | float | bool], set[ConnectionData] ] = {} - self.geomean_map: dict[str, list[tuple[Connection, float]]] = {} - self.reduce_inputs: dict[str, list[tuple[Connection, Connection]]] = {} + self.geomean_map: dict[str, list[tuple[ConnectionData, float]]] = {} + self.reduce_inputs: dict[str, list[tuple[ConnectionData, ConnectionData]]] = {} def __add__(self, model: ExtendInfo | BaseModel) -> Self: """This function allows models to be added sequentially via "+=" operator. @@ -165,7 +168,7 @@ def check_fn(context: TrainModel, *args: Any, **kwargs: Any) -> T: def add_loss( self, loss_model: Model, - reduce_steps: list[BaseModel] | None = None, + reduce_steps: list[Model] | None = None, key_name: str | None = None, coef: float | None = None, **kwargs: Any, @@ -236,7 +239,7 @@ def add_loss( # TODO: Currently kwargs contains only input keys of # first (loss) model. # We may want to add output key for the final model's output key. - reduce_inputs: list[tuple[Connection, Connection]] = [] + reduce_inputs: list[tuple[ConnectionData, ConnectionData]] = [] for key in kwargs: if key in loss_model.conns.output_keys: raise KeyError("Output of the loss model cannot be defined!") @@ -252,27 +255,27 @@ def add_loss( if i == len(reduce_steps) - 1 and key_name is not None and coef is None: out_key = self.get_single_output(m).key # self.extend(m, **{in_key: prev_out_key.conn, out_key: key_name}) - info: dict[str, ConnectionType] = { - in_key: prev_out_key.conn, - out_key: IOKey(key_name), + info: dict[str, ConnectionDataType] = { + in_key: prev_out_key, + out_key: BaseKey(key_name), } self.extend(m, **info) else: - self.extend(m, **{in_key: prev_out_key.conn}) + self.extend(m, **{in_key: prev_out_key}) # Save all reduce inputs for geo-mean if isinstance(m, Min | Max | Mean): if (axis := m.conns.get_connection("axis")) is None: raise KeyError("Reduce model should have axis key.") - reduce_inputs.append((prev_out_key.conn, axis.conn)) + reduce_inputs.append((prev_out_key, axis)) prev_out_key = self.get_single_output(m).data # Apply coef if coef is not None: # kwargs = {"left": prev_out_key.conn, "right": coef, "output": key_name} kwargs = { - "left": prev_out_key.conn, + "left": prev_out_key, "right": coef, - "output": IOKey(name=key_name), + "output": BaseKey(name=key_name), } if key_name is None: kwargs.pop("output") @@ -316,8 +319,8 @@ def add_regularization( "args": kwargs, } ) - canonical_inputs = {con_data.conn for con_data in self.conns.cins} - canonical_outputs = {con_data.conn for con_data in self.conns.couts} + canonical_inputs = {self.connection_map[data] for data in self.conns.cins} + canonical_outputs = {self.connection_map[data] for data in self.conns.couts} self._add_regularization(model, coef, reg_key, key_name, **kwargs) self.set_cin(*canonical_inputs) self.set_cout(*canonical_outputs) @@ -391,8 +394,8 @@ def _add_regularization( keywords = {} for key, value in model(**kwargs).connections.items(): - if isinstance(value, ConnectionData): - keywords[key] = value.conn + if isinstance(value, Connection): + keywords[key] = value.data else: keywords[key] = value self.extend(model, **keywords) @@ -402,22 +405,26 @@ def _add_regularization( if (out_con := model.conns.get_connection("output")) is None: raise KeyError("Given key does not belong to the Model!") - self.geomean_map.setdefault(outer_key, []).append((out_con.conn, coef)) - self.reg_coef_map.setdefault(coef, set()).add(out_con.conn) + self.geomean_map.setdefault(outer_key, []).append((out_con, coef)) + self.reg_coef_map.setdefault(coef, set()).add(out_con) @check_finalized def add_metric( self, model: Model, - reduce_steps: list[BaseModel] | None = None, + reduce_steps: list[Model] | None = None, key_name: str | None = None, **kwargs: Any, ) -> None: # TODO: Somehow we need to imply metric is attached and self model # could not be extended or be used as another model's child model. - # self._extend(model, **kwargs) - # self.extend(model, **kwargs) - self.extend(model, **model(**kwargs).connections) + self.extend( + model, + **{ + key: value.data if isinstance(value, Connection) else value + for key, value in model(**kwargs).connections.items() + }, # type: ignore + ) if not reduce_steps: reduce_steps = [Buffer()] @@ -428,13 +435,13 @@ def add_metric( if i == len(reduce_steps) - 1 and key_name is not None: out = self.get_single_output(m).data # self.extend(m, **{in_key: prev_out_key, out.key: key_name}) - info: dict[str, ConnectionType] = { - in_key: prev_out_key, - out.key: IOKey(name=key_name), + info: dict[str, ConnectionDataType] = { + in_key: prev_out_key.data, + out.key: BaseKey(name=key_name), } self.extend(m, **info) else: - self.extend(m, **{in_key: prev_out_key}) + self.extend(m, **{in_key: prev_out_key.data}) prev_out_con = self.get_single_output(m) assert prev_out_con is not None prev_out_key = prev_out_con @@ -461,12 +468,12 @@ def _add_loss_combiner(self) -> None: if not concat_model.conns.is_key_non_diff(key): concat_kwargs[key] = self.conns.all[ list(self.loss_keys.values())[idx] - ].conn + ] idx += 1 self.extend(concat_model, **concat_kwargs) self.extend( self.loss_combiner, - input=concat_model.output, + input=concat_model.output.data, output=IOKey(name=loss_output_key), ) elif num_of_loss_keys == 1: @@ -479,7 +486,7 @@ def _add_loss_combiner(self) -> None: buffer_model = Buffer() self.extend( buffer_model, - input=self.conns.all[list(self.loss_keys.values())[0]].conn, + input=self.conns.all[list(self.loss_keys.values())[0]], # input=list(self.loss_keys.values())[0], output=IOKey(name=loss_output_key), ) @@ -490,7 +497,7 @@ def finalize(self) -> None: self._add_geo_mean() self._add_loss_combiner() if self.reg_coef_map: - reg_concat_args: list[str | Connection] = [LossKey] + reg_concat_args: list[str | ConnectionData] = [LossKey] for coef, o_set in self.reg_coef_map.items(): concat_inputs = { f"input{idx + 1}": o for idx, o in enumerate(o_set) @@ -498,9 +505,9 @@ def finalize(self) -> None: self.extend( concat := Concat(n=len(o_set), axis=None), **concat_inputs ) - self.extend(add := Sum(), input=concat.output) - self.extend(mult := Multiply(), left=add.output, right=coef) - reg_concat_args.append(mult.output) + self.extend(add := Sum(), input=concat.output.data) + self.extend(mult := Multiply(), left=add.output.data, right=coef) + reg_concat_args.append(mult.output.data) # TODO: add concat and sum if len(reg_concat_args) > 1 self.extend( reg_concat := Concat(n=len(reg_concat_args), axis=None), @@ -510,7 +517,7 @@ def finalize(self) -> None: }, ) self.extend( - Sum(), input=reg_concat.output, output=IOKey(name=FinalCost) + Sum(), input=reg_concat.output.data, output=IOKey(name=FinalCost) ) self.set_cout(FinalCost) loss_con = self.conns.get_connection(LossKey) @@ -591,7 +598,7 @@ def summary( loss_conn = self.conns.get_connection(loss_key) assert loss_conn is not None model = self.dependency_map.local_output_dependency_map[loss_conn][0] - t_list.append([model.__class__.__name__]) + t_list.append([model.class_name]) m_name = name_mappings[model] conns = conn_info[m_name][0] shape = shape_info[m_name][0] @@ -606,10 +613,11 @@ def summary( else: for reduce in loss_dict["reduce_steps"]: axis = reduce.factory_args["axis"] + reduce_str += reduce.class_name if axis is None: - reduce_str += reduce.__class__.__name__ + "()" + reduce_str += "()" else: - reduce_str += reduce.__class__.__name__ + f"(axis = {axis})" + reduce_str += f"(axis = {axis})" reduce_str += ", " t_list.append([reduce_str[:-2]]) coef = loss_dict["coef"] @@ -638,11 +646,11 @@ def summary( model = self.dependency_map.local_output_dependency_map[conn_data][ 0 ] - r_list.append([model.__class__.__name__]) + r_list.append([model.class_name]) m_name = name_mappings[model] conns = conn_info[m_name][0] shape = shape_info[m_name][0] - reg_key = model.cin.key + reg_key = model._cin.key updated_reg_key = model.generate_keys(include_outputs=True).get( reg_key, reg_key ) @@ -668,7 +676,7 @@ def summary( m_conn = self.conns.get_connection(m_key) assert m_conn is not None model = self.dependency_map.local_output_dependency_map[m_conn][0] - m_list.append([model.__class__.__name__]) + m_list.append([model.class_name]) m_name = name_mappings[model] conns = conn_info[m_name][0] shape = shape_info[m_name][0] @@ -685,7 +693,8 @@ def _add_geo_mean(self) -> None: # Find all loss / reg_key dependencies. # geo_mappings: dict[Connection, list[tuple[Connection, Connection]]] = {} geo_mappings: dict[ - tuple[Connection, float], list[list[tuple[Connection, Connection]]] + tuple[ConnectionData, float], + list[list[tuple[ConnectionData, ConnectionData]]], ] = {} # Find all loss dependencies with corresponding regularization keys. for key, value in self.loss_keys.items(): @@ -702,7 +711,7 @@ def _add_geo_mean(self) -> None: geo_mappings[reg_info].append(self.reduce_inputs[key]) for reg_info, loss_connections in geo_mappings.items(): - final_outputs: list[Connection | Tensor[int]] = [] + final_outputs: list[ConnectionData | Tensor[int]] = [] for reduce in loss_connections: final_outputs.append(self._add_reduce_sizes(reduce)) if final_outputs: @@ -710,7 +719,7 @@ def _add_geo_mean(self) -> None: final_output = final_outputs[0] if (n_final_outputs := len(final_outputs)) > 0: concat_model = Concat(n=n_final_outputs, axis=None) - concat_kwargs: dict[str, Tensor[int] | Connection] = {} + concat_kwargs: dict[str, Tensor[int] | ConnectionData] = {} idx = 0 for key in concat_model.input_keys: if not concat_model.conns.is_key_non_diff(key): @@ -718,8 +727,8 @@ def _add_geo_mean(self) -> None: idx += 1 self.extend(concat_model, **concat_kwargs) - self.extend(prod := Prod(), input=concat_model.output) - final_output = prod.output + self.extend(prod := Prod(), input=concat_model.output.data) + final_output = prod.output.data # Add geo-mean result as final_output if n_final_outputs > 1: @@ -728,7 +737,7 @@ def _add_geo_mean(self) -> None: base=final_output, exponent=Tensor([1 / n_final_outputs]), ) - final_output = power.output + final_output = power.output.data # Add Divide Model to divide final_output to geo_mean. reg_con, coef = reg_info self.extend( @@ -737,30 +746,30 @@ def _add_geo_mean(self) -> None: self.reg_coef_map[coef].remove(reg_con) out_con = divide.conns.get_connection("output") assert out_con is not None - self.reg_coef_map[coef].add(out_con.conn) + self.reg_coef_map[coef].add(out_con) def _add_reduce_sizes( - self, reduce_list: list[tuple[Connection, Connection]] - ) -> Connection | Tensor[int]: - final_output: Connection | Tensor[int] = Tensor(1) - sizes: list[Connection] = [] + self, reduce_list: list[tuple[ConnectionData, ConnectionData]] + ) -> ConnectionData | Tensor[int]: + final_output: ConnectionData | Tensor[int] = Tensor(1) + sizes: list[ConnectionData] = [] for input, dim in reduce_list: m = _create_size() self.extend(m, input=input, dim=dim) out_con = m.conns.get_connection("output") assert out_con is not None - sizes.append(out_con.conn) - final_output = out_con.conn + sizes.append(out_con) + final_output = out_con if (num_of_sizes := len(sizes)) > 0: concat_model = Concat(n=num_of_sizes, axis=None) - concat_kwargs: dict[str, int | Connection] = {} + concat_kwargs: dict[str, int | ConnectionData] = {} idx = 0 for key in concat_model.input_keys: if not concat_model.conns.is_key_non_diff(key): concat_kwargs[key] = sizes[idx] idx += 1 self.extend(concat_model, **concat_kwargs) - self.extend(prod := Prod(), input=concat_model.output) - final_output = prod.output + self.extend(prod := Prod(), input=concat_model.output.data) + final_output = prod.output.data return final_output diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 19d775a0..b0c3aa3e 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import abc import re from collections.abc import Callable, Sequence from copy import deepcopy @@ -27,7 +26,6 @@ AssignedConstraintType, ConnectionData, IOHyperEdge, - IOKey, MainValueType, ShapesType, ShapeTemplateType, @@ -36,11 +34,11 @@ ToBeDetermined, ) from ..framework.constraints import constrain_fn_dict -from ..framework.logical import essential_primitives +from ..framework.logical import primitive +from ..framework.logical.model import IOKey from ..models import ( BaseModel, Connection, - CustomPrimitiveModel, Model, models, primitives, @@ -103,17 +101,17 @@ class TrainModelDict(TypedDict): model_dict = { item[0].lower(): item[1] for item in models.__dict__.items() - if isinstance(item[1], abc.ABCMeta) and issubclass(item[1], BaseModel) + if isinstance(item[1], type) and issubclass(item[1], BaseModel) } model_dict |= { item[0].lower(): item[1] for item in primitives.__dict__.items() - if isinstance(item[1], abc.ABCMeta) and issubclass(item[1], BaseModel) + if isinstance(item[1], type) and issubclass(item[1], BaseModel) } model_dict |= { item[0].lower(): item[1] - for item in essential_primitives.__dict__.items() - if isinstance(item[1], abc.ABCMeta) and issubclass(item[1], BaseModel) + for item in primitive.__dict__.items() + if isinstance(item[1], type) and issubclass(item[1], BaseModel) } model_dict |= {"trainmodel": TrainModel} @@ -151,7 +149,7 @@ def create_iokey_kwargs( def dict_to_model( modelparams: ModelDict | TrainModelDict | str, -) -> BaseModel: +) -> Model: """Convert given dictionary to a model object. Parameter @@ -198,13 +196,14 @@ def dict_to_model( args[k] = enum_dict[enum_key][v] model = model_class(**args) - else: # Custom model args |= handle_dict_to_model_args(model_name, params.get("args", {})) attrs: dict[str, Callable[..., Any]] = { - "__init__": lambda self: super(self.__class__, self).__init__(**args) # pyright: ignore + "__init__": lambda self: super(self.__class__, self).__init__( + formula_key=args.pop("formula_key") + ) # pyright: ignore } - model = type(model_name, (CustomPrimitiveModel,), attrs)() + model = type(model_name, (Model,), attrs)(**args) types: dict[str, str] = params.get("types", {}) # TODO: Set all types in a bulk. @@ -217,7 +216,7 @@ def dict_to_model( # way to convert strings into types and generic types. set_types[key] = eval(typ) if set_types: - model.set_types(set_types) + model.set_types(set_types) # type: ignore unnamed_keys: list[str] = params.get("unnamed_keys", []) differentiability_info: dict[str, bool] = params.get("differentiability_info", {}) @@ -227,6 +226,7 @@ def dict_to_model( "canonical_keys", {} ) + assert isinstance(model, Model) submodels_dict = {} for m_key, v in submodels.items(): m = dict_to_model(v) @@ -242,13 +242,11 @@ def dict_to_model( elif isinstance(conn, dict): if (io_key := conn.get("key")) is not None: # TODO: Update this part according to new IOKey structure. - key_kwargs = create_iokey_kwargs(io_key, submodels_dict) + key_kwargs = create_iokey_kwargs(io_key, submodels_dict) # type: ignore mappings[k] = IOKey(**key_kwargs) elif "tensor" in conn: mappings[k] = Tensor(conn["tensor"]) - assert isinstance(model, Model) - # model += m(**mappings) model |= m(**mappings) if "model" in canonical_keys: @@ -275,6 +273,7 @@ def dict_to_model( if len(assigned_shapes) > 0: model.set_shapes(dict_to_shape(assigned_shapes)) + assert isinstance(model, Model) return model @@ -651,21 +650,21 @@ def item_to_json(item: IOKey) -> dict[str, Any]: # TODO: Currently type is not supported for Tensors. # Handle This whit conversion test updates. result: dict[str, Any] = {} - if not isinstance(item.data.value, ToBeDetermined): - result["value"] = item.data.value - if item.data.shape is not None: + if not isinstance(item.value, ToBeDetermined): + result["value"] = item.value + if item.value_shape is not None: shape_template: list[str] = [] - for symbol in item.data.shape: + for symbol in item.value_shape: if isinstance(symbol, tuple): # variadic shape_template.append(f"{symbol[0]},...") else: shape_template.append(str(symbol)) result["shape_template"] = shape_template - elif isinstance(item.data.type, UnionType): - result["type"] = [type_to_str(item) for item in item.data.type.__args__] + elif isinstance(item.type, UnionType): + result["type"] = [type_to_str(item) for item in item.type.__args__] else: result["type"] = [ - type_to_str(item.data.type), # type: ignore + type_to_str(item.type), # type: ignore ] return result diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 8ade0550..3e961693 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -15,7 +15,7 @@ from mithril import Backend, Constant, compile, epsilon_table from mithril.framework.common import IOHyperEdge, Tensor -from mithril.models import BaseModel, Model, PrimitiveModel, TrainModel +from mithril.models import BaseModel, Model, Operator, TrainModel from mithril.utils.dict_conversions import dict_to_model, model_to_dict from tests.scripts.test_utils import ( assert_all_conn_key_are_same, @@ -230,12 +230,12 @@ def assert_models_equal(model1: BaseModel, model2: BaseModel): model1.factory_args.items(), model2.factory_args.items(), strict=False ): assert key1 == key2 - if isinstance(arg1, Model | PrimitiveModel): + if isinstance(arg1, Model | Operator): assert_models_equal(arg1, arg2) else: assert arg1 == arg2 - if isinstance(model1, Model) and isinstance(model2, Model): + if isinstance(model1, Operator) and isinstance(model2, Operator): assert len(model1.dag) == len(model2.dag) for submodel1, submodel2 in zip( model1.get_models_in_topological_order(), diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 0bc9f422..372a7909 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -29,7 +29,6 @@ Arange, ArgMax, ArgMin, - BaseModel, BroadcastTo, Buffer, Cast, @@ -93,7 +92,7 @@ def list_full(fill_value, *shapes): def compile_and_compare( - model: BaseModel, + model: Model, compile_kwargs: dict[str, Any], data: dict[str, Any], params: dict[str, Any], diff --git a/tests/scripts/test_c_backend.py b/tests/scripts/test_c_backend.py index 20e2c7a3..6ef3d790 100644 --- a/tests/scripts/test_c_backend.py +++ b/tests/scripts/test_c_backend.py @@ -19,8 +19,8 @@ from mithril import CBackend, NumpyBackend, compile from mithril.backends.with_manualgrad.c_backend.src.array import PyArray -from mithril.framework.common import IOKey, Tensor -from mithril.models import Add, Model, Multiply +from mithril.framework.common import Tensor +from mithril.models import Add, IOKey, Model, Multiply from ..utils import with_temp_file diff --git a/tests/scripts/test_canonicality.py b/tests/scripts/test_canonicality.py index 75a57818..63e1d151 100644 --- a/tests/scripts/test_canonicality.py +++ b/tests/scripts/test_canonicality.py @@ -17,12 +17,13 @@ import pytest import mithril as ml -from mithril.framework.common import IOKey, Tensor +from mithril.framework.common import Tensor from mithril.models import ( Add, Buffer, Convolution2D, Gelu, + IOKey, Linear, LogisticRegression, Model, diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index f0cb6154..edfdc987 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -28,9 +28,6 @@ from mithril.framework.common import ( NOT_GIVEN, TBD, - Connection, - ConnectionType, - IOKey, Tensor, ToBeDetermined, ) @@ -42,6 +39,8 @@ BinaryCrossEntropy, Buffer, Concat, + Connection, + ConnectionType, Convolution1D, Convolution2D, CrossEntropy, @@ -49,6 +48,7 @@ Flatten, Greater, Indexer, + IOKey, LeakyRelu, Linear, Log, diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index 10b43c2d..0ffb91ac 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -19,9 +19,7 @@ NOT_GIVEN, TBD, BaseKey, - ConnectionType, IOHyperEdge, - IOKey, ShapeRepr, Tensor, Uniadic, @@ -32,14 +30,17 @@ Add, Buffer, Connection, + ConnectionType, ExtendInfo, Indexer, + IOKey, Model, - PrimitiveModel, + Operator, Relu, Slice, Transpose, ) +from mithril.models.primitives import PrimitiveModel def dummy_constraint(output: IOHyperEdge, input: IOHyperEdge): @@ -150,9 +151,7 @@ def __init__(self, left, right, output) -> None: left=BaseKey(shape=left, type=Tensor), right=BaseKey(shape=right, type=Tensor), ) - self._add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] - ) + self._add_constraint(fn=bcast, keys=[Operator.output_key, "left", "right"]) def __call__( # type: ignore self, diff --git a/tests/scripts/test_constraint_graph.py b/tests/scripts/test_constraint_graph.py index 5c4705cb..991f29d3 100644 --- a/tests/scripts/test_constraint_graph.py +++ b/tests/scripts/test_constraint_graph.py @@ -21,7 +21,7 @@ Updates, UpdateType, ) -from mithril.models import PrimitiveModel, Sigmoid +from mithril.models import Model, Sigmoid class ConstrGraphTestBase: @@ -29,7 +29,7 @@ class ConstrGraphTestBase: trace_list: list[str] result_map: dict[tuple[bool, ...], list[str]] - def model(self) -> PrimitiveModel: + def model(self) -> Model: # model with constraint graph raise NotImplementedError() @@ -92,11 +92,11 @@ class ThreeConstraintsTest(ThreeConstraints): subclass should have two fields defined: 1. result_map: dict[tuple[bool, ...], list[str]] - 2. model: Callable[[], PrimitiveModel] + 2. model: Callable[[], Operator] result_map defines result for each possible condition of constraints - model should return a PrimitiveModel object with constraints defined + model should return a Operator object with constraints defined in each constraints, if its conditions true, appends its name to trace_list (namely, c1, c2, c3). However, if a constraint has dependencies, it cannot be @@ -148,7 +148,7 @@ class TestThreeSequential(ThreeConstraintsTest): (False, False, False): [], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: # constr_1 ----> constr_2 ----> constr_3 model = Sigmoid() @@ -174,7 +174,7 @@ class TestThreeOneToMany(ThreeConstraintsTest): (False, False, False): [], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: # + ---> constr_2 # | # constr_1 --- + @@ -204,7 +204,7 @@ class TestThreeManyToOne(ThreeConstraintsTest): (False, False, False): [], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: # constr_1 --- + # | # + ---> constr_3 @@ -240,7 +240,7 @@ class TestFourDiamond(FourConstraintsTest): (False, False, False, False): [], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: # + ---> constr_2 --- + # | | # constr_1 --- + + ---> constr_4 @@ -284,7 +284,7 @@ class TestTwoPhaseDiamond(TestFourDiamond): (False, False, False, False): ["c1", "c2"], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: model = super().model() self.conditions = (True, True, False, False) model.set_shapes(input=[("V1", ...), "a"]) @@ -311,7 +311,7 @@ class TestTwoPhaseDiamondAtInit(TestFourDiamond): (False, False, False, False): ["c1", "c2"], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: self.conditions = (True, True, False, False) model = super().model() return model @@ -337,7 +337,7 @@ class TestFourManyToMany(FourConstraintsTest): (False, False, False, False): [], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: # constr_1 --- + + ---> constr_3 # | | # + --- + @@ -379,7 +379,7 @@ class TestFourTwoSequential(FourConstraintsTest): (False, False, False, False): [], } - def model(self) -> PrimitiveModel: + def model(self) -> Model: # constr_1 ----> constr_2 # # constr_3 ----> constr_4 diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index 314dfa14..a5a6f90f 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -155,7 +155,7 @@ def test_data_store_4(): def test_data_store_5(): """Tests infer_static and prune runs together""" - # TODO: This test is expects cached_data to be "input" and "output" but + # TODO: This test expects cached_data to be "input" and "output" but # after we fix corresponding flat_graph handlings, it will be changed # to expect only "output" as cached_data and "input" as unused_keys. backend = TorchBackend() diff --git a/tests/scripts/test_differentiablity.py b/tests/scripts/test_differentiablity.py index 028ae732..bd04b48c 100644 --- a/tests/scripts/test_differentiablity.py +++ b/tests/scripts/test_differentiablity.py @@ -14,8 +14,8 @@ import mithril from mithril import JaxBackend -from mithril.framework.common import IOKey, Tensor -from mithril.models import Add, Buffer, Linear, Model, Multiply +from mithril.framework.common import Tensor +from mithril.models import Add, Buffer, IOKey, Linear, Model, Multiply def test_data_linear(): diff --git a/tests/scripts/test_flat_graph.py b/tests/scripts/test_flat_graph.py index f6b118dd..0b9b3d63 100644 --- a/tests/scripts/test_flat_graph.py +++ b/tests/scripts/test_flat_graph.py @@ -24,9 +24,9 @@ def test_flatgraph_1(): graph = FlatGraph( {"input1", "input2"}, {"output"}, ml.JaxBackend(), ConstraintSolver() ) - graph.add_value(Relu(), {"input": "input1", "output": "relu_out"}) - graph.add_value(Buffer(), {"input": "relu_out", "output": "buffer_output"}) - graph.add_value(Buffer(), {"input": "buffer_output", "output": "output"}) + graph.add_value(Relu().submodel, {"input": "input1", "output": "relu_out"}) + graph.add_value(Buffer().submodel, {"input": "relu_out", "output": "buffer_output"}) + graph.add_value(Buffer().submodel, {"input": "buffer_output", "output": "output"}) graph.prune_duplicate_nodes({}, {}) expected_connections = ["input1", "relu_out"] @@ -42,11 +42,11 @@ def test_flatgraph_2(): ml.JaxBackend(), ConstraintSolver(), ) - graph.add_value(Relu(), {"input": "input1", "output": "relu_out"}) - graph.add_value(Buffer(), {"input": "relu_out", "output": "output1"}) - graph.add_value(Buffer(), {"input": "output1", "output": "output2"}) - graph.add_value(Buffer(), {"input": "output2", "output": "output3"}) - graph.add_value(Buffer(), {"input": "output3", "output": "output4"}) + graph.add_value(Relu().submodel, {"input": "input1", "output": "relu_out"}) + graph.add_value(Buffer().submodel, {"input": "relu_out", "output": "output1"}) + graph.add_value(Buffer().submodel, {"input": "output1", "output": "output2"}) + graph.add_value(Buffer().submodel, {"input": "output2", "output": "output3"}) + graph.add_value(Buffer().submodel, {"input": "output3", "output": "output4"}) graph.prune_duplicate_nodes({}, {}) expected_connections = ["input1", "relu_out"] @@ -65,9 +65,9 @@ def test_flatgraph_3(): ml.JaxBackend(), ConstraintSolver(), ) - graph.add_value(Relu(), {"input": "input1", "output": "relu_out"}) - graph.add_value(Relu(), {"input": "relu_out", "output": "output1"}) - graph.add_value(Relu(), {"input": "output1", "output": "output2"}) + graph.add_value(Relu().submodel, {"input": "input1", "output": "relu_out"}) + graph.add_value(Relu().submodel, {"input": "relu_out", "output": "output1"}) + graph.add_value(Relu().submodel, {"input": "output1", "output": "output2"}) graph.prune_duplicate_nodes({}, {}) expected_connections = ["input1", "output1", "output2", "relu_out"] @@ -135,7 +135,7 @@ def test_infer_static_2(): inference=True, ) - assert len(pm.flat_graph.nodes) == 1 and add in pm.flat_graph.nodes + assert len(pm.flat_graph.nodes) == 1 and add.submodel in pm.flat_graph.nodes assert pm.flat_graph.all_source_keys == {"relu_out", "input2"} assert pm.flat_graph.all_target_keys == {"output"} assert pm.flat_graph.topological_order == ["output"] @@ -156,7 +156,7 @@ def test_infer_static_3(): inference=True, ) - assert len(pm.flat_graph.nodes) == 1 and add in pm.flat_graph.nodes + assert len(pm.flat_graph.nodes) == 1 and add.submodel in pm.flat_graph.nodes assert pm.flat_graph.all_source_keys == {"relu_out", "input2"} assert pm.flat_graph.all_target_keys == {"output"} assert pm.flat_graph.topological_order == ["output"] @@ -198,7 +198,7 @@ def test_discard_primitive(): inference=True, ) - assert len(pm.flat_graph.nodes) == 1 and relu in pm.flat_graph.nodes + assert len(pm.flat_graph.nodes) == 1 and relu.submodel in pm.flat_graph.nodes assert pm.flat_graph.all_source_keys == {"input2"} assert pm.flat_graph.all_target_keys == {"output2"} assert pm.flat_graph.topological_order == ["output2"] @@ -221,8 +221,8 @@ def test_discard_partial_of_sequence(): assert ( len(pm.flat_graph.nodes) == 2 - and relu2 in pm.flat_graph.nodes - and sig in pm.flat_graph.nodes + and relu2.submodel in pm.flat_graph.nodes + and sig.submodel in pm.flat_graph.nodes ) assert pm.flat_graph.all_source_keys == {"input1", "input2"} assert pm.flat_graph.all_target_keys == {"output1", "output2"} @@ -244,7 +244,7 @@ def test_discard_whole_sequence(): inference=True, ) - assert len(pm.flat_graph.nodes) == 1 and relu in pm.flat_graph.nodes + assert len(pm.flat_graph.nodes) == 1 and relu.submodel in pm.flat_graph.nodes assert pm.flat_graph.all_source_keys == {"input2"} assert pm.flat_graph.all_target_keys == {"output2"} assert pm.flat_graph.topological_order == ["output2"] diff --git a/tests/scripts/test_flatmodel.py b/tests/scripts/test_flatmodel.py index c631ef00..1d981cfe 100644 --- a/tests/scripts/test_flatmodel.py +++ b/tests/scripts/test_flatmodel.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import numpy as np import mithril as ml from mithril import JaxBackend -from mithril.framework.common import IOKey, Tensor +from mithril.framework.common import Tensor from mithril.framework.physical.model import FlatModel from mithril.models import ( Add, + IOKey, Linear, Model, Relu, @@ -34,21 +34,22 @@ def test_with_all_defined(): model = Model() model += (add := Add())(left="a", right="b", output="c") - f_model = FlatModel(model, short_namings=True) - assert f_model.mappings == {add: {"left": "a", "right": "b", "output": "c"}} + _add = add.submodel + assert f_model.mappings == {_add: {"left": "a", "right": "b", "output": "c"}} f_model = FlatModel(model, short_namings=False) - assert f_model.mappings == {add: {"left": "a", "right": "b", "output": "c"}} + assert f_model.mappings == {_add: {"left": "a", "right": "b", "output": "c"}} def test_with_some_undefined(): model = Model() model += (add := Add())(right="b", output="c") + _add = add.submodel f_model = FlatModel(model, short_namings=True) - assert f_model.mappings == {add: {"left": "left", "right": "b", "output": "c"}} + assert f_model.mappings == {_add: {"left": "left", "right": "b", "output": "c"}} f_model = FlatModel(model, short_namings=False) - assert f_model.mappings == {add: {"left": "add_left", "right": "b", "output": "c"}} + assert f_model.mappings == {_add: {"left": "add_left", "right": "b", "output": "c"}} def test_with_all_undefined(): @@ -57,13 +58,14 @@ def test_with_all_undefined(): f_model = FlatModel(model) assert f_model.mappings == { - add: {"left": "left", "right": "right", "output": "output"} + add.submodel: {"left": "left", "right": "right", "output": "output"} } def test_multi_level_name_with_lowest_definition(): model2 = Model("adder") model2 += (add := Add())(left="a", right="b", output="c") + _add = add.submodel model1 = Model(name="model") model1 += model2 @@ -71,10 +73,10 @@ def test_multi_level_name_with_lowest_definition(): model += model1 f_model = FlatModel(model) - assert f_model.mappings == {add: {"left": "a", "right": "b", "output": "c"}} + assert f_model.mappings == {_add: {"left": "a", "right": "b", "output": "c"}} f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - add: { + _add: { "left": "model_adder_a", "right": "model_adder_b", "output": "model_adder_c", @@ -84,7 +86,8 @@ def test_multi_level_name_with_lowest_definition(): def test_multi_level_name_with_lowest_definition_higher_redefinition_1(): model2 = Model(name="adder") - model2 += (add := Add())(left="a", right="b", output="c") + model2 += (add_model := Add())(left="a", right="b", output="c") + _add = add_model.submodel model1 = Model(name="namer") model1 += model2(a="d", b="e") @@ -92,16 +95,17 @@ def test_multi_level_name_with_lowest_definition_higher_redefinition_1(): model += model1(e="f") f_model = FlatModel(model) - assert f_model.mappings == {add: {"left": "d", "right": "f", "output": "c"}} + assert f_model.mappings == {_add: {"left": "d", "right": "f", "output": "c"}} f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - add: {"left": "namer_d", "right": "f", "output": "namer_adder_c"} + _add: {"left": "namer_d", "right": "f", "output": "namer_adder_c"} } def test_multi_level_name_with_lowest_definition_higher_redefinition_2(): model2 = Model() - model2 += (add := Add())(left="a", right="b", output="c") + model2 += (add_model := Add())(left="a", right="b", output="c") + _add = add_model.submodel model1 = Model(name="middle") model1 += model2(a="d", b="e") @@ -109,16 +113,17 @@ def test_multi_level_name_with_lowest_definition_higher_redefinition_2(): model += model1 f_model = FlatModel(model) - assert f_model.mappings == {add: {"left": "d", "right": "e", "output": "c"}} + assert f_model.mappings == {_add: {"left": "d", "right": "e", "output": "c"}} f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - add: {"left": "middle_d", "right": "middle_e", "output": "middle_model_c"} + _add: {"left": "middle_d", "right": "middle_e", "output": "middle_model_c"} } def test_collision_from_different_levels(): model2 = Model() - model2 += (add := Add())(left="a", right="b", output="e") + model2 += (add_model := Add())(left="a", right="b", output="e") + _add = add_model.submodel model1 = Model(name="middle") model1 += model2(a="d", b="e") @@ -126,16 +131,17 @@ def test_collision_from_different_levels(): model += model1 f_model = FlatModel(model) - assert f_model.mappings == {add: {"left": "d", "right": "e", "output": "e_0"}} + assert f_model.mappings == {_add: {"left": "d", "right": "e", "output": "e_0"}} f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - add: {"left": "middle_d", "right": "middle_e", "output": "middle_model_e"} + _add: {"left": "middle_d", "right": "middle_e", "output": "middle_model_e"} } def test_collision_from_different_levels_2(): model2 = Model(name="lower") - model2 += (add := Add())(left="a", right="b", output="e") + model2 += (add_model := Add())(left="a", right="b", output="e") + _add = add_model.submodel model1 = Model(name="middle2") model1 += model2(a="d", b="e") @@ -146,10 +152,10 @@ def test_collision_from_different_levels_2(): model = Model(name="upper") model += model3() f_model = FlatModel(model) - assert f_model.mappings == {add: {"left": "d", "right": "e", "output": "e_0"}} + assert f_model.mappings == {_add: {"left": "d", "right": "e", "output": "e_0"}} f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - add: { + _add: { "left": "middle1_d", "right": "middle1_middle2_e", "output": "middle1_middle2_lower_e", @@ -159,7 +165,8 @@ def test_collision_from_different_levels_2(): def test_collision_from_different_levels_3(): model2 = Model() - model2 += (add := Add())(left="a", right="b", output="e") + model2 += (add_model := Add())(left="a", right="b", output="e") + _add = add_model.submodel model1 = Model() model1 += model2(a="d", b="e") @@ -167,26 +174,30 @@ def test_collision_from_different_levels_3(): model += model1(e="e") f_model = FlatModel(model) - assert f_model.mappings == {add: {"left": "d", "right": "e", "output": "e_0"}} + assert f_model.mappings == {_add: {"left": "d", "right": "e", "output": "e_0"}} f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - add: {"left": "model_d", "right": "e", "output": "model_model_e"} + _add: {"left": "model_d", "right": "e", "output": "model_model_e"} } def test_collision_from_different_models(): model1 = Model() - model1 += Add()(left="l", right="r", output="o") + model1 += (add := Add())(left="l", right="r", output="o") + add1 = add.submodel + + model2 = Model() + model2 += (add := Add())(left="l", right="r", output="o") + add2 = add.submodel - model2 = deepcopy(model1) model = Model() model += model1 model += model2 f_model = FlatModel(model) expected_mapping = { - list(model1.dag.keys())[0]: {"left": "l", "right": "r_0", "output": "o_0"}, - list(model2.dag.keys())[0]: {"left": "o_0", "right": "r_1", "output": "o_1"}, + add1: {"left": "l", "right": "r_0", "output": "o_0"}, + add2: {"left": "o_0", "right": "r_1", "output": "o_1"}, } assert f_model.mappings == expected_mapping @@ -194,49 +205,53 @@ def test_collision_from_different_models(): def test_output_first_1(): model = Model() - model += Relu()(input="in1", output="out1") - model += Sigmoid()(input="in2", output="in1") + model += (relu := Relu())(input="in1", output="out1") + model += (sig := Sigmoid())(input="in2", output="in1") + _sig = sig.submodel + _relu = relu.submodel f_model = FlatModel(model) assert f_model.mappings == { - list(model.dag.keys())[1]: { + _sig: { "input": "in2", "output": "output", }, # TODO: Why this is output? - list(model.dag.keys())[0]: {"input": "output", "output": "out1"}, + _relu: {"input": "output", "output": "out1"}, } f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - list(model.dag.keys())[1]: { + _sig: { "input": "in2", "output": "output", }, - list(model.dag.keys())[0]: {"input": "output", "output": "out1"}, + _relu: {"input": "output", "output": "out1"}, } def test_output_first_2(): model = Model() model += (relu := Relu())(output="out1") - model += Sigmoid()(input="in2", output=relu.input) + model += (sig := Sigmoid())(input="in2", output=relu.input) + _sig = sig.submodel + _relu = relu.submodel f_model = FlatModel(model) assert f_model.mappings == { - list(model.dag.keys())[1]: { + _sig: { "input": "in2", "output": "output", }, # TODO: Why this is output? - list(model.dag.keys())[0]: {"input": "output", "output": "out1"}, + _relu: {"input": "output", "output": "out1"}, } f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - list(model.dag.keys())[1]: { + _sig: { "input": "in2", "output": "sigmoid_output", }, - list(model.dag.keys())[0]: {"input": "sigmoid_output", "output": "out1"}, + _relu: {"input": "sigmoid_output", "output": "out1"}, } @@ -244,16 +259,18 @@ def test_output_first_3(): model = Model() model += (relu := Relu())(output="out1") model += (sig := Sigmoid())(input="in2", output=relu.input) + _sig = next(iter(sig.dag)) + _relu = next(iter(relu.dag)) f_model = FlatModel(model) assert f_model.mappings == { - sig: {"input": "in2", "output": "output"}, - relu: {"input": "output", "output": "out1"}, + _sig: {"input": "in2", "output": "output"}, + _relu: {"input": "output", "output": "out1"}, } f_model = FlatModel(model, short_namings=False) assert f_model.mappings == { - sig: {"input": "in2", "output": "sigmoid_output"}, - relu: {"input": "sigmoid_output", "output": "out1"}, + _sig: {"input": "in2", "output": "sigmoid_output"}, + _relu: {"input": "sigmoid_output", "output": "out1"}, } @@ -277,19 +294,19 @@ def test_output_first_4(): f_model = FlatModel(model) expected_mapping = { - relu: {"input": "input", "output": "output1_0"}, - softp: {"input": "output1_0", "output": "output1_1"}, - sig: {"input": "output1_1", "output": "output2"}, - tanh: {"input": "output2", "output": "output"}, + relu.submodel: {"input": "input", "output": "output1_0"}, + softp.submodel: {"input": "output1_0", "output": "output1_1"}, + sig.submodel: {"input": "output1_1", "output": "output2"}, + tanh.submodel: {"input": "output2", "output": "output"}, } assert f_model.mappings == expected_mapping f_model = FlatModel(model, short_namings=False) expected_mapping = { - relu: {"input": "input", "output": "model_0_output1"}, - softp: {"input": "model_0_output1", "output": "model_1_output1"}, - sig: {"input": "model_1_output1", "output": "model_0_output2"}, - tanh: {"input": "model_0_output2", "output": "output"}, + relu.submodel: {"input": "input", "output": "model_0_output1"}, + softp.submodel: {"input": "model_0_output1", "output": "model_1_output1"}, + sig.submodel: {"input": "model_1_output1", "output": "model_0_output2"}, + tanh.submodel: {"input": "model_0_output2", "output": "output"}, } assert f_model.mappings == expected_mapping @@ -298,18 +315,23 @@ def test_linear_flat(): model = Model() model += (lin := Linear(21))(output="qwe") f_model = FlatModel(model) + next(iter(lin.dag.keys())) expected_mapping = { list(lin.dag.keys())[0]: { "input": "weight", "axes": "axes", "output": "output_0", }, - list(lin.dag.keys())[1]: { + next(iter(list(lin.dag.keys())[1].dag.keys())): { "left": "input", "right": "output_0", "output": "output_1", }, - list(lin.dag.keys())[2]: {"left": "output_1", "right": "bias", "output": "qwe"}, + next(iter(list(lin.dag.keys())[2].dag.keys())): { + "left": "output_1", + "right": "bias", + "output": "qwe", + }, } assert f_model.mappings == expected_mapping diff --git a/tests/scripts/test_functions.py b/tests/scripts/test_functions.py index 6dca509f..2cfa84e3 100644 --- a/tests/scripts/test_functions.py +++ b/tests/scripts/test_functions.py @@ -19,9 +19,7 @@ import mithril from mithril import CBackend, JaxBackend, NumpyBackend, TorchBackend from mithril.backends.with_manualgrad.numpy_backend.ops_grad import add_grad -from mithril.framework import NOT_GIVEN, ConnectionType, ExtendInfo -from mithril.framework.common import BaseKey, IOKey, Tensor -from mithril.framework.constraints import bcast +from mithril.framework.common import Tensor from mithril.models import ( Absolute, Add, @@ -32,6 +30,7 @@ Cosine, CrossEntropy, Divide, + IOKey, Layer, Linear, LinearSVM, @@ -39,7 +38,6 @@ Model, Multiply, Power, - PrimitiveModel, Relu, Sigmoid, Sine, @@ -52,7 +50,7 @@ from mithril.utils.utils import BiMultiMap from tests.scripts.test_utils import compare_callables -from ..utils import with_temp_file +from ..utils import MyAdder, with_temp_file # ruff: noqa: F821 @@ -208,7 +206,7 @@ def test_flatten_dag_1(): ] assert flatted_primitive_model_list == [ - model.__class__ for model in ordered_model_list + model.submodel.__class__ for model in ordered_model_list ] @@ -270,7 +268,7 @@ def test_flatten_dag_2(): ] assert flatted_primitive_model_list == [ - model.__class__ for model in ordered_model_list + model.submodel.__class__ for model in ordered_model_list ] @@ -315,7 +313,7 @@ def test_flatten_dag_3(): ] assert flatted_primitive_model_list == [ - model.__class__ for model in ordered_model_list + model.submodel.__class__ for model in ordered_model_list ] @@ -427,33 +425,12 @@ def evaluate(params, data, cache): def test_code_generator_4(file_path: str): model = Model() - def my_adder(input, rhs, cache: None): - return input + rhs + def my_adder(left, right, cache: None): + return left + right NumpyBackend.register_primitive(my_adder, add_grad) - class MyAdder(PrimitiveModel): - def __init__(self) -> None: - super().__init__( - formula_key="my_adder", - output=BaseKey(shape=[("Var_out", ...)], type=Tensor), - input=BaseKey(shape=[("Var_1", ...)], type=Tensor), - rhs=BaseKey(shape=[("Var_2", ...)], type=Tensor), - ) - self.add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "input", "rhs"] - ) - - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - rhs: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - kwargs = {"input": input, "rhs": rhs, "output": output} - return ExtendInfo(self, kwargs) - - model += MyAdder()(input="input", rhs="rhs", output=IOKey(name="output")) + model += MyAdder()(left="left", right="right", output=IOKey(name="output")) context = TrainModel(model) context.add_loss( BinaryCrossEntropy(), reduce_steps=[Mean()], input="output", target="target" @@ -471,13 +448,15 @@ def __call__( # type: ignore[override] @typing.no_type_check def evaluate(params, data, cache): - input = params["input"] + left = params["left"] output_0_cache = cache["output_0_cache"] output_1_cache = cache["output_1_cache"] output_cache = cache["output_cache"] - rhs = params["rhs"] + right = params["right"] target = data["target"] - output = output_cache["output"] = make_array(my_adder(input, rhs, output_cache)) + output = output_cache["output"] = make_array( + my_adder(left, right, output_cache) + ) output_0 = output_0_cache["output"] = make_array( binary_cross_entropy_with_logits( output, target, 2.2250738585072014e-308, cache=output_0_cache @@ -491,13 +470,13 @@ def evaluate(params, data, cache): @typing.no_type_check def evaluate_gradients(params, gradients, data, cache): - input = params["input"] + left = params["left"] output = cache["output_cache"]["output"] output_0 = cache["output_0_cache"]["output"] output_0_cache = cache["output_0_cache"] output_1_cache = cache["output_1_cache"] output_cache = cache["output_cache"] - rhs = params["rhs"] + right = params["right"] target = data["target"] gradients["output_1"] += gradients["final_cost"] gradients["output_0"] += accumulate_grads( @@ -518,15 +497,15 @@ def evaluate_gradients(params, gradients, data, cache): 2.2250738585072014e-308, ) ) - gradients["input"] += accumulate_grads( - make_array(add_grad(gradients["output"], output_cache, 0, input, rhs)), - input, + gradients["left"] += accumulate_grads( + make_array(add_grad(gradients["output"], output_cache, 0, left, right)), + left, output_cache, 0, ) - gradients["rhs"] += accumulate_grads( - make_array(add_grad(gradients["output"], output_cache, 1, input, rhs)), - rhs, + gradients["right"] += accumulate_grads( + make_array(add_grad(gradients["output"], output_cache, 1, left, right)), + right, output_cache, 1, ) @@ -540,33 +519,12 @@ def evaluate_gradients(params, gradients, data, cache): def test_code_generator_5(file_path: str): model = Model() - def my_adder(input, rhs): - return input + rhs + def my_adder(left, right): + return left + right JaxBackend.register_primitive(my_adder) - class MyAdder(PrimitiveModel): - def __init__(self) -> None: - super().__init__( - formula_key="my_adder", - output=BaseKey(shape=[("Var_out", ...)], type=Tensor), - input=BaseKey(shape=[("Var_1", ...)], type=Tensor), - rhs=BaseKey(shape=[("Var_2", ...)], type=Tensor), - ) - self.add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "input", "rhs"] - ) - - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - rhs: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - kwargs = {"input": input, "rhs": rhs, "output": output} - return ExtendInfo(self, kwargs) - - model += MyAdder()(input="input", rhs="rhs", output=IOKey(name="output")) + model += MyAdder()(left="left", right="right", output=IOKey(name="output")) context = TrainModel(model) add = Add() add.set_types(right=Tensor) @@ -586,15 +544,15 @@ def __call__( # type: ignore[override] @typing.no_type_check def evaluate(params, data, cache): - input = params["input"] - rhs = params["rhs"] + left = params["left"] right = params["right"] + right_0 = params["right_0"] target = data["target"] - output = my_adder(input, rhs) + output = my_adder(left, right) output_0 = binary_cross_entropy_with_logits( output, target, 2.2250738585072014e-308 ) - output_1 = add(output_0, right) + output_1 = add(output_0, right_0) del output_0 return {"final_cost": output_1, "output": output} diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index ae5835de..d3c6ace0 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -20,10 +20,11 @@ import mithril from mithril import TorchBackend -from mithril.framework.common import TBD, IOKey, Tensor, ToBeDetermined +from mithril.framework.common import TBD, Tensor, ToBeDetermined from mithril.models import ( Add, Buffer, + IOKey, Linear, Mean, Model, diff --git a/tests/scripts/test_key_namings.py b/tests/scripts/test_key_namings.py index 9407c8b6..b0530634 100644 --- a/tests/scripts/test_key_namings.py +++ b/tests/scripts/test_key_namings.py @@ -16,11 +16,12 @@ import mithril from mithril import TorchBackend -from mithril.framework.common import IOKey, Tensor +from mithril.framework.common import Tensor from mithril.models import ( Add, Buffer, Concat, + IOKey, Linear, Model, Sigmoid, diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index c897168f..6247d582 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -14,9 +14,11 @@ import re +import pytest + import mithril from mithril import JaxBackend, TorchBackend -from mithril.framework.common import TBD, BaseKey, IOKey, Tensor +from mithril.framework.common import TBD, BaseKey, Tensor from mithril.framework.constraints import squeeze_constraints from mithril.models import ( L2, @@ -25,16 +27,18 @@ Buffer, Convolution2D, CrossEntropy, - CustomPrimitiveModel, + IOKey, Layer, Linear, Mean, Model, + Operator, Relu, Sigmoid, SquaredError, TrainModel, ) +from mithril.models.primitives import PrimitiveModel from mithril.utils import dict_conversions from .helper import assert_evaluations_equal, assert_models_equal @@ -928,6 +932,7 @@ def test_set_values_ellipsis_2(): assert_models_equal(model, model_recreated) +@pytest.mark.skip(reason="Waiting for the fix in the conversion bug") def test_make_shape_constraint(): model = Model() @@ -936,7 +941,7 @@ def my_adder(input, rhs): TorchBackend.register_primitive(my_adder) # After serialization is this available? - class MyAdder(CustomPrimitiveModel): + class MyAdder(PrimitiveModel): def __init__(self, threshold=3) -> None: threshold *= 2 super().__init__( @@ -946,7 +951,7 @@ def __init__(self, threshold=3) -> None: rhs=BaseKey(type=int, value=threshold), ) self.add_constraint( - fn=squeeze_constraints, keys=[CustomPrimitiveModel.output_key, "input"] + fn=squeeze_constraints, keys=[Operator.output_key, "input"] ) model += MyAdder()(input="input") diff --git a/tests/scripts/test_parallel.py b/tests/scripts/test_parallel.py index 7b99fe8d..3b37754d 100644 --- a/tests/scripts/test_parallel.py +++ b/tests/scripts/test_parallel.py @@ -29,11 +29,12 @@ from mithril import compile from mithril.backends.with_autograd.torch_backend.parallel import TorchParallel from mithril.backends.with_autograd.torch_backend.utils import SharedCyclicQueue -from mithril.framework.common import IOKey, Tensor +from mithril.framework.common import Tensor from mithril.models import ( TBD, Add, Eye, + IOKey, Linear, Model, Multiply, diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index 2d5b71eb..9dc73bea 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -62,7 +62,7 @@ "Recall", "LogicalNot", "ToTensor", - "PrimitiveModel", + "Operator", "LessEqual", "Greater", "LogicalAnd", diff --git a/tests/scripts/test_recurrent_models.py b/tests/scripts/test_recurrent_models.py index eb61a003..b7dd7486 100644 --- a/tests/scripts/test_recurrent_models.py +++ b/tests/scripts/test_recurrent_models.py @@ -20,15 +20,17 @@ import mithril from mithril import TorchBackend -from mithril.framework.common import NOT_GIVEN, ConnectionType, IOKey, Tensor +from mithril.framework.common import NOT_GIVEN, Tensor from mithril.models import ( AbsoluteError, Add, Buffer, Cell, + ConnectionType, EncoderDecoder, ExtendInfo, Indexer, + IOKey, LSTMCell, ManyToOne, MatrixMultiply, diff --git a/tests/scripts/test_ref_counts.py b/tests/scripts/test_ref_counts.py index 3c19d632..0d2aa089 100644 --- a/tests/scripts/test_ref_counts.py +++ b/tests/scripts/test_ref_counts.py @@ -19,8 +19,6 @@ from mithril.framework.common import ( NOT_GIVEN, BaseKey, - Connection, - ConnectionType, ShapeTemplateType, Tensor, ) @@ -28,6 +26,8 @@ Add, BaseModel, Buffer, + Connection, + ConnectionType, Convolution1D, ExtendInfo, IOKey, @@ -35,12 +35,12 @@ MatrixMultiply, MaxPool1D, Model, - PrimitiveModel, PrimitiveUnion, Relu, Sigmoid, Sum, ) +from mithril.models.primitives import PrimitiveModel from .test_utils import ( get_all_data, diff --git a/tests/scripts/test_scalar_inference.py b/tests/scripts/test_scalar_inference.py index e37794c6..44af2bc0 100644 --- a/tests/scripts/test_scalar_inference.py +++ b/tests/scripts/test_scalar_inference.py @@ -18,7 +18,8 @@ import pytest import mithril as ml -from mithril.framework.common import TBD, Connection, IOKey +from mithril.framework.common import TBD +from mithril.framework.logical.model import Connection, IOKey from mithril.models import ( Add, Divide, @@ -27,7 +28,6 @@ Model, Multiply, Power, - PrimitiveModel, Shape, Subtract, ) @@ -46,7 +46,7 @@ class SupportsOutput(Protocol): class TestScalarInference: lambda_map: dict[ - type[PrimitiveModel], + type[Model], Callable[[int | float | bool, int | float | bool], int | float | bool], ] = { Add: lambda left, right: left + right, @@ -65,7 +65,7 @@ class TestScalarInference: ) def test_one_model( self, - model: PrimitiveModel, + model: Model, inputs: tuple[int | float | bool, int | float | bool], ): model = deepcopy(model) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 4319ef6b..776e9667 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -35,16 +35,14 @@ TBD, BaseKey, ConnectionData, - ConnectionType, IOHyperEdge, - IOKey, Tensor, ToBeDetermined, UniadicRecord, Variadic, create_shape_map, ) -from mithril.framework.constraints import bcast +from mithril.framework.logical.operators import BufferOp from mithril.models import ( L1, L2, @@ -58,6 +56,7 @@ Buffer, Concat, Connection, + ConnectionType, ConstraintSolver, Convolution1D, Convolution2D, @@ -69,6 +68,7 @@ FloorDivide, Gelu, Greater, + IOKey, Layer, LeakyRelu, Less, @@ -83,7 +83,6 @@ Multiply, PolynomialFeatures, Power, - PrimitiveModel, Prod, Relu, Reshape, @@ -104,9 +103,11 @@ TrainModel, Where, ) +from mithril.models.primitives import PrimitiveModel from mithril.utils.type_utils import is_list_int from mithril.utils.utils import OrderedSet +from ..utils import MyAdder from .helper import assert_models_equal from .test_shapes import check_shapes_semantically from .test_utils import ( @@ -675,22 +676,6 @@ def my_adder_grad(x): u_numpy_backend = pickle.loads(pickled_numpy) u_torch_backend = pickle.loads(pickled_torch) - class MyAdder(PrimitiveModel): - def __init__(self) -> None: - super().__init__( - formula_key="my_adder", - output=BaseKey(shape=[("Var_out", ...)], type=Tensor), - left=BaseKey(shape=[("Var_1", ...)], type=Tensor), - right=BaseKey(shape=[("Var_2", ...)], type=Tensor), - ) - self.add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] - ) - - def __call__(self, left, right, output): # type: ignore[override] - kwargs = {"left": left, "right": right, "output": output} - return ExtendInfo(self, kwargs) - model = Model() model += MyAdder()(left="left", right="right", output="output") @@ -1268,18 +1253,22 @@ def test_relational_operators_ignored_1(): def test_relational_operators_ignored_2(): model = Model() - model.extend( + model._extend( Less(), - left=IOKey("left", type=Tensor), - right=IOKey("right", type=Tensor), - output=IOKey("relational_out"), + { + "left": IOKey("left", type=Tensor), + "right": IOKey("right", type=Tensor), + "output": IOKey("relational_out"), + }, ) - model.extend( + model._extend( Where(), - cond=model.cout, - input1="inp1", - input2="inp2", - output=IOKey("where_out"), + { + "cond": model.cout, + "input1": "inp1", + "input2": "inp2", + "output": IOKey("where_out"), + }, ) pm = compile(model, NumpyBackend()) assert ( @@ -1661,15 +1650,23 @@ def test_geomean_evaluate(): model1 = Model() lin1 = Linear(dimension=10) lin12 = Linear(dimension=10) - model1.extend( - lin1, input="input", weight="weight", bias="bias", output=IOKey("output1") + model1._extend( + lin1, + { + "input": "input", + "weight": "weight", + "bias": "bias", + "output": IOKey("output1"), + }, ) - model1.extend( + model1._extend( lin12, - input=lin1.output, - weight="weight1", - bias="bias1", - output=IOKey("output2"), + { + "input": lin1.output, + "weight": "weight1", + "bias": "bias1", + "output": IOKey("output2"), + }, ) model1.set_shapes({"input": [10, 10, 10]}) lin1.input.set_differentiable(True) @@ -1687,15 +1684,23 @@ def test_geomean_evaluate(): model2 = Model() lin2 = Linear() lin22 = Linear(dimension=10) - model2.extend( - lin2, input="input", weight="weight", bias="bias", output=IOKey("output1") + model2._extend( + lin2, + { + "input": "input", + "weight": "weight", + "bias": "bias", + "output": IOKey("output1"), + }, ) - model2.extend( + model2._extend( lin22, - input=lin2.output, - weight="weight1", - bias="bias1", - output=IOKey("output2"), + { + "input": lin2.output, + "weight": "weight1", + "bias": "bias1", + "output": IOKey("output2"), + }, ) lin2.input.set_differentiable(True) ctx2 = TrainModel(model2) @@ -2062,7 +2067,8 @@ def test_static_anlaysis_4(): comp_model = mithril.compile(model=model, backend=NumpyBackend()) models = {add1, add2, sum1, sub1, mul1, mat1} - assert (models - comp_model.flat_graph.nodes.keys()) == {mat1} + _models = {model.submodel for model in models} + assert (_models - comp_model.flat_graph.nodes.keys()) == {mat1.submodel} def test_prune_1(): @@ -3163,23 +3169,23 @@ def test_empy_out_grad(): def geomean_multigpu_test(): model = Model() model.extend(l1 := Linear(16), input="input1") - model.extend(l2 := Linear(32), w="w", input=l1.output) - model.extend(l3 := Linear(32), w="w", input=l1.output) + model.extend(l2 := Linear(32), w="w", input=l1.output.data) + model.extend(l3 := Linear(32), w="w", input=l1.output.data) # Classification - model.extend(add := Add(), left=l3.output, right=l2.output) - model.extend(pow := Power(), base=add.output, exponent=2) - model.extend(mul := Multiply(), left=pow.output) - model.extend(abs := Absolute(), input=mul.output) - model.extend(sqrt := Sqrt(), input=abs.output) - model.extend(mul2 := Multiply(), left=sqrt.output, right="input2") - model.extend(div := Divide(), numerator=mul2.output, denominator=1.0) - model.extend(Softmax(), input=div.output, output="out1") + model.extend(add := Add(), left=l3.output.data, right=l2.output.data) + model.extend(pow := Power(), base=add.output.data, exponent=2) + model.extend(mul := Multiply(), left=pow.output.data) + model.extend(abs := Absolute(), input=mul.output.data) + model.extend(sqrt := Sqrt(), input=abs.output.data) + model.extend(mul2 := Multiply(), left=sqrt.output.data, right="input2") + model.extend(div := Divide(), numerator=mul2.output.data, denominator=1.0) + model.extend(Softmax(), input=div.output.data, output="out1") # Regression - model.extend(mul := Multiply(), left=l2.output, right=l3.output) - model.extend(add2 := Add(), left=mul.output, right="input3") - model.extend(Divide(), numerator=add2.output, denominator=40.0, output="out2") + model.extend(mul := Multiply(), left=l2.output.data, right=l3.output.data) + model.extend(add2 := Add(), left=mul.output.data, right="input3") + model.extend(Divide(), numerator=add2.output.data, denominator=40.0, output="out2") context = TrainModel(model) context.add_loss( @@ -3648,7 +3654,7 @@ def test_mlp_last_dimension_prop(): mlp_model = MLP(activations=[Relu(), Relu(), Relu()], dimensions=[12, 24, None]) ctx = TrainModel(mlp_model) loss_model = SquaredError() - loss_model.set_shapes(loss_model.safe_shapes) + loss_model.set_shapes(loss_model.submodel.safe_shapes) ctx.add_loss( loss_model, input=mlp_model.cout, @@ -3890,22 +3896,6 @@ def test_infer_static_register_fn(): def my_adder(left, right): return left + right - class MyAdder(PrimitiveModel): - def __init__(self) -> None: - super().__init__( - formula_key="my_adder", - output=BaseKey(shape=[("Var_out", ...)], type=Tensor), - left=BaseKey(shape=[("Var_1", ...)], type=Tensor), - right=BaseKey(shape=[("Var_2", ...)], type=Tensor), - ) - self.add_constraint( - fn=bcast, keys=[PrimitiveModel.output_key, "left", "right"] - ) - - def __call__(self, left, right, output): # type: ignore[override] - kwargs = {"left": left, "right": right, "output": output} - return ExtendInfo(self, kwargs) - JaxBackend.register_primitive(my_adder) model = Model() @@ -6825,7 +6815,6 @@ def __init__( } super().__init__(formula_key="einsum", name=name, **kwargs) - self._freeze() def __call__( # type: ignore[override] self, @@ -6909,7 +6898,6 @@ def __init__( } super().__init__(formula_key="einsum", name=name, **kwargs) - self._freeze() def __call__( # type: ignore[override] self, @@ -6947,3 +6935,19 @@ def test_empty_call_vs_direct_model_extending(): model2 += LeakyRelu()() assert_models_equal(model1, model2) + + +def test_extending_operator(): + model1 = BufferOp() + with pytest.raises(NotImplementedError) as err: + model1.extend(BufferOp()) + + assert str(err.value) == "Operators cannot be extended!" + + +def test_extending_operator_model(): + model1 = Buffer() + with pytest.raises(RuntimeError) as err: + model1 += Buffer() + + assert str(err.value) == "Primitive models cannot have submodels." diff --git a/tests/scripts/test_set_outputs.py b/tests/scripts/test_set_outputs.py index c9cf7eae..9c32e132 100644 --- a/tests/scripts/test_set_outputs.py +++ b/tests/scripts/test_set_outputs.py @@ -410,7 +410,7 @@ def test_9_error(): model.set_outputs(output=lin.output) error_text = str(err_info.value).strip('"') - assert error_text == "'Connection with name output already exists!'" + assert error_text == "Key 'output' is already used!" def test_10_error(): diff --git a/tests/scripts/test_set_shapes.py b/tests/scripts/test_set_shapes.py index 15dc8942..5af9e0d9 100644 --- a/tests/scripts/test_set_shapes.py +++ b/tests/scripts/test_set_shapes.py @@ -13,8 +13,7 @@ # limitations under the License. -from mithril.framework.common import IOKey -from mithril.models import Add, Model, Sigmoid +from mithril.models import Add, IOKey, Model, Sigmoid from .test_utils import check_shapes_semantically diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index 8009ef76..d4109c3e 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -28,10 +28,7 @@ DNF, NOT_GIVEN, BaseKey, - Connection, - ConnectionType, Equivalences, - IOKey, PossibleValues, ShapeNode, ShapeRepr, @@ -43,6 +40,7 @@ Variadic, ) from mithril.framework.constraints import reverse_constraints +from mithril.framework.logical.primitive import OperatorModel, PrimitiveModel from mithril.models import ( AUC, MLP, @@ -59,6 +57,8 @@ Cast, Cholesky, Concat, + Connection, + ConnectionType, ConstraintSolver, Convolution1D, Convolution2D, @@ -73,6 +73,7 @@ Gelu, GPRAlpha, GPRVOuter, + IOKey, IsNan, Layer, LeakyRelu, @@ -88,10 +89,10 @@ Multiply, NanToNum, NormModifier, + Operator, Pad, PermuteTensor, PositionalEncoding, - PrimitiveModel, PrimitiveUnion, Relu, Reshape, @@ -132,7 +133,7 @@ def assert_shapes( - model: BaseModel, + model: Model, logical_ref: Mapping[str, Sequence[Sequence[int | str] | int | str] | None], physical_ref: Mapping[str, Sequence[Sequence[int | str] | int | str] | None] | None = None, @@ -662,6 +663,7 @@ def test_linear_1_set_shapes(): ctx = TrainModel(model) loss_model = SquaredError() loss_model.set_shapes(loss_model.safe_shapes) + loss_model.set_shapes(loss_model.submodel.safe_shapes) ctx.add_loss( loss_model=loss_model, reduce_steps=[Mean()], input="output", target="target" ) @@ -699,7 +701,7 @@ def test_linear_1_static_shapes(): shapes = {"input": [100, 4], "target": [100, 1]} ctx = TrainModel(model) loss_model = SquaredError() - loss_model.set_shapes(loss_model.safe_shapes) + loss_model.set_shapes(loss_model.submodel.safe_shapes) ctx.add_loss( loss_model=loss_model, reduce_steps=[Mean()], input="output", target="target" ) @@ -750,7 +752,7 @@ def test_linear_1_static_inputs(): } ctx = TrainModel(model) loss_model = SquaredError() - loss_model.set_shapes(loss_model.safe_shapes) + loss_model.set_shapes(loss_model.submodel.safe_shapes) ctx.add_loss( loss_model=loss_model, reduce_steps=[Mean()], input="output", target="target" ) @@ -2581,7 +2583,7 @@ def test_mlp_1_static_shapes(): model = MLP(activations=[Softplus(), Buffer(), Buffer()], dimensions=[5, 10, 1]) ctx = TrainModel(model) loss_model = SquaredError() - loss_model.set_shapes(loss_model.safe_shapes) + loss_model.set_shapes(loss_model.submodel.safe_shapes) ctx.add_loss(loss_model, input=model.output, target="target", reduce_steps=[Mean()]) static_input_shapes = {"input": [100, 4], "target": [100, 1]} logical_ref: dict[str, list | None] = { @@ -2694,7 +2696,7 @@ def test_mlp_1_static_inputs(): model = MLP(activations=[Softplus(), Buffer(), Buffer()], dimensions=[5, 10, 1]) ctx = TrainModel(model) loss_model = SquaredError() - loss_model.set_shapes(loss_model.safe_shapes) + loss_model.set_shapes(loss_model.submodel.safe_shapes) ctx.add_loss(loss_model, input=model.output, target="target", reduce_steps=[Mean()]) static_inputs = { @@ -3119,7 +3121,7 @@ class Model9(PrimitiveModel): def __init__(self) -> None: super().__init__( - formula_key="buffer", + formula_key="relu", input=BaseKey(shape=["u1", ("Var1", ...)], type=Tensor), output=BaseKey(shape=["u2", "u1", ("Var1", ...)], type=Tensor), ) @@ -6646,9 +6648,11 @@ def find_all_reprs(repr: ShapeRepr, repr_cache=None) -> set[ShapeRepr]: CrossEntropy, CustomPrimitiveModel, ToTuple, - PrimitiveModel, + Operator, ToList, ScaledDotProduct, + PrimitiveModel, + OperatorModel, } ref_counts = { Exponential: 1, @@ -6690,10 +6694,10 @@ def find_all_reprs(repr: ShapeRepr, repr_cache=None) -> set[ShapeRepr]: ZerosLike: 1, } # find all primitives that are defined in primitives.py - _all_primitives_dict = ( - primitives.__dict__ | mithril.framework.essential_primitives.__dict__ # type: ignore - ) - all_primitives = primitives.__all__ + mithril.framework.essential_primitives.__all__ # type: ignore + + u_primitives = mithril.framework.logical.primitive + _all_primitives_dict = primitives.__dict__ | u_primitives.__dict__ + all_primitives = primitives.__all__ + u_primitives.__all__ # type: ignore all_primitives_dict = { value for key, value in _all_primitives_dict.items() if key in all_primitives } @@ -6709,11 +6713,11 @@ def find_all_reprs(repr: ShapeRepr, repr_cache=None) -> set[ShapeRepr]: param: default_args.get(param, TBD) for param in model_init_params } # kwargs = {param: ... for param in model_init_params} - model: PrimitiveModel = primitive_model(**kwargs) + model: Model = primitive_model(**kwargs) # Set all untyped connections to Tensor type. model.set_types( { - conn.conn: Tensor + model.connection_map[conn]: Tensor for conn in model.conns.input_connections if conn.metadata.edge_type is ToBeDetermined } diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index 06938199..462d078f 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -23,8 +23,6 @@ from mithril import JaxBackend, NumpyBackend, TorchBackend from mithril.framework.common import ( NOT_GIVEN, - Connection, - IOKey, ShapeTemplateType, Table, Tensor, @@ -32,7 +30,6 @@ Variadic, get_summary_shapes, ) -from mithril.framework.utils import define_unique_names from mithril.models import ( L1, L2, @@ -40,11 +37,13 @@ Add, Buffer, Concat, + Connection, Convolution1D, Convolution2D, CrossEntropy, Divide, Flatten, + IOKey, KernelizedSVM, LeakyRelu, Linear, @@ -65,6 +64,7 @@ Tanh, ToTensor, TrainModel, + define_unique_names, ) # TODO: Remove dependency to examples folder (Create a model zoo and include ResNets)! @@ -96,7 +96,7 @@ def test_extract_logical_connections_1(): ) model1 += lin2(input=lin1.output, weight=lin1.output, output=IOKey(name="output2")) model1 += lin3(input=lin1.weight, weight=lin1.weight, output=IOKey(name="output3")) - name_mappings = define_unique_names(model1.dag) + name_mappings = define_unique_names(model1.dag.keys()) conns = model1.extract_connection_info(name_mappings) assert conns == { "Linear_0": ( @@ -142,7 +142,7 @@ def test_extract_logical_connections_2(): model2 += model() model2 += buff3(input=model.output1, output=model.input2) # type: ignore model2.set_cin(model.input1) # type: ignore - name_mappings = define_unique_names(model2.dag) + name_mappings = define_unique_names(model2.dag.keys()) conns = model2.extract_connection_info(name_mappings) ref_conns = { "Model": ( @@ -162,7 +162,7 @@ def test_extract_logical_connections_3(): model += buff2(output=IOKey(name="output")) model += buff1(output=buff2.input, input="input") - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { "Buffer_0": ({"input": ["Buffer_1.output"]}, {"output": ["'output'"]}), @@ -194,7 +194,7 @@ def test_extract_logical_connections_4(): input2="in2", ) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { @@ -222,7 +222,7 @@ def test_extract_logical_connections_5(): model += Linear(1000) model += Linear(1) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { @@ -274,7 +274,7 @@ def test_extract_logical_connections_6(): model += Linear(dimension=3)(input="input", output=IOKey(name="output")) model += Flatten() model += Mean(keepdim=True) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { "Linear": ( @@ -308,7 +308,7 @@ def test_extract_logical_connections_7(): model += model_2 model += Flatten() model += Buffer() - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { @@ -359,7 +359,7 @@ def test_extract_logical_connections_8(): model_2 = Model() model_2 += model_1 model_2 += buff_1 - name_mappings = define_unique_names(model_2.dag) + name_mappings = define_unique_names(model_2.dag.keys()) conns = model_2.extract_connection_info(name_mappings) ref_conns = { "Model": ( @@ -384,7 +384,7 @@ def test_extract_logical_connections_9(): for model in (deepcopy(model_n) for m in range(3)): model_nm += model - name_mappings = define_unique_names(model_nm.dag) + name_mappings = define_unique_names(model_nm.dag.keys()) conns = model_nm.extract_connection_info(name_mappings) ref_conns = { "Model_0": ({"$input": ["'$input'"]}, {"$output2": ["Model_1.$input"]}), @@ -414,7 +414,7 @@ def test_extract_logical_connections_10(): model += model_2 model += model_3 - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { @@ -463,7 +463,7 @@ def test_extract_logical_connections_11(): model += model_0 model += model_1 model += model_2 - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { "Model_0": ({"$input": ["'$input'"]}, {"$output": ["Model_1.$input"]}), @@ -480,7 +480,7 @@ def test_extract_logical_connections_12(): model += model_1 model += model_2 model += model_3 - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { "Sigmoid": ({"input": ["'$input'"]}, {"output": ["Model_0.$input"]}), @@ -500,7 +500,7 @@ def test_extract_logical_connections_13(): model += Linear(1000) model += Linear(1) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) conns = model.extract_connection_info(name_mappings) ref_conns = { "Model_0": ( @@ -552,7 +552,7 @@ def test_extract_shapes_logical_1(): buff2 = Buffer() model += buff1(input="input") model += buff2(input=buff1.output) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) uni_cache: dict[UniadicRecord, str] = {} var_cache: dict[Variadic, str] = {} conn_info = model.extract_connection_info(name_mappings) @@ -574,7 +574,7 @@ def test_extract_shapes_logical_2(): model += buff1(input="input") model += buff2(input=buff1.output) model.set_shapes({"input": [45, 96, 2]}) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) uni_cache: dict[UniadicRecord, str] = {} var_cache: dict[Variadic, str] = {} conn_info = model.extract_connection_info(name_mappings) @@ -605,7 +605,7 @@ def test_extract_shapes_logical_3(): model += linear_3 model += relu_3 relu_2.set_shapes({"input": [4, 2]}) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) uni_cache: dict[UniadicRecord, str] = {} var_cache: dict[Variadic, str] = {} conn_info = model.extract_connection_info(name_mappings) @@ -648,7 +648,7 @@ def test_extract_shapes_logical_4(): model += relu_2 model += conv_3 model += relu_3 - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) uni_cache: dict[UniadicRecord, str] = {} var_cache: dict[Variadic, str] = {} conn_info = model.extract_connection_info(name_mappings) @@ -722,7 +722,7 @@ def test_extract_shapes_logical_5(): model += linear_3 model += relu_3 relu_2.set_shapes({"input": [None, None]}) - name_mappings = define_unique_names(model.dag) + name_mappings = define_unique_names(model.dag.keys()) uni_cache: dict[UniadicRecord, str] = {} var_cache: dict[Variadic, str] = {} conn_info = model.extract_connection_info(name_mappings) @@ -774,7 +774,7 @@ def test_define_unique_names_1(): model |= KernelizedSVM_1(input1=model.cout) lin_0.input.set_differentiable(True) - name_dict = define_unique_names(model.dag) + name_dict = define_unique_names(model.dag.keys()) assert name_dict == { lin_0: "Linear_0", lin_1: "Linear_1", @@ -795,7 +795,7 @@ def test_define_unique_names_2(): model += (rel_1 := Relu()) model += (lin2 := Linear()) model += (rel_2 := Relu()) - name_dict = define_unique_names(model.dag) + name_dict = define_unique_names(model.dag.keys()) assert name_dict == { lin1: "Linear_0", lin2: "Linear_1", diff --git a/tests/scripts/test_tuple_list_args_in_extend.py b/tests/scripts/test_tuple_list_args_in_extend.py index 00816e2f..9940efe6 100644 --- a/tests/scripts/test_tuple_list_args_in_extend.py +++ b/tests/scripts/test_tuple_list_args_in_extend.py @@ -15,8 +15,8 @@ import pytest from mithril import JaxBackend, TorchBackend, compile -from mithril.framework.common import IOKey, Tensor -from mithril.models import Add, MatrixMultiply, Model, ToTensor +from mithril.framework.common import Tensor +from mithril.models import Add, IOKey, MatrixMultiply, Model, ToTensor from .test_utils import assert_results_equal diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index a51f060a..78e974cd 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -23,15 +23,13 @@ from mithril.framework.common import ( NOT_GIVEN, BaseKey, - Connection, - ConnectionType, IOHyperEdge, - IOKey, ShapeTemplateType, Tensor, Updates, ) from mithril.framework.constraints import set_edge_type +from mithril.framework.logical.model import Connection, ConnectionType, IOKey from mithril.models import ( MLP, TBD, @@ -46,7 +44,7 @@ MatrixMultiply, Mean, Model, - PrimitiveModel, + Operator, PrimitiveUnion, Relu, Shape, @@ -57,6 +55,7 @@ ToTensor, ToTuple, ) +from mithril.models.primitives import PrimitiveModel from ..utils import compare_models from .test_utils import assert_results_equal @@ -701,7 +700,7 @@ def __init__(self, type) -> None: input=BaseKey(shape=[("Var2", ...)], type=type), ) self._add_constraint( - fn=self.artificial_constraint, keys=[PrimitiveModel.output_key, "input"] + fn=self.artificial_constraint, keys=[Operator.output_key, "input"] ) def __call__( # type: ignore[override] @@ -888,8 +887,9 @@ def test_connect_type_conv_handling_1(): value=Tensor([[2.0]]), expose=True, ) - model.extend( - mat_mul := MatrixMultiply(), left=con_object, output=IOKey(name="output") + model._extend( + mat_mul := MatrixMultiply(), + {"left": con_object, "output": IOKey(name="output")}, ) mat_mul.set_shapes({"output": output_shape}) model_1 = model @@ -904,8 +904,9 @@ def test_connect_type_conv_handling_1(): name="abcd", expose=True, ) - model.extend( - (mat_mul := MatrixMultiply()), left=con_object, output=IOKey(name="output") + model._extend( + (mat_mul := MatrixMultiply()), + {"left": con_object, "output": IOKey(name="output")}, ) mat_mul.set_shapes({"output": output_shape}) model_2 = model @@ -920,8 +921,9 @@ def test_connect_type_conv_handling_1(): name="abcd", expose=True, ) - model.extend( - (mat_mul := MatrixMultiply()), left=con_object, output=IOKey(name="output") + model._extend( + (mat_mul := MatrixMultiply()), + {"left": con_object, "output": IOKey(name="output")}, ) mat_mul.set_shapes({"output": output_shape}) model_3 = model diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 3d49e3a0..98acd42a 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -24,10 +24,10 @@ from mithril.framework.common import ( NOT_GIVEN, BaseKey, - ConnectionType, ToBeDetermined, find_intersection_type, ) +from mithril.framework.logical.model import ConnectionType from mithril.framework.utils import ( find_type, infer_all_possible_types, @@ -42,12 +42,12 @@ Mean, Model, Multiply, - PrimitiveModel, PrimitiveUnion, Shape, Sigmoid, Tensor, ) +from mithril.models.primitives import PrimitiveModel from mithril.utils.utils import find_dominant_type from .test_constant_inputs import ReduceMult diff --git a/tests/scripts/test_utils.py b/tests/scripts/test_utils.py index bcb8111a..6c79185f 100644 --- a/tests/scripts/test_utils.py +++ b/tests/scripts/test_utils.py @@ -28,7 +28,7 @@ Uniadic, find_intersection_type, ) -from mithril.framework.logical import BaseModel, Model, PrimitiveModel +from mithril.framework.logical import BaseModel, Model, Operator from mithril.framework.physical import PhysicalModel from mithril.framework.utils import find_type from mithril.models.train_model import TrainModel @@ -271,7 +271,7 @@ def assert_metadata_equal(*args): def get_all_data(model: BaseModel) -> set[IOHyperEdge]: # recursively gets the all data in the model (Tensor or Scalar) - if isinstance(model, PrimitiveModel): + if isinstance(model, Operator): return {model.conns.get_data(key) for key in model.conns.all} assert isinstance(model, Model) data = set() @@ -282,7 +282,7 @@ def get_all_data(model: BaseModel) -> set[IOHyperEdge]: def get_all_metadata(model: BaseModel) -> set[IOHyperEdge | None]: # recursively gets the all metadata in the model (IOHyperEdge) - if isinstance(model, PrimitiveModel): + if isinstance(model, Operator): return {model.conns.get_metadata(key) for key in model.conns.all} assert isinstance(model, Model) data = set() diff --git a/tests/utils.py b/tests/utils.py index e344e3e5..9e9b48b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,7 +18,16 @@ import mithril from mithril import Backend, DataType -from mithril.models import Model, PhysicalModel +from mithril.framework.common import BaseKey +from mithril.framework.constraints import bcast +from mithril.models import ( + ExtendInfo, + Model, + Operator, + PhysicalModel, + Tensor, +) +from mithril.models.primitives import PrimitiveModel def check_logical_models(model_1: Model, model_2: Model): @@ -29,7 +38,11 @@ def check_logical_models(model_1: Model, model_2: Model): dag_1.items(), dag_2.items(), strict=False ): # Check dag keys of each model. - assert key_1.__class__.__name__ == key_2.__class__.__name__ + assert ( + key_1.__class__.__name__ == key_2.__class__.__name__ + or key_1.__class__.__name__ + "Op" == key_2.__class__.__name__ + or key_1.__class__.__name__ == key_2.__class__.__name__ + "Op" + ) for (in_1, conn_1), (in_2, conn_2) in zip( value_1.items(), value_2.items(), strict=False ): @@ -160,3 +173,18 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +class MyAdder(PrimitiveModel): + def __init__(self) -> None: + super().__init__( + formula_key="my_adder", + output=BaseKey(shape=[("Var_out", ...)], type=Tensor), + left=BaseKey(shape=[("Var_1", ...)], type=Tensor), + right=BaseKey(shape=[("Var_2", ...)], type=Tensor), + ) + self.add_constraint(fn=bcast, keys=[Operator.output_key, "left", "right"]) + + def __call__(self, left, right, output): # type: ignore[override] + kwargs = {"left": left, "right": right, "output": output} + return ExtendInfo(self, kwargs)