diff --git a/mithril/framework/common.py b/mithril/framework/common.py index c18f5430..af3adf84 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -113,13 +113,13 @@ class NotAvailable(SingletonObject): pass -# class ToBeDetermined(SingletonObject): -# """ -# A singleton class representing a null data indicating -# that no data is provided. -# """ +class Auto(SingletonObject): + """ + A singleton class representing a configuration + setting of automatically handled arguments. + """ -# pass + pass class ToBeDetermined(SingletonObject): @@ -134,9 +134,7 @@ class ToBeDetermined(SingletonObject): NOT_GIVEN = NullConnection() NOT_AVAILABLE = NotAvailable() TBD = ToBeDetermined() -# TBD = ToBeDetermined() -# TBD = TBD -# ToBeDetermined = ToBeDetermined +AUTO = Auto() class UpdateType(Enum): @@ -645,7 +643,6 @@ def match(self, other: Tensor | Scalar) -> Updates: class Tensor(BaseData, GenericDataType[DataType]): _type: type[float] | type[int] | type[bool] | UnionType - temp_value: TensorValueType | ToBeDetermined value: DataType | ToBeDetermined def __init__( @@ -657,12 +654,10 @@ def __init__( ) -> None: super().__init__(type=possible_types) self._differentiable: bool = value is TBD - self.temp_value = TBD self.value = TBD # Update type if any value is given. if not isinstance(value, ToBeDetermined): self.set_type(find_dominant_type(value)) - self.temp_value = value else: self.value = value self.shape: ShapeNode = shape @@ -681,21 +676,11 @@ def is_valued(self) -> bool: def _set_as_physical(self): super()._set_as_physical() - def _convert_value(self, backend: Backend) -> DataType | ToBeDetermined: - if isinstance(self.temp_value, Constant): - self.value = backend.array( - epsilon_table[backend.precision][self.temp_value] - ) - elif self.temp_value is not TBD: - self.value = backend.array(self.temp_value) - return self.value - def make_physical(self, backend: Backend, memo: dict[int, Tensor | Scalar]): physical_tensor = deepcopy(self, memo) # Update data as physical data. physical_tensor._set_as_physical() # Update value of physical data taking backend into account. - physical_tensor._convert_value(backend) return physical_tensor def __deepcopy__(self, memo: dict[int, Tensor | Scalar]): @@ -732,31 +717,9 @@ def match_shapes(self, other: Tensor): # type: ignore[override] # which requires handling of interval arithmetic in logical level also. def set_value(self, value: DataType | TensorValueType) -> Updates: # type: ignore[override] - if self._logical_data: - assert isinstance(value, TensorValueType) - return self._set_logical_value(value) - else: + if not self._logical_data: assert self.is_tensor_type(value) return self._set_physical_value(value) - - def _set_logical_value(self, value: TensorValueType) -> Updates: - if isinstance(value, ToBeDetermined): - if self.temp_value is not TBD: - raise ValueError( - f"Value is set before as {self.temp_value}. Can not be reset." - ) - # if self.value is TBD and value is None: - # raise ValueError( - # "Already set as non-differentiable. Can not be reverted \ - # to a differentiable state." - # ) - self.value = value - else: - if self.temp_value is not TBD and self.temp_value != value: - raise ValueError( - f"Value is set before as {self.temp_value}. Can not be reset." - ) - self.temp_value = value return Updates() def _set_physical_value(self, value: DataType) -> Updates: diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index e19e9ac8..81cb4503 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -3495,6 +3495,35 @@ def tuple_converter_constraint(output: Scalar, input: Scalar) -> ConstrainResult return status, updates +def cross_entropy_constraint( + categorical: Scalar, input: Tensor, target: Tensor +) -> ConstrainResultType: + assert input._temp_shape is not None, "Input shape of reverse is not set!" + assert target._temp_shape is not None, "Target shape of reverse is not set!" + + status = False + updates = Updates() + categorical_value = categorical.value + + input_shape: ShapeRepr = input._temp_shape + target_shape: ShapeRepr = target._temp_shape + + if categorical_value is not TBD: + if not categorical_value: + updates |= target_shape._match(input_shape) + else: + N = Uniadic() + C = Uniadic() + var = Variadic() + in_repr = ShapeRepr([N, C], var) + target_repr = ShapeRepr([N], var) + updates = input_shape._match(in_repr) + updates = target_shape._match(target_repr) + + status = True + return status, updates + + type_constraints = { general_tensor_type_constraint, floor_divide_type_constraint, diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index cb39f014..f456247d 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -93,6 +93,7 @@ def __init__(self, enforce_jit: bool = True) -> None: self._jittable = True self.constraint_solver: ConstraintSolver = ConstraintSolver() self.safe_shapes: dict[str, ShapeTemplateType] = {} + self.is_frozen = False @abc.abstractmethod def summary( @@ -174,8 +175,8 @@ def __setattr__(self, name: str, value: Any): else: super().__setattr__(name, value) - @abc.abstractmethod - def _freeze(self) -> None: ... + def _freeze(self) -> None: + self.is_frozen = True @abc.abstractmethod def extract_connection_info( @@ -266,6 +267,17 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non updates = Updates() + # TODO: Currently Setting values in fozen models are prevented only for Tensors. + # Scalar and Tensors should not be operated differently. This should be fixed. + for key in values: + metadata = self.conns.extract_metadata(key) + if isinstance(metadata.data, Tensor) and model.is_frozen: + conn_data = model.conns.get_con_by_metadata(metadata) + assert conn_data is not None + raise ValueError( + f"Model is frozen, can not set the key: {conn_data.key}!" + ) + for key, value in values.items(): # Perform metadata extraction process on self. metadata = self.conns.extract_metadata(key) diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 1860ca0c..f79b7d24 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -129,14 +129,12 @@ def __call__( # type: ignore[override] return super().__call__(input=input, output=output) -ToTupleOutputType = tuple[int | float | bool | list | tuple, ...] - - class ToTuple(PrimitiveModel): def __init__(self, n: int) -> None: self.factory_args = {"n": n} - key_definitions = {} - key_definitions["output"] = Scalar(ToTupleOutputType) + key_definitions = { + "output": Scalar(tuple[int | float | bool | list | tuple, ...]) + } key_definitions |= { f"input{idx+1}": Scalar(int | float | bool | list | tuple) for idx in range(n) @@ -182,15 +180,14 @@ def __call__( # type: ignore[override] class Power(PrimitiveModel): base: Connection exponent: Connection - threshold: Connection output: Connection def __init__( self, robust: bool = False, - threshold: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL, ) -> None: - self.factory_args = {"threshold": threshold, "robust": robust} + self.robust = robust + self.factory_args = {"robust": robust} assert isinstance(robust, bool), "Robust must be a boolean value!" if robust: @@ -199,16 +196,10 @@ def __init__( output=TensorType([("Var_out", ...)]), base=TensorType([("Var_1", ...)]), exponent=TensorType([("Var_2", ...)]), - threshold=TensorType([], ConstantType, threshold), + threshold=TensorType([], ConstantType), ) - + self.threshold.set_differentiable(False) # type: ignore else: - if threshold != Constant.MIN_POSITIVE_NORMAL: - raise KeyError( - "Threshold cannot be specified \ - when robust mode is off" - ) - super().__init__( formula_key="power", output=TensorType([("Var_out", ...)]), @@ -228,13 +219,15 @@ def __call__( # type: ignore[override] self, base: ConnectionType = NOT_GIVEN, exponent: ConnectionType = NOT_GIVEN, - threshold: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, + *, + threshold: ConnectionType = Constant.MIN_POSITIVE_NORMAL, ) -> ExtendInfo: kwargs = {"base": base, "exponent": exponent, "output": output} - if "threshold" in self._input_keys: + is_constant = isinstance(threshold, Constant) + if self.robust: kwargs["threshold"] = threshold - elif threshold != NOT_GIVEN: + elif not (is_constant and threshold is Constant.MIN_POSITIVE_NORMAL): raise ValueError("Threshold cannot be specified when robust mode is off") return super().__call__(**kwargs) @@ -941,26 +934,21 @@ def __init__(self) -> None: class Sqrt(PrimitiveModel): input: Connection - cutoff: Connection output: Connection def __init__( self, robust: bool = False, - cutoff: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL, ) -> None: - self.factory_args = {"robust": robust, "cutoff": cutoff} + self.robust = robust + self.factory_args = {"robust": robust} if robust: - if isinstance(cutoff, str) and cutoff != Constant.MIN_POSITIVE_NORMAL: - raise ValueError(f"cutoff can only be set to 'min_positive_normal' \ - in string format, got {cutoff}") - super().__init__( formula_key="robust_sqrt", output=TensorType([("Var", ...)], float), input=TensorType([("Var", ...)]), - cutoff=TensorType([], ConstantType, cutoff), + cutoff=TensorType([], ConstantType), ) else: super().__init__( @@ -972,19 +960,17 @@ def __init__( def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, - cutoff: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, + *, + cutoff: ConnectionType = Constant.MIN_POSITIVE_NORMAL, ) -> ExtendInfo: kwargs = {"input": input, "output": output} - if self.formula_key == "sqrt" and cutoff != NOT_GIVEN: - raise ValueError( - "Sqrt does not accept cutoff argument \ - when initialized with robust = False." - ) - - if self.formula_key == "robust_sqrt": + is_constant = isinstance(cutoff, Constant) + if self.robust: kwargs["cutoff"] = cutoff + elif not (is_constant and cutoff == Constant.MIN_POSITIVE_NORMAL): + raise ValueError("Cutoff cannot be specified when robust mode is off") return super().__call__(**kwargs) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 4f4e00ea..49516d12 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -17,7 +17,6 @@ from types import UnionType from typing import Self, TypeVar, overload -from ...core import Constant from ...utils.utils import OrderedSet, find_dominant_type from ..common import ( NOT_AVAILABLE, @@ -154,10 +153,8 @@ class Model(BaseModel): def __init__( self, formula_key: str | None = None, enforce_jit: bool = True ) -> None: - self.passive_output = None - self.main_primitive = None self.dag: dict[BaseModel, dict[str, ConnectionData]] = {} - self.inter_key_count = 0 + self.inter_key_count: int = 0 self.formula_key = formula_key super().__init__(enforce_jit=enforce_jit) @@ -1044,30 +1041,6 @@ def extend( shape_info: dict[str, ShapeTemplateType] = dict() type_info: dict[str, type | UnionType] = dict() - # Check if any Tensor type key is already initialized with a value. - # This occurs only when a Primitive model having any Tensor type key - # initialized with a default value is extending the model. - for input_key in model._input_keys: - input_conn = model.conns.all.get(input_key) - assert input_conn is not None, "Connection type is not found!" - input_data = input_conn.metadata.data - if isinstance(input_data, Tensor) and not isinstance( - input_data.temp_value, ToBeDetermined - ): - if ( - (given_value := kwargs.get(input_key)) is not None - and given_value is not NOT_GIVEN - and given_value != input_data.temp_value - and not isinstance(given_value, Constant) - ): - raise ValueError( - f"Value of {model.__class__.__name__}'s {input_key} given " - f"as {given_value}. But the value is already initialized as " - f"{input_data.temp_value}" - ) - kwargs[input_key] = input_data.temp_value - input_data.temp_value = TBD - for key, value in kwargs.items(): # Check if given keys are among model's keys. if key not in model._input_keys | model.conns.output_keys: @@ -1231,35 +1204,9 @@ def extend( # Update jittablity by using model's jittablity. self._jittable &= model.jittable - # def __add__(self, model: Model | PrimitiveModel): - # """This function allows models to be added sequentially via "+=" operator. - # There are several conditions for a model to be sequentially added: - # if added model has single input, connect that input directly. - - # Parameters - # ---------- - # model : Model - # Other model to be sequentially added. - # """ - # if not (isinstance(model, BaseModel) or isinstance(model, PrimitiveModel)): - # raise TypeError("Added element should be a Model type.") - # kwargs = {} - # if self.canonical_output: - # kwargs = {model._canonical_input.key: self.canonical_output} - - # self.extend(model, **kwargs) - # return self - - def __add__(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: - """This function allows models to be added via "+=" operator. - There are several conditions for a model to be added: - if added model has single input, connect that input directly. - - Parameters - ---------- - model : Model - Other model to be added. - """ + def _extend(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: + if self.is_frozen: + raise AttributeError("Model is frozen and can not be extended!") model, kwargs = ( (info._model, info._connections) @@ -1299,6 +1246,18 @@ def __add__(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: self.extend(model, **kwargs) return self + def __add__(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: + """This function allows models to be added via "+=" operator. + There are several conditions for a model to be added: + if added model has single input, connect that input directly. + + Parameters + ---------- + model : Model + Other model to be added. + """ + return self._extend(info) + __iadd__ = __add__ @staticmethod @@ -1466,7 +1425,7 @@ def _freeze(self) -> None: self.dependency_map.update_all_keys() - # Sort and freeze dag + # Sort dag self.dag = {m: self.dag[m] for m in self.get_models_in_topological_order()} if self.formula_key is not None: # Must be convertable to primitive. @@ -1474,6 +1433,7 @@ def _freeze(self) -> None: "Logical models have altenative primitive implementation must " "have only 1 output." ) + super()._freeze() def summary( self, diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 415eada1..41b25e6f 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -118,6 +118,8 @@ def __init__(self, formula_key, **kwargs: Tensor | TensorType | Scalar) -> None: else: self._canonical_output = canonical_output_conn + self._freeze() + def __iadd__(self, other: BaseModel): raise Exception( f"Primitive '{self.__class__.__name__}' model can not be extended!" @@ -213,6 +215,3 @@ def summary( table._compile() table.display() - - def _freeze(self) -> None: - pass diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index fb736b91..cd706305 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -15,15 +15,18 @@ import math import warnings from collections.abc import Callable, Mapping, Sequence +from copy import deepcopy from functools import partial, reduce from ...backends.backend import Backend, ParallelBackend from ...core import DataType, GenericDataType from ...utils.type_utils import is_list_int from ..common import ( + NOT_GIVEN, TBD, Connection, ConnectionData, + IOKey, MainValueType, Scalar, Table, @@ -84,6 +87,21 @@ def __init__( if len(model._input_keys) == 0: raise ValueError("Model without input keys could not be compiled.") + if isinstance(model, PrimitiveModel): + # TODO: Remove wrapping with Model in the future. + model = deepcopy(model) + extend_info = model() + model_keys = {} + for key in model.external_keys: + value = extend_info._connections.get(key, NOT_GIVEN) + # NOTE: Do not set default value if it is given in constant_keys. + value = (value, NOT_GIVEN)[key in constant_keys] + if value is NOT_GIVEN: + model_keys[key] = key + else: + model_keys[key] = IOKey(key, value) # type: ignore + model = Model() + model(**model_keys) + self.backend: Backend[DataType] = backend self._output_keys: set[str] = set(model.conns.output_keys) diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index afce21f6..f55e29a5 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -31,6 +31,7 @@ concat_constraints, conv_1d_constraints, conv_2d_constraints, + cross_entropy_constraint, eye_constraints, flatten_constrains, general_tensor_type_constraint, @@ -45,6 +46,7 @@ tuple_converter_constraint, where_constrains, ) +from ..framework.logical.base import BaseModel from ..framework.logical.essential_primitives import SingleInputOperation from ..models import ExtendInfo, PrimitiveModel, Scalar, TensorType from ..utils.utils import PaddingType @@ -214,14 +216,13 @@ class QuantileLoss(PrimitiveModel): quantile: Connection output: Connection - def __init__(self, quantile: QuantileType | ToBeDetermined = 0.5) -> None: - self.factory_args = {"quantile": quantile} + def __init__(self) -> None: super().__init__( formula_key="quantile_loss", output=TensorType([("Var_1", ...)]), input=TensorType([("Var_2", ...)]), target=TensorType([("Var_3", ...)]), - quantile=TensorType([], QuantileType, quantile), + quantile=TensorType([], QuantileType), ) self._set_constraint( @@ -242,7 +243,7 @@ def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, target: ConnectionType = NOT_GIVEN, - quantile: ConnectionType = NOT_GIVEN, + quantile: QuantileType | ConnectionType = 0.5, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__( @@ -267,20 +268,9 @@ class CrossEntropy(PrimitiveModel): output: Connection def __init__( - self, - input_type: str = "logits", - weights: list[float] | str = "", - categorical: bool = True, - robust: bool | ToBeDetermined = False, - cutoff: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL, + self, input_type: str = "logits", weights: list[float] | str = "" ) -> None: - self.factory_args = { - "input_type": input_type, - "weights": weights, - "categorical": categorical, - "robust": robust, - "cutoff": cutoff, - } + self.factory_args = {"input_type": input_type, "weights": weights} weights_type: type = list[float] if isinstance(weights, str): @@ -296,13 +286,11 @@ def __init__( kwargs: dict[str, TensorType | Scalar] = { "output": TensorType(["N", ("Var", ...)], float), "input": TensorType(["N", "C", ("Var", ...)]), - "target": TensorType( - ["N", ("Var", ...)] if categorical else ["N", "C", ("Var", ...)] - ), + "target": TensorType(["N", ("VarTarget", ...)]), "weights": Scalar(weights_type, final_weights), - "categorical": Scalar(bool, categorical), - "cutoff": TensorType([], ConstantType, cutoff), - "robust": Scalar(bool, robust), + "categorical": Scalar(bool), + "cutoff": TensorType([], ConstantType), + "robust": Scalar(bool), } if input_type == "logits": @@ -321,27 +309,33 @@ def __init__( super().__init__(formula_key=formula_key, **kwargs) + self._set_constraint( + fn=cross_entropy_constraint, keys=["categorical", "input", "target"] + ) + def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, target: ConnectionType = NOT_GIVEN, weights: ConnectionType = NOT_GIVEN, - cutoff: ConnectionType = NOT_GIVEN, - robust: ConnectionType = NOT_GIVEN, + categorical: bool | ConnectionType = True, + cutoff: ConstantType | ConnectionType = Constant.MIN_POSITIVE_NORMAL, + robust: bool | ConnectionType = False, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: kwargs = { "input": input, "target": target, "weights": weights, + "categorical": categorical, "output": output, } # Check if the given argument set is valid. if self.formula_key == "cross_entropy_with_log_probs": args = [] - if robust != NOT_GIVEN: + if robust is not False: args.append("robust") - if cutoff != NOT_GIVEN: + if cutoff != Constant.MIN_POSITIVE_NORMAL: args.append("cutoff") if args: raise ValueError( @@ -363,17 +357,13 @@ class KLDivergence(PrimitiveModel): cutoff: Connection output: Connection - def __init__( - self, cutoff: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL - ) -> None: - if cutoff != Constant.MIN_POSITIVE_NORMAL: - self.factory_args = {"cutoff": cutoff} + def __init__(self) -> None: super().__init__( formula_key="kl_divergence", output=TensorType([("Var_1", ...)], float), input=TensorType([("Var_2", ...)]), target=TensorType([("Var_3", ...)]), - cutoff=TensorType([], ConstantType, cutoff), + cutoff=TensorType([], ConstantType), ) self.safe_shapes = { @@ -389,7 +379,7 @@ def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, target: ConnectionType = NOT_GIVEN, - cutoff: ConnectionType = NOT_GIVEN, + cutoff: ConnectionType = Constant.MIN_POSITIVE_NORMAL, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__( @@ -410,18 +400,9 @@ class BinaryCrossEntropy(PrimitiveModel): output: Connection def __init__( - self, - input_type: str = "logits", - pos_weight: float | str | ToBeDetermined = 1.0, - cutoff: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL, - robust: bool | ToBeDetermined = False, + self, input_type: str = "logits", pos_weight: float | str | ToBeDetermined = 1.0 ) -> None: - self.factory_args = { - "input_type": input_type, - "pos_weight": pos_weight, - "cutoff": cutoff, - "robust": robust, - } + self.factory_args = {"input_type": input_type, "pos_weight": pos_weight} if isinstance(pos_weight, str): if pos_weight != "auto": @@ -441,8 +422,8 @@ def __init__( [("Var_out", ...)], int | float ), # NOTE: Target can also be probabilistic, so float is acceptable. "pos_weight": Scalar(pos_weight_type, pos_weight), - "cutoff": TensorType([], ConstantType, cutoff), - "robust": Scalar(bool, robust), + "cutoff": TensorType([], ConstantType), + "robust": Scalar(bool), } if input_type == "logits": @@ -465,8 +446,8 @@ def __call__( # type: ignore[override] input: ConnectionType = NOT_GIVEN, target: ConnectionType = NOT_GIVEN, pos_weight: ConnectionType = NOT_GIVEN, - cutoff: ConnectionType = NOT_GIVEN, - robust: ConnectionType = NOT_GIVEN, + cutoff: ConstantType | ConnectionType = Constant.MIN_POSITIVE_NORMAL, + robust: bool | ConnectionType = False, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__( @@ -482,21 +463,20 @@ def __call__( # type: ignore[override] class Log(PrimitiveModel): input: Connection output: Connection - cutoff: Connection def __init__( self, robust: bool = False, - cutoff: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL, ) -> None: - self.factory_args = {"cutoff": cutoff, "robust": robust} + self.robust = robust + self.factory_args = {"robust": robust} if robust: super().__init__( formula_key="robust_log", output=TensorType([("Var", ...)], float), input=TensorType([("Var", ...)]), - cutoff=TensorType([], ConstantType, cutoff), + cutoff=TensorType([], ConstantType), ) else: super().__init__( @@ -510,18 +490,15 @@ def __call__( # type: ignore[override] input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, *, - cutoff: ConnectionType = NOT_GIVEN, + cutoff: ConnectionType = Constant.MIN_POSITIVE_NORMAL, ) -> ExtendInfo: kwargs = {"input": input, "output": output} - if self.formula_key == "log" and cutoff != NOT_GIVEN: - raise ValueError( - "Log does not accept cutoff argument \ - when initialized with robust = False." - ) - - if self.formula_key == "robust_log": + is_constant = isinstance(cutoff, Constant) + if self.robust: kwargs["cutoff"] = cutoff + elif not (is_constant and cutoff == Constant.MIN_POSITIVE_NORMAL): + raise ValueError("Cutoff cannot be specified when robust mode is off") return super().__call__(**kwargs) @@ -533,22 +510,18 @@ class StableReciprocal(PrimitiveModel): def __init__( self, - cutoff: ConstantType | ToBeDetermined = Constant.STABLE_RECIPROCAL_THRESHOLD, ) -> None: - if cutoff != Constant.STABLE_RECIPROCAL_THRESHOLD: - self.factory_args = {"cutoff": cutoff} - super().__init__( formula_key="stable_reciprocal", output=TensorType([("Var", ...)], float), input=TensorType([("Var", ...)]), - cutoff=TensorType([], ConstantType, cutoff), + cutoff=TensorType([], ConstantType), ) def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, - cutoff: ConnectionType = NOT_GIVEN, + cutoff: ConnectionType = Constant.STABLE_RECIPROCAL_THRESHOLD, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__(input=input, cutoff=cutoff, output=output) @@ -647,9 +620,16 @@ def __init__(self) -> None: class Softmax(Activation): - def __init__(self, axis: int | None | ToBeDetermined = -1) -> None: - self.factory_args = {"axis": axis} - super().__init__(formula_key="softmax", axis=Scalar(int | None, axis)) + def __init__(self) -> None: + super().__init__(formula_key="softmax", axis=Scalar(int | None)) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + axis: ConnectionType = -1, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return BaseModel.__call__(self, input=input, axis=axis, output=output) class Softplus(Activation): @@ -667,17 +647,13 @@ class LeakyRelu(Activation): output: Connection slope: Connection - def __init__(self, slope: float | ToBeDetermined = 0.01) -> None: - self.factory_args = {"slope": slope} - - super().__init__( - formula_key="leaky_relu", slope=TensorType([], int | float, slope) - ) + def __init__(self) -> None: + super().__init__(formula_key="leaky_relu", slope=TensorType([], float)) def __call__( # type: ignore[override] self, input: ConnectionType = NOT_GIVEN, - slope: ConnectionType = NOT_GIVEN, + slope: ConnectionType = 0.01, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return PrimitiveModel.__call__(self, input=input, slope=slope, output=output) @@ -1293,32 +1269,23 @@ class TsnePJoint(PrimitiveModel): disposable = True - def __init__( - self, - target_perplexity: float | ToBeDetermined = TBD, - threshold: ConstantType | ToBeDetermined = Constant.EPSILON, - ) -> None: - if threshold != Constant.EPSILON: - self.factory_args = {"threshold": threshold} - if target_perplexity is not TBD: - self.factory_args |= {"target_perplexity": target_perplexity} - + def __init__(self) -> None: super().__init__( formula_key="tsne_p_joint", output=TensorType(["N", "M"], float), squared_distances=TensorType( ["N", "M"] ), # TODO: Can we say anything about the type of distances? - target_perplexity=TensorType([], float, target_perplexity), - threshold=TensorType([], ConstantType, threshold), + target_perplexity=TensorType([], float), + threshold=TensorType([], ConstantType), ) def __call__( # type: ignore[override] self, squared_distances: ConnectionType = NOT_GIVEN, - target_perplexity: ConnectionType = NOT_GIVEN, + target_perplexity: float | ConnectionType = NOT_GIVEN, threshold: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, + output: ConstantType | ConnectionType = Constant.EPSILON, ) -> ExtendInfo: return super().__call__( squared_distances=squared_distances, @@ -1563,25 +1530,20 @@ class Eigvalsh(PrimitiveModel): threshold: Connection output: Connection - def __init__( - self, threshold: ConstantType | ToBeDetermined = Constant.EPSILON - ) -> None: - if threshold != Constant.EPSILON: - self.factory_args = {"threshold": threshold} - + def __init__(self) -> None: super().__init__( formula_key="eigvalsh", output=TensorType(["N", 1], float), # TODO: Is it always float? K_term=TensorType(["N", "N"]), L=TensorType(["N", "N"]), - threshold=TensorType([], ConstantType, threshold), + threshold=TensorType([], ConstantType), ) def __call__( # type: ignore[override] self, K_term: ConnectionType = NOT_GIVEN, L: ConnectionType = NOT_GIVEN, - threshold: ConnectionType = NOT_GIVEN, + threshold: ConstantType | ConnectionType = Constant.EPSILON, output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__(K_term=K_term, L=L, threshold=threshold, output=output) @@ -1709,13 +1671,18 @@ def __call__( # type: ignore[override] query: ConnectionType = NOT_GIVEN, key: ConnectionType = NOT_GIVEN, value: ConnectionType = NOT_GIVEN, - attn_mask: ConnectionType = NOT_GIVEN, dropout_p: ConnectionType = NOT_GIVEN, is_causal: ConnectionType = NOT_GIVEN, scale: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, + *, + attn_mask: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: - if not self.use_attn_mask and attn_mask != NOT_GIVEN: + if ( + not self.use_attn_mask + and attn_mask != NOT_GIVEN + and not isinstance(attn_mask, str) + ): raise KeyError( "Model does not have 'attn_mask' input. Got attn_mask argument!" ) @@ -1728,6 +1695,7 @@ def __call__( # type: ignore[override] is_causal=is_causal, scale=scale, output=output, + attn_mask=attn_mask, ) diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index ff9b3960..239d33e9 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -18,6 +18,7 @@ from typing import Any from ..framework import ( + NOT_GIVEN, BaseModel, Connection, ConnectionData, @@ -148,8 +149,12 @@ def add_loss( **kwargs, ) -> None: # If provided key namings does not match with Loss model - keys = set(loss_model._input_keys) - loss_model.conns.get_non_diff_keys() - if set(kwargs.keys()) != keys: + if { + key + for key, value in loss_model(**kwargs)._connections.items() + if value is NOT_GIVEN and key in loss_model._input_keys + } - loss_model.conns.get_non_diff_keys(): + # if set(kwargs.keys()) != keys: raise KeyError("The provided keys do not match the model's loss.") outputs_conns_metadata = set() @@ -214,8 +219,7 @@ def add_loss( for key in kwargs: if key in loss_model.conns.output_keys: raise KeyError("Output of the loss model cannot be defined!") - # self._extend(loss_model, **kwargs) - self.extend(loss_model, **kwargs) + self._extend(loss_model(**kwargs)) prev_out_key = self.get_single_output(loss_model).data if (prev_con := self.conns.get_con_by_metadata(prev_out_key.metadata)) is None: raise KeyError("Given key does not belong to the Model!") diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index e5c43287..dca2287f 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -185,12 +185,12 @@ def dict_to_model(modelparams: dict[str, Any]) -> BaseModel: else: connection_mappings[key] = value - m.set_values(constant_mappings) # type: ignore[arg-type] if m.canonical_input.key in constant_mappings: connection_mappings.setdefault(m.canonical_input.key, "") assert isinstance(model, Model) model += m(**connection_mappings) + m.set_values(constant_mappings) # type: ignore[arg-type] if "model" in canonical_keys: candidate_canonical_in = model.conns.get_connection(canonical_keys["model"][0]) diff --git a/tests/json_files/error_test.json b/tests/json_files/error_test.json index 191147c3..51f6d256 100644 --- a/tests/json_files/error_test.json +++ b/tests/json_files/error_test.json @@ -67,9 +67,6 @@ "model": { "name": "Model" }, - "static_keys": { - "input": [[1.0]] - }, "inputs": { "w0": [[1.0, 2, 3]], "b0": [-2.0, -3, 0], diff --git a/tests/json_files/integration_directed_test.json b/tests/json_files/integration_directed_test.json index a03d5afa..aac468b8 100644 --- a/tests/json_files/integration_directed_test.json +++ b/tests/json_files/integration_directed_test.json @@ -313,7 +313,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.5 } }, @@ -352,7 +352,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.9 } }, @@ -392,7 +392,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.9 } }, @@ -431,7 +431,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 1e-15 } }, @@ -470,7 +470,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.2 } }, @@ -509,7 +509,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.99999999999 } }, @@ -555,7 +555,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.99999999999 } }, @@ -601,7 +601,7 @@ "output": { "loss": { "fn": "QuantileLoss", - "params": { + "call_kwargs": { "quantile": 0.99999999999 } }, diff --git a/tests/json_files/models_directed_test.json b/tests/json_files/models_directed_test.json index 434f80ac..7938553d 100644 --- a/tests/json_files/models_directed_test.json +++ b/tests/json_files/models_directed_test.json @@ -164,8 +164,8 @@ "sigma": [0.70710678118] }, "output_grads": { - "output": [[1.0, 2.0, 3.0], - [2.0, 3.0, 1.0], + "output": [[1.0, 2.0, 3.0], + [2.0, 3.0, 1.0], [2.0, 2.0, 2.0]] }, "results": { @@ -225,8 +225,8 @@ "degree": [2.0] }, "output_grads": { - "output": [[1.0, 2.0, 3.0], - [1.0, 3.0, 1.0], + "output": [[1.0, 2.0, 3.0], + [1.0, 3.0, 1.0], [2.0, 2.0, 2.0]] }, "results": { @@ -346,7 +346,7 @@ "input2": [[24.0], [6.0], [30]], "norm": [11.090354888959123] } - } + } }, "test_distance_matrix_2": { "model": { @@ -374,7 +374,7 @@ "input2": [[5.0], [2.0], [1.0]], "norm": [0.0] } - } + } }, "test_distance_matrix_3": { "model": { @@ -382,11 +382,11 @@ "differentiability_info": {"input1": true, "input2": true} }, "inputs": { - "input1": [[1.0, 2.0], - [2.0, 2.0], + "input1": [[1.0, 2.0], + [2.0, 2.0], [3.0, 1.0]], - "input2": [[1.0, 2.0], - [2.0, 2.0], + "input2": [[1.0, 2.0], + [2.0, 2.0], [3.0, 1.0]], "norm": [3.0] }, @@ -402,15 +402,15 @@ [2.080083823051904, 1.2599210498948732, 0.0]] }, "grad": { - "input1": [[-4.773445097402539, 0.6933612743506348], - [-2.7400789501051266, -1.2599210498948732], + "input1": [[-4.773445097402539, 0.6933612743506348], + [-2.7400789501051266, -1.2599210498948732], [-0.9244816991341797, 0.23112042478354491]], - "input2": [[4.92448169913418, -0.23112042478354491], - [2.0, 0.0], + "input2": [[4.92448169913418, -0.23112042478354491], + [2.0, 0.0], [1.513524047507666, 0.5665597755442384]], "norm": [0.03282460531963896] } - } + } }, "test_distance_matrix_4": { "model": { @@ -441,7 +441,7 @@ "input2": [[16.0], [90.0], [2]], "norm": [448.34546617728535] } - } + } }, "test_distance_matrix_5": { @@ -470,7 +470,7 @@ "input2": [[4.0], [8.0], [0]], "norm": [0.0] } - } + } }, "test_distance_matrix_6": { @@ -500,11 +500,11 @@ [4.0, -1.2599210498948736]], "input2": [[-4.46224084956708985, -0.1510366017316406], - [-1.879791042496866064099,7.256026687221401976891], + [-1.879791042496866064099,7.256026687221401976891], [-0.3700394750525632,0.6299605249474368]], "norm": [-0.9621774874394304] } - } + } }, "test_poly_features_1":{ "model": { @@ -526,7 +526,7 @@ "grad": { "input": [[25.0, 40.0]] } - } + } }, "test_poly_features_2":{ "model": { @@ -548,7 +548,7 @@ "grad": { "input": [[5.0], [6], [3], [18]] } - } + } }, "test_poly_features_3":{ "model": { @@ -561,22 +561,22 @@ "input": [[1.0, 1], [2, 3], [3, -1], [0, 10]] }, "output_grads": { - "output": [[1.0, 2, 0, 1, 2], - [0, 3, 2, 1, 1], - [1, -1, 1, -3, 0], + "output": [[1.0, 2, 0, 1, 2], + [0, 3, 2, 1, 1], + [1, -1, 1, -3, 0], [0, 1, 0, 0, 0]] }, "results": { "eval": { - "output": [[1.0, 1, 1, 1, 1], - [2, 3, 4, 6, 9], - [3, -1, 9, -3, 1], + "output": [[1.0, 1, 1, 1, 1], + [2, 3, 4, 6, 9], + [3, -1, 9, -3, 1], [0, 10, 0, 0, 100]] }, "grad": { "input": [[2.0, 7], [11, 11], [10, -10], [0, 1]] } - } + } }, "test_poly_features_4":{ "model": { @@ -589,18 +589,18 @@ "input": [[1.0, -2], [2, 1]] }, "output_grads": { - "output": [[0.0, 2, 0, 1, 2, 0, 1, 0, 0], + "output": [[0.0, 2, 0, 1, 2, 0, 1, 0, 0], [1.0, 3, 1, 0, 2, 2, 0, 0, 1]] }, "results": { "eval": { - "output": [[1.0, -2, 1, -2, 4, 1, -2, 4, -8], + "output": [[1.0, -2, 1, -2, 4, 1, -2, 4, -8], [2.0, 1, 4, 2, 1, 8, 4, 2, 1]] }, "grad": { "input": [[-6.0, -4], [29, 10]] } - } + } }, "test_mds_core_1": { "model": { @@ -639,7 +639,7 @@ [0, 0, 0, 0]], "norm": 0.0 } - } + } }, "test_mds_core_2": { "model": { @@ -674,7 +674,7 @@ [-0.08333333333333333, 0.05555555555555555, 0]], "norm": -1.0986122886681096 } - } + } }, "test_mds_core_3": { "model": { @@ -742,10 +742,10 @@ [0.0,0.0]], "norm": 1.32459464e-152 } - } + } }, "test_mds_2": { - "NOTE": "Prove the differences of diagonal entries of input gradient because of linearized models like power, log etc.", + "NOTE": "Prove the differences of diagonal entries of input gradient wrt composite-ml is because of linearized models like power, log etc.", "model": { "name": "MDS", "differentiability_info": {"input": true}, @@ -773,12 +773,12 @@ "input": [[-0.03333333333333333, -0.06666666666666667, -0.025], [-0.06666666666666667, -0.03333333333333333, -0.16666666666666666], [-0.025, -0.16666666666666666, -0.03333333333333333]], - "coords": [[-0.39999999999999997], - [0.9999999999999996], + "coords": [[-0.39999999999999997], + [0.9999999999999996], [-0.5999999999999996]], "norm": 0.2772588722239777 } - } + } }, "test_tsne_core_1": { "model": { @@ -792,7 +792,7 @@ "pred_distances": [[0.0, 1, 9], [1, 0, 4], [9, 4, 0]] }, "static_keys": { - "distances": [[0.0, 1, 9], [1, 0, 4], [9, 4, 0]], + "distances": [[0.0, 1, 9], [1, 0, 4], [9, 4, 0]], "TsnePJoint_0_output": [[0, 0.16666654, 0.16666673], [0.16666654, 0, 0.16666673], [0.16666673, 0.16666673, 0]] @@ -822,7 +822,7 @@ }, "test_tsne_core_2": { "model": { - "name": "TSNECore", + "name": "TSNECore", "args": { "exact_distances": false } @@ -831,7 +831,7 @@ "pred_distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]] }, "static_keys": { - "distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]], + "distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]], "TsnePJoint_0_output": [[0, 0.1, 0.2], [0.1, 0, 0.2], [0.2, 0.2, 0]] @@ -862,9 +862,9 @@ } }, "test_tsne_core_3": { - "NOTE": "In this test, gradients can not be obtained when P values are wrong and does not sum to 1.0. We can only use the short gradient formula used in there only when sum(P) = 1.0", + "NOTE": "In this test, gradients can not be obtained as in composite-ml. Because P values are given wrong in composite-ml which does not sum to 1.0. We can only use the short gradient formula used in there only when sum(P) = 1.0", "model": { - "name": "TSNECore", + "name": "TSNECore", "args": { "exact_distances": false } @@ -873,8 +873,8 @@ "pred_distances": [[0.0, 16], [16, 0]] }, "static_keys": { - "distances": [[0.0, 16], [16, 0]], - "TsnePJoint_0_output": [[0, 0.5], + "distances": [[0.0, 16], [16, 0]], + "TsnePJoint_0_output": [[0, 0.5], [0.5, 0]] }, "injected_static_keys": ["TsnePJoint_0_output"], @@ -914,8 +914,8 @@ }, "static_keys": { "input": [[5.0], [1]], - "norm": 2.0, - "TSNECore_2_TsnePJoint_0_output": [[0, 0.5], + "norm": 2.0, + "TSNECore_2_TsnePJoint_0_output": [[0, 0.5], [0.5, 0]] }, "injected_static_keys": ["TSNECore_2_TsnePJoint_0_output"], @@ -1358,7 +1358,7 @@ "bias": [2.0] } } - }, + }, "test_conv1d_5": { "model": { "name": "Convolution1D", @@ -2244,7 +2244,7 @@ [10, 11, 12, 12, 12], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]], - + [[ 0.0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 12, 12], @@ -2255,7 +2255,7 @@ }, "output_grads": { "output": [ - + [ [ [1.0, 2.0], @@ -2278,7 +2278,7 @@ [3.0, 4.0] ] ] - + ] }, "results": { @@ -2481,7 +2481,7 @@ [true, false, true], [false, true, false], [true, false, true] - ]]] + ]]] }, "inputs": { "input1": [[[ @@ -2530,7 +2530,7 @@ [true, false, true], [false, true, false], [true, false, true] - ]]] + ]]] }, "inputs": { "input1": [[ @@ -2819,7 +2819,7 @@ "name": "Multiply" }, "inputs": { - "left": [[1.0], [2.0], [3.0], [4.0]], + "left": [[1.0], [2.0], [3.0], [4.0]], "right": [[1.0], [2.0], [3.0], [5.0]] }, @@ -2831,7 +2831,7 @@ "output": [[1.0], [4.0], [9.0], [20.0]] }, "grad": { - "left": [[1.0], [4.0], [9.0], [30.0]], + "left": [[1.0], [4.0], [9.0], [30.0]], "right": [[1.0], [4.0], [9.0], [24.0]] } } @@ -2843,7 +2843,7 @@ "name": "Multiply" }, "inputs": { - "left": [2.0], + "left": [2.0], "right": [[1.0], [2.0], [3.0], [4.0]] }, @@ -2867,7 +2867,7 @@ "name": "Multiply" }, "inputs": { - "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "right": [1.0, 2.0, 3.0] }, @@ -2915,7 +2915,7 @@ "name": "Multiply" }, "inputs": { - "left": [[1.0, 2.0], [3.0, 4.0]], + "left": [[1.0, 2.0], [3.0, 4.0]], "right": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -2939,7 +2939,7 @@ "name": "Multiply" }, "inputs": { - "left": [1.0, 2.0], + "left": [1.0, 2.0], "right": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -2961,7 +2961,7 @@ "name": "Multiply" }, "inputs": { - "left": [[1.0],[2],[3]], + "left": [[1.0],[2],[3]], "right": [[2.0,3]] }, "output_grads": { @@ -2982,7 +2982,7 @@ "name": "Divide" }, "inputs": { - "numerator": [[1.0], [2.0], [3.0], [4.0]], + "numerator": [[1.0], [2.0], [3.0], [4.0]], "denominator": [[1.0], [2.0], [3.0], [5.0]] }, @@ -3006,7 +3006,7 @@ "name": "Divide" }, "inputs": { - "numerator": [2.0], + "numerator": [2.0], "denominator": [[1.0], [2.0], [3.0], [4.0]] }, @@ -3030,7 +3030,7 @@ "name": "Divide" }, "inputs": { - "numerator": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + "numerator": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "denominator": [1.0, 2.0, 3.0] }, @@ -3078,7 +3078,7 @@ "name": "Divide" }, "inputs": { - "numerator": [[1.0, 2.0], [3.0, 4.0]], + "numerator": [[1.0, 2.0], [3.0, 4.0]], "denominator": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3100,7 +3100,7 @@ "name": "Divide" }, "inputs": { - "numerator": [1.0, 2.0], + "numerator": [1.0, 2.0], "denominator": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, "output_grads": { @@ -3122,7 +3122,7 @@ "name": "Add" }, "inputs": { - "left": [[1.0], [2.0], [3.0], [4.0]], + "left": [[1.0], [2.0], [3.0], [4.0]], "right": [[1.0], [2.0], [3.0], [5.0]] }, @@ -3134,7 +3134,7 @@ "output": [[2.0], [4.0], [6.0], [9.0]] }, "grad": { - "left": [[1.0], [2.0], [3.0], [6.0]], + "left": [[1.0], [2.0], [3.0], [6.0]], "right": [[1.0], [2.0], [3.0], [6.0]] } } @@ -3146,7 +3146,7 @@ "name": "Add" }, "inputs": { - "left": [2.0], + "left": [2.0], "right": [[1.0], [2.0], [3.0], [4.0]] }, @@ -3170,7 +3170,7 @@ "name": "Add" }, "inputs": { - "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "right": [1.0, 2.0, 3.0] }, @@ -3218,7 +3218,7 @@ "name": "Add" }, "inputs": { - "left": [[1.0, 2.0], [3.0, 4.0]], + "left": [[1.0, 2.0], [3.0, 4.0]], "right": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3242,7 +3242,7 @@ "name": "Add" }, "inputs": { - "left": [1.0, 2.0], + "left": [1.0, 2.0], "right": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3265,7 +3265,7 @@ "name": "Add" }, "inputs": { - "left": [[[[[2.0,3,4]]]]], + "left": [[[[[2.0,3,4]]]]], "right": [[[2.0]],[[3]],[[4]]] }, @@ -3308,7 +3308,7 @@ }, "test_sum_9": { "NOTE": "Input shape is (2, 2, 1, 1, 1, 1) and rhs shape is (2, 1, 3) and expected shape is (2, 2, 1, 2, 1, 3). This is a test for testing broadcast algorithm rather than a numeric test ", - + "model": { "name": "Add" }, @@ -3336,7 +3336,7 @@ "name": "Subtract" }, "inputs": { - "left": [[1.0], [2.0], [3.0], [4.0]], + "left": [[1.0], [2.0], [3.0], [4.0]], "right": [[1.0], [2.0], [3.0], [5.0]] }, @@ -3348,7 +3348,7 @@ "output": [[0.0], [0.0], [0.0], [-1.0]] }, "grad": { - "left": [[1.0], [2.0], [3.0], [6.0]], + "left": [[1.0], [2.0], [3.0], [6.0]], "right": [[-1.0], [-2.0], [-3.0], [-6.0]] } } @@ -3360,7 +3360,7 @@ "name": "Subtract" }, "inputs": { - "left": [2.0], + "left": [2.0], "right": [[1.0], [2.0], [3.0], [4.0]] }, @@ -3384,7 +3384,7 @@ "name": "Subtract" }, "inputs": { - "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "right": [1.0, 2.0, 3.0] }, @@ -3432,7 +3432,7 @@ "name": "Subtract" }, "inputs": { - "left": [[1.0, 2.0], [3.0, 4.0]], + "left": [[1.0, 2.0], [3.0, 4.0]], "right": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3456,7 +3456,7 @@ "name": "Subtract" }, "inputs": { - "left": [1.0, 2.0], + "left": [1.0, 2.0], "right": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3480,7 +3480,7 @@ "name": "Power" }, "inputs": { - "base": [[1.0], [2.0], [3.0], [4.0]], + "base": [[1.0], [2.0], [3.0], [4.0]], "exponent": [[1.0], [2.0], [3.0], [5.0]] }, @@ -3504,7 +3504,7 @@ "name": "Power" }, "inputs": { - "base": [2.0], + "base": [2.0], "exponent": [[1.0], [2.0], [3.0], [4.0]] }, @@ -3528,7 +3528,7 @@ "name": "Power" }, "inputs": { - "base": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + "base": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "exponent": [1.0, 2.0, 3.0] }, @@ -3576,7 +3576,7 @@ "name": "Power" }, "inputs": { - "base": [[1.0, 2.0], [3.0, 4.0]], + "base": [[1.0, 2.0], [3.0, 4.0]], "exponent": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3599,7 +3599,7 @@ "name": "Power" }, "inputs": { - "base": [1.0, 2.0], + "base": [1.0, 2.0], "exponent": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] }, @@ -3810,7 +3810,7 @@ } } }, - + "test_flatten2": { "model": { "name": "Flatten" @@ -4032,35 +4032,35 @@ "args": { "axis": -1, "n": 2 - + } }, "inputs": { "input1": [ - [1.0, 2.0], - [3.0, 4.0], + [1.0, 2.0], + [3.0, 4.0], [5.0, 6.0] - ], + ], "input2": [[1.0], [2.0], [3.0]] }, "output_grads": { "output": [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], [7.0, 8.0, 9.0] ] }, "results": { "eval": { "output": [ - [1.0, 2.0, 1.0], - [3.0, 4.0, 2.0], + [1.0, 2.0, 1.0], + [3.0, 4.0, 2.0], [5.0, 6.0, 3.0] ] }, "grad": { - "input1": [[1.0, 2.0], [4.0, 5.0], [7.0, 8.0]], + "input1": [[1.0, 2.0], [4.0, 5.0], [7.0, 8.0]], "input2": [[3.0], [6.0], [9.0]] } } @@ -4077,44 +4077,44 @@ }, "inputs": { "input1": [ - [[1.0, 2.0], - [3.0, 4.0]], - [[5.0, 6.0], + [[1.0, 2.0], + [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]] - ], + ], "input2": [ - [[1.0, 2.0, 3.0, 4.0], - [3.0, 4.0, 5.0, 6.0]], - [[5.0, 6.0, 7.0, 8.0], + [[1.0, 2.0, 3.0, 4.0], + [3.0, 4.0, 5.0, 6.0]], + [[5.0, 6.0, 7.0, 8.0], [7.0, 8.0, 9.0, 10.0]] ] }, "output_grads": { "output": [ - [[1.0, 2.0, 1.0, 2.0, 1.0, 2.0], - [1.0, 2.0, 1.0, 2.0, 1.0, 2.0]], - [[2.0, 3.0, 2.0, 3.0, 2.0, 3.0], + [[1.0, 2.0, 1.0, 2.0, 1.0, 2.0], + [1.0, 2.0, 1.0, 2.0, 1.0, 2.0]], + [[2.0, 3.0, 2.0, 3.0, 2.0, 3.0], [3.0, 2.0, 3.0, 2.0, 3.0, 2.0]] ] }, "results": { "eval": { "output": [ - [[1.0, 2.0, 1.0, 2.0, 3.0, 4.0], - [3.0, 4.0, 3.0, 4.0, 5.0, 6.0]], - [[5.0, 6.0, 5.0, 6.0, 7.0, 8.0], + [[1.0, 2.0, 1.0, 2.0, 3.0, 4.0], + [3.0, 4.0, 3.0, 4.0, 5.0, 6.0]], + [[5.0, 6.0, 5.0, 6.0, 7.0, 8.0], [7.0, 8.0, 7.0, 8.0, 9.0, 10.0]] ] }, "grad": { - "input1": [[[1.0, 2.0], - [1.0, 2.0]], - [[2.0, 3.0], - [3.0, 2.0]]], - "input2": [[[1.0, 2.0, 1.0, 2.0], - [1.0, 2.0, 1.0, 2.0]], - [[2.0, 3.0, 2.0, 3.0], + "input1": [[[1.0, 2.0], + [1.0, 2.0]], + [[2.0, 3.0], + [3.0, 2.0]]], + "input2": [[[1.0, 2.0, 1.0, 2.0], + [1.0, 2.0, 1.0, 2.0]], + [[2.0, 3.0, 2.0, 3.0], [3.0, 2.0, 3.0, 2.0]]] } } @@ -4132,8 +4132,8 @@ }, "inputs": { "input": [ - [1.0, 2.0], - [3.0, 4.0], + [1.0, 2.0], + [3.0, 4.0], [5.0, 6.0] ] }, @@ -4150,8 +4150,8 @@ ] }, "grad": { - "input": [[5.0, 6.0], - [0.0, 0.0], + "input": [[5.0, 6.0], + [0.0, 0.0], [0.0, 0.0]] } } @@ -4169,8 +4169,8 @@ }, "inputs": { "input": [ - [[1.0, 2.0]], - [[3.0, 4.0]], + [[1.0, 2.0]], + [[3.0, 4.0]], [[5.0, 6.0]] ] }, @@ -4190,8 +4190,8 @@ }, "grad": { "input": [ - [[3.0, 0.0]], - [[2.0, 1.0]], + [[3.0, 0.0]], + [[2.0, 1.0]], [[0.0, 0.0]] ] } @@ -4307,7 +4307,7 @@ } } }, - + "test_sigmoid_1": { @@ -4440,7 +4440,7 @@ } } }, - + "test_softplus_1": { "model": { @@ -4535,8 +4535,8 @@ }, "inputs": { "input": [ - [1.0, 2.0], - [3.0, 4.0], + [1.0, 2.0], + [3.0, 4.0], [5.0, 6.0] ] }, @@ -4545,23 +4545,23 @@ }, "output_grads": { "output": [ - [-1.0, 3.0], - [5.0, 2.0], + [-1.0, 3.0], + [5.0, 2.0], [4.0, 1.0] ] }, "results": { "eval": { "output": [ - [5.0, 6.0], - [1.0, 2.0], + [5.0, 6.0], + [1.0, 2.0], [3.0, 4.0] ] }, "grad": { "input": [ - [5.0, 2.0], - [4.0, 1.0], + [5.0, 2.0], + [4.0, 1.0], [-1.0, 3.0] ] } @@ -4573,8 +4573,8 @@ }, "inputs": { "input": [ - [[1.0, 2.0]], - [[3.0, 4.0]], + [[1.0, 2.0]], + [[3.0, 4.0]], [[5.0, 6.0]] ] }, @@ -4583,23 +4583,23 @@ }, "output_grads": { "output": [ - [[3.0, 0.0]], - [[4.0, 1.0]], + [[3.0, 0.0]], + [[4.0, 1.0]], [[5.0, 2.0]] ] }, "results": { "eval": { "output": [ - [[1.0, 2.0]], - [[5.0, 6.0]], + [[1.0, 2.0]], + [[5.0, 6.0]], [[3.0, 4.0]] ] }, "grad": { "input":[ - [[3.0, 0.0]], - [[5.0, 2.0]], + [[3.0, 0.0]], + [[5.0, 2.0]], [[4.0, 1.0]] ] } @@ -4740,7 +4740,7 @@ }, "grad": { "input": [[-0.5], [0.5]] - + } } }, @@ -4762,7 +4762,7 @@ "output": [[1.0], [1.0]] }, "grad": { - "input": [[-1.0], [1.0]] + "input": [[-1.0], [1.0]] } } }, @@ -4771,7 +4771,7 @@ "name": "AbsoluteError" }, "inputs": { - "input": [[1.0], [2.0], [3.0], [4.0]], + "input": [[1.0], [2.0], [3.0], [4.0]], "target": [[0.7], [0.9], [1.1], [1.3]] }, "output_grads": { @@ -4786,7 +4786,7 @@ "jax": [[1.0], [1.0], [1.0], [1.0]], "torch": [[1.0], [1.0], [1.0], [1.0]], "numpy": [[1.0], [1.0], [1.0], [1.0]] - }, + }, "target": { "jax": [[-1.0], [-1.0], [-1.0], [-1.0]], "torch": [[-1.0], [-1.0], [-1.0], [-1.0]], @@ -4800,7 +4800,7 @@ "name": "AbsoluteError" }, "inputs": { - "input": [[1.0], [-1000000.0], [0.000000000001], [-1.5]], + "input": [[1.0], [-1000000.0], [0.000000000001], [-1.5]], "target": [[1.0], [-1000000.0], [0.000000000001], [-1.5]] }, "output_grads": { @@ -4815,7 +4815,7 @@ "jax": [[1.0], [1.0], [1.0], [1.0]], "torch": [[0.0], [0.0], [0.0], [0.0]], "numpy": [[0.0], [0.0], [0.0], [0.0]] - }, + }, "target": { "jax": [[-1.0], [-1.0], [-1.0], [-1.0]], "torch": [[0.0], [0.0], [0.0], [0.0]], @@ -4830,7 +4830,7 @@ "name": "AbsoluteError" }, "inputs": { - "input": [[0.1, 0.2], [1000000000.0, 0.0000000001]], + "input": [[0.1, 0.2], [1000000000.0, 0.0000000001]], "target": [[0.1, 0.2], [0.0000000001, 1000000000.0]] }, "output_grads": { @@ -4845,7 +4845,7 @@ "jax": [[1.0, 1.0], [1.0, -1.0]], "torch": [[0.0, 0.0], [1.0, -1.0]], "numpy": [[0.0, 0.0], [1.0, -1.0]] - }, + }, "target": { "jax": [[-1.0, -1.0], [-1.0, 1.0]], "torch": [[0.0, 0.0], [-1.0, 1.0]], @@ -4875,7 +4875,7 @@ "jax": [[-1.000000000000000000002, -1.000000000000000000002], [1.0, 1.0], [1.000000000000000000002, 1.000000000000000000002]], "torch": [[-1.000000000000000000002, -1.000000000000000000002], [0.0, 0.0], [1.000000000000000000002, 1.000000000000000000002]], "numpy": [[-1.000000000000000000002, -1.000000000000000000002], [0.0, 0.0], [1.000000000000000000002, 1.000000000000000000002]] - }, + }, "target": { "jax": [[1.000000000000000000002, 1.000000000000000000002], [-1.0, -1.0], [-1.000000000000000000002, -1.000000000000000000002]], "torch": [[1.000000000000000000002, 1.000000000000000000002], [0.0, 0.0], [-1.000000000000000000002, -1.000000000000000000002]], @@ -4903,7 +4903,7 @@ "output": [1.3862943611198906] }, "grad": { - "input": [[0.25, 0.25, -0.75, 0.25]] + "input": [[0.25, 0.25, -0.75, 0.25]] } } }, @@ -4957,14 +4957,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] }, "static_keys": { + "robust": true, "target": [1, 0] }, "output_grads": { @@ -4984,14 +4984,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] }, "static_keys": { + "robust": true, "target": [0, 0] }, "output_grads": { @@ -5015,14 +5015,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[2.220446049250313e-16, 1.0], [0.1, 0.9]] }, "static_keys": { + "robust": true, "target": [0, 0] }, "output_grads": { @@ -5042,14 +5042,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.0, 1.0]] }, "static_keys": { + "robust": true, "target": [0] }, "output_grads": { @@ -5073,14 +5073,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.2, 1.1102230246251565e-16, 0.7999999999999998], [0.1, 0.6, 0.3]] }, "static_keys": { + "robust": true, "target": [1, 2] }, "output_grads": { @@ -5099,14 +5099,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.5, 0.5], [0.1, 0.9]] }, "static_keys": { + "robust": true, "target": [0, 1] }, "output_grads": { @@ -5126,14 +5126,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.5, 0.5], [0.1, 0.9]] }, "static_keys": { + "robust": true, "target": [0, 1] }, "output_grads": { @@ -5153,14 +5153,14 @@ "model": { "name": "CrossEntropy", "args": { - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] }, "static_keys": { + "robust": true, "target": [1, 0] }, "output_grads": { @@ -5291,14 +5291,14 @@ "model": { "name": "BinaryCrossEntropy", "args":{ - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "inputs": { "input": [[0.1, 0.2, 0.3], [0.5, 1.0, 0.2]] }, "static_keys": { + "robust": true, "target": [[0, 1, 0], [1, 1, 1]] }, "output_grads": { @@ -5405,7 +5405,7 @@ }, "inputs": { "input": [ - [-2.197224577336219, -1.3862943611198906, -0.8472978603872036], + [-2.197224577336219, -1.3862943611198906, -0.8472978603872036], [0.0, 1e100, -1.3862943611198906] ] }, @@ -5556,15 +5556,13 @@ }, "test_quantile_loss_1": { "model": { - "name": "QuantileLoss", - "args": { - "quantile": 0.1 - } + "name": "QuantileLoss" }, "inputs": { "input": [[1.0]] }, "static_keys": { + "quantile": 0.1, "target": [[0.0]] }, "output_grads": { @@ -5581,12 +5579,10 @@ }, "test_quantile_loss_2": { "model": { - "name": "QuantileLoss", - "args": { - "quantile": 0.1 - } + "name": "QuantileLoss" }, "static_keys": { + "quantile": 0.1, "target": [[1.0], [1e-5]] }, "inputs": { @@ -5606,15 +5602,13 @@ }, "test_quantile_loss_3": { "model": { - "name": "QuantileLoss", - "args": { - "quantile": 0.9 - } + "name": "QuantileLoss" }, "inputs": { "input": [[1.0], [1e-5]] }, "static_keys": { + "quantile": 0.9, "target": [[1.0], [1e-5]] }, "output_grads": { @@ -5631,15 +5625,13 @@ }, "test_quantile_loss_4": { "model": { - "name": "QuantileLoss", - "args": { - "quantile": 0.4 - } + "name": "QuantileLoss" }, "inputs": { "input":[[1.0, 2.0], [0.0, 1.0]] }, "static_keys": { + "quantile": 0.4, "target": [[1.0, 0.0], [2.0, -1.0]] }, "output_grads": {"output": [[1.0, 1.0], [1.0, 1.0]]}, @@ -5661,11 +5653,11 @@ "target": [[1], [-1]] }, "output_grads": { - "output": [[1.0], [1.0]] + "output": [[1.0], [1.0]] }, "results": { "eval": { - "output": [[0.0], [1.5625]] + "output": [[0.0], [1.5625]] }, "grad": { "input": [[0.0], [2.5]] @@ -5687,7 +5679,7 @@ }, "results": { "eval": { - "output": [[0.0], [0.0]] + "output": [[0.0], [0.0]] }, "grad": { "input": [[0.0], [0.0]] @@ -5710,7 +5702,7 @@ }, "results": { "eval": { - "output": [[1.0], [1.0]] + "output": [[1.0], [1.0]] }, "grad": { "input": [[-2.0], [2.0]] @@ -5861,11 +5853,9 @@ }, "test_leaky_relu_1": { "model": { - "name": "LeakyRelu", - "args":{ - "slope": 0.2 - } + "name": "LeakyRelu" }, + "static_keys": {"slope": 0.2}, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] }, @@ -5887,11 +5877,9 @@ }, "test_leaky_relu_2": { "model": { - "name": "LeakyRelu", - "args":{ - "slope": 0.2 - } + "name": "LeakyRelu" }, + "static_keys": {"slope": 0.2}, "inputs": { "input": [[0.0]] }, @@ -5913,11 +5901,9 @@ }, "test_leaky_relu_3": { "model": { - "name": "LeakyRelu", - "args":{ - "slope": 0.2 - } + "name": "LeakyRelu" }, + "static_keys": {"slope": 0.2}, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] }, @@ -6148,11 +6134,9 @@ }, "test_softmax_8": { "model": { - "name": "Softmax", - "args": { - "axis": 0 - } + "name": "Softmax" }, + "static_keys": {"axis": 0}, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] }, @@ -6184,7 +6168,7 @@ "results": { "eval": { "output": [[0.6931471805599453094172321214, 0.6931471805599453094172321214], [1.09861228866810969139524523692252570464, 1.3862943611198906188344642429163531], [1.386294361119890618834464242916353136, 4.6051701859880913680359829093687284152022]] - }, + }, "grad": { "input": [[0.5, 0.5], [0.333333333333333333333, 0.25], [0.25, 0.01]] } @@ -6206,7 +6190,7 @@ "results": { "eval": { "output": [[[-709.396418532264106216811584991213718666567366540526]]] - }, + }, "grad": { "input": { "jax": [[[4.4942328371557897351686972308210038429885969661285748811057e307]]], @@ -6236,7 +6220,7 @@ "torch": [[[-709.3959691089803906378380681214906365661830676808299040690118894, -704.591038456177979309505385133415447525937055520404530]]], "numpy": [[[-709.3959691089803906378380681214906365661830676808299040690118894, -704.591038456177979309505385133415447525937055520404530]]] } - }, + }, "grad": { "input": [[[4.4942328371557897351686972308210038429885969661285748811057e307, 1e306]]] } @@ -6255,7 +6239,7 @@ "results": { "eval": { "output": [[[[[0.5, 0.5], [0.333333333333333333333, 0.25], [0.25, 0.01]]]]] - }, + }, "grad": { "input": [[[[[-0.25, -0.25], [-0.11111111111111111111111, -0.0625], [-0.0625, -0.0001]]]]] } @@ -6274,7 +6258,7 @@ "results": { "eval": { "output": [[1.34078079299425975369365231198393901835279232e153]] - }, + }, "grad": { "input": [[-4.494232837155790062526881505709689647141e305]] } @@ -6293,7 +6277,7 @@ "results": { "eval": { "output": [[1.33628656015710396363112543047822932870565049166274379e153, 1e145]] - }, + }, "grad": { "input": [[-4.494232837155790062526881505709689647141e305, -1e290]] } @@ -6315,7 +6299,7 @@ "results": { "eval": { "output": [[1.0, 1.4142135623730950488], [1.732050807568877293527446341505872, 2.0]] - }, + }, "grad": { "input": [[0.5 , 0.35355339059327373], [0.2886751345948129, 0.25]] } @@ -6337,7 +6321,7 @@ "results": { "eval": { "output": [[[[2.0, 4.0], [5.0, 10.0]]]] - }, + }, "grad": { "input": [[[[0.75 , 0.25], [0.5, 0.3]]]] } @@ -6359,7 +6343,7 @@ "results": { "eval": { "output": [100.0] - }, + }, "grad": { "input": [0.005] } @@ -6381,7 +6365,7 @@ "results": { "eval": { "output": [[0.0, 2.0], [1.0, 2.0]] - }, + }, "grad": { "input": { "jax": [[6.703903964971299e+153, -0.25], [-0.5, -0.25]], @@ -6395,10 +6379,10 @@ "model": { "name": "Sqrt", "args": { - "robust": true, - "cutoff": 1e-20 + "robust": true } }, + "static_keys": {"cutoff": 1e-20}, "inputs": { "input": [[0.9999999999999e-20, 1.0000000000001e-20, 0.0]] }, @@ -6408,7 +6392,7 @@ "results": { "eval": { "output": [[9.999999999999e-11, 1.0000000000000499999999999987500000000000624e-10, 0.0]] - }, + }, "grad": { "input": { "jax": [[1e10, 4.99999999999975000000000001874999999999843750000000013e9, 1e10]], @@ -6435,7 +6419,7 @@ "results": { "eval": { "output": [[0.0], [2.94760229696920019e-62]] - }, + }, "grad": { "base": [[0.0], [2.6494422067830456494626727070578342769220668517896442342632083875646713489594382293556230650365376585602492809663848600694036441316054962143018101466468388523103155742568140908731216503912613033423051394246710020933425593223230870218134328967168e+245]], "exponent": [[0.0], [-2.08807091043044e-59]] @@ -6463,7 +6447,7 @@ "output": [[4.0,8.0], [9.0,27.0], [16.0,64.0]] - }, + }, "grad": { "base": [[16.0],[33.0],[56.0]], "exponent": [[34.84080909817102,123.93054835019151]] @@ -6478,7 +6462,7 @@ } }, "inputs": { - "base": [2.0], + "base": [2.0], "exponent": [[1.0], [2.0], [3.0], [4.0]] }, "output_grads": { @@ -6493,9 +6477,9 @@ [4.0], [8.0], [16.0]] - }, + }, "grad": { - "base": [49.0], + "base": [49.0], "exponent": [[1.3862943611198906], [2.772588722239781], [5.545177444479562], [11.090354888959125]] } } @@ -6508,7 +6492,7 @@ } }, "inputs": { - "base": [2.0], + "base": [2.0], "exponent": [3.0] }, "output_grads": { @@ -6517,9 +6501,9 @@ "results": { "eval": { "output": [8.0] - }, + }, "grad": { - "base": [12.0], + "base": [12.0], "exponent": [5.545177444479562] } } @@ -6610,7 +6594,7 @@ "model": { "name": "Concat", "args": { - "n": 2, + "n": 2, "axis": null } }, @@ -6628,14 +6612,14 @@ "grad": { "input1": [[-1.0, -2.0, -3.0, -4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]] - } + } } }, "test_concat_axis_none_2": { "model": { "name": "Concat", "args": { - "n": 3, + "n": 3, "axis": null } }, @@ -6655,14 +6639,14 @@ "input1": [[-1.0, -2.0, -3.0, -4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]], "input3": [[[[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]]]]]] - } + } } }, "test_concat_axis_none_3": { "model": { "name": "Concat", "args": { - "n": 3, + "n": 3, "axis": null } }, @@ -6682,14 +6666,14 @@ "input1": [[-1.0, -2.0, -3.0, -4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]], "input3": [[9.0, 8.0, 7.0], [6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] - } + } } }, "test_concat_axis_none_4": { "model": { "name": "Concat", "args": { - "n": 9, + "n": 9, "axis": null } }, @@ -6722,14 +6706,14 @@ "input7": [[8.0, 9.0], [10.0, -11.0]], "input8": [[[[12.0]]],[[[13.0]]]], "input9": [14.0] - } + } } }, "test_concat_1": { "model": { "name": "Concat", "args": { - "n": 2, + "n": 2, "axis": 0 } }, @@ -6747,14 +6731,14 @@ "grad": { "input1": [[-1.0], [-2.0], [-3.0], [-4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]] - } + } } }, "test_concat_2": { "model": { "name": "Concat", "args": { - "n": 2, + "n": 2, "axis": 1 } }, @@ -6772,14 +6756,14 @@ "grad": { "input1": [[-1.0], [-2.0], [-3.0], [-4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]] - } + } } }, "test_concat_3": { "model": { "name": "Concat", "args": { - "n": 3, + "n": 3, "axis": -2 } }, @@ -6799,14 +6783,14 @@ "input1": [[-1.0], [-2.0], [-3.0], [-4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]], "input3": [[1.0], [2.0], [3.0], [4.0]] - } + } } }, "test_concat_4": { "model": { "name": "Concat", "args": { - "n": 3, + "n": 3, "axis": -1 } }, @@ -6826,7 +6810,7 @@ "input1": [[-1.0], [-2.0], [-3.0], [-4.0]], "input2": [[-5.0], [-6.0], [7.0], [8.0]], "input3": [[1.0], [2.0], [3.0], [4.0]] - } + } } }, "test_var_1": { @@ -6845,7 +6829,7 @@ }, "grad": { "input": [[-0.8], [-0.4], [0.0], [0.4], [0.8]] - } + } } }, "test_var_2": { @@ -6868,7 +6852,7 @@ }, "grad": { "input": [[-0.8], [-0.4], [0.0], [0.4], [0.8]] - } + } } }, "test_var_3": { @@ -6891,7 +6875,7 @@ }, "grad": { "input": [[0.0], [0.0], [0.0], [0.0], [0.0]] - } + } } }, "test_var_4": { @@ -6913,7 +6897,7 @@ }, "grad": { "input": [[-1.0], [-0.5], [0.0], [0.5], [1.0]] - } + } } }, "test_reduce_sum_1": { @@ -6932,7 +6916,7 @@ }, "grad": { "input": [[[[[[[[[[[[[[[7.0,7.0,7.0],[7.0,7.0,7.0]],[[7.0,7.0,7.0],[7.0,7.0,7.0]]]]]]]]]]]]]]] - } + } } }, "test_reduce_sum_2": { @@ -6951,7 +6935,7 @@ }, "grad": { "input": [[3.0, 3.0], [3.0, 3.0], [3.0, 3.0]] - } + } } }, "test_reduce_sum_3": { @@ -6973,7 +6957,7 @@ }, "grad": { "input": [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] - } + } } }, "test_reduce_sum_4": { @@ -6995,7 +6979,7 @@ }, "grad": { "input": [[4.0, 4.0], [3.0, 3.0], [1.0, 1.0]] - } + } } }, "test_reduce_sum_5": { @@ -7017,7 +7001,7 @@ }, "grad": { "input": [[3.0, 7.0], [3.0, 7.0], [3.0, 7.0]] - } + } } }, "test_reduce_sum_6": { @@ -7040,7 +7024,7 @@ }, "grad": { "input": [[[3.0,3.0,3.0],[7.0,7.0,7.0]],[[3.0,3.0,3.0],[7.0,7.0,7.0]],[[3.0,3.0,3.0],[7.0,7.0,7.0]]] - } + } } }, "test_reduce_sum_7": { @@ -7063,7 +7047,7 @@ }, "grad": { "input": [[[[[[[[1.0,2.0,4.0],[5.0,6.0,3.0]],[[2.0,7.0,9.0],[1.0,2.0,6.0]],[[2.0,8.0,9.0],[11.0,17.0,1e10]]]]]]]] - } + } } }, "test_reduce_mean_1": { @@ -7082,7 +7066,7 @@ }, "grad": { "input": [[[[[[[[[[[[[[[1.0,1.0,1.0],[1.0,1.0,1.0]],[[1.0,1.0,1.0],[1.0,1.0,1.0]]]]]]]]]]]]]]] - } + } } }, "test_reduce_mean_2": { @@ -7101,7 +7085,7 @@ }, "grad": { "input": [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] - } + } } }, "test_reduce_mean_3": { @@ -7123,7 +7107,7 @@ }, "grad": { "input": [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]] - } + } } }, "test_reduce_mean_4": { @@ -7145,7 +7129,7 @@ }, "grad": { "input": [[2.0, 2.0], [1.5, 1.5], [0.5, 0.5]] - } + } } }, "test_reduce_mean_5": { @@ -7167,7 +7151,7 @@ }, "grad": { "input": [[1.0, 3.0], [1.0, 3.0], [1.0, 3.0]] - } + } } }, "test_reduce_mean_6": { @@ -7190,7 +7174,7 @@ }, "grad": { "input": [[[1.0,1.0,1.0],[3.0,3.0,3.0]],[[1.0,1.0,1.0],[3.0,3.0,3.0]],[[1.0,1.0,1.0],[3.0,3.0,3.0]]] - } + } } }, "test_reduce_mean_7": { @@ -7213,7 +7197,7 @@ }, "grad": { "input": [[[[[[[[1.0,2.0,4.0],[5.0,6.0,3.0]],[[2.0,7.0,9.0],[1.0,2.0,6.0]],[[2.0,8.0,9.0],[11.0,17.0,1e10]]]]]]]] - } + } } }, "test_reduce_max_1": { @@ -7232,7 +7216,7 @@ }, "grad": { "input": [[[[[[[[[[[[[[[0.0,0.0,0.0],[0.0,0.0,0.0]],[[0.0,0.0,0.0],[17.0,0.0,0.0]]]]]]]]]]]]]]] - } + } } }, "test_reduce_max_2": { @@ -7251,7 +7235,7 @@ }, "grad": { "input": [[0.0, 0.0], [3.0, 3.0], [3.0, 0.0]] - } + } } }, "test_reduce_max_3": { @@ -7273,7 +7257,7 @@ }, "grad": { "input": [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]] - } + } } }, "test_reduce_max_4": { @@ -7295,7 +7279,7 @@ }, "grad": { "input": [[4.0, 0.0], [3.0, 0.0], [1.0, 0.0]] - } + } } }, "test_reduce_max_5": { @@ -7317,7 +7301,7 @@ }, "grad": { "input": [[0.0, 0.0], [3.0, 9.0], [0.0, 0.0]] - } + } } }, "test_reduce_max_6": { @@ -7340,7 +7324,7 @@ }, "grad": { "input": [[[0.0,0.0,0.0],[0.0,0.0,0.0]],[[0.0,0.0,0.0],[0.0,0.0,0.0]],[[9.0,0.0,0.0],[27.0,0.0,0.0]]] - } + } } }, "test_reduce_max_7": { @@ -7363,7 +7347,7 @@ }, "grad": { "input": [[[[[[[[1.0,2.0,4.0],[5.0,6.0,3.0]],[[2.0,7.0,9.0],[1.0,2.0,6.0]],[[2.0,8.0,9.0],[11.0,17.0,1e10]]]]]]]] - } + } } }, "test_reduce_min_1": { @@ -7382,7 +7366,7 @@ }, "grad": { "input": [[[[[[[[[[[[[[[3.0,0.0,0.0],[0.0,0.0,3.0]],[[0.0,0.0,3.0],[0.0,0.0,0.0]]]]]]]]]]]]]]] - } + } } }, "test_reduce_min_2": { @@ -7401,7 +7385,7 @@ }, "grad": { "input": [[0.0, 4.0], [0.0, 0.0], [0.0, 0.0]] - } + } } }, "test_reduce_min_3": { @@ -7423,7 +7407,7 @@ }, "grad": { "input": [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]] - } + } } }, "test_reduce_min_4": { @@ -7445,7 +7429,7 @@ }, "grad": { "input": [[0.0, 4.0], [0.0, 3.0], [0.0, 1.0]] - } + } } }, "test_reduce_min_5": { @@ -7467,7 +7451,7 @@ }, "grad": { "input": [[3.0, 9.0], [0.0, 0.0], [0.0, 0.0]] - } + } } }, "test_reduce_min_6": { @@ -7490,7 +7474,7 @@ }, "grad": { "input": [[[4.0,0.0,0.0],[0.0,0.0,0.0]],[[0.0,0.0,4.0],[8.0,0.0,0.0]],[[0.0,0.0,0.0],[0.0,0.0,8.0]]] - } + } } }, "test_reduce_min_7": { @@ -7513,7 +7497,7 @@ }, "grad": { "input": [[[[[[[[[[1.0,2.0,4.0],[5.0,6.0,3.0]],[[2.0,7.0,9.0],[1.0,2.0,6.0]],[[2.0,8.0,9.0],[11.0,17.0,1e10]]]]]]]]]] - } + } } }, "test_composite_1": { @@ -7528,9 +7512,9 @@ }, "m3": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -7562,7 +7546,7 @@ "grad": { "left": [[[[18.0, 32.0, 50.0]]]], "right": [[[[12.0, 32.0, 60.0]]]] - } + } } }, "test_composite_2": { @@ -7580,9 +7564,9 @@ }, "m3": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -7595,7 +7579,7 @@ "left": {"connect": [["m2","output"]]}, "right": {"connect": [["m2","output"]]}, "output": { - "name": "output", + "name": "output", "expose": true } } @@ -7612,9 +7596,9 @@ }, "m3": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -7627,7 +7611,7 @@ "left": {"connect": [["m2","output"]]}, "right": {"connect": [["m2","output"]]}, "output": { - "name": "output", + "name": "output", "expose": true } } @@ -7644,9 +7628,9 @@ }, "m3": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -7659,14 +7643,14 @@ "left": {"connect": [["m2","output"]]}, "right": {"connect": [["m2","output"]]}, "output": { - "name": "output", + "name": "output", "expose": true } } } } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -7679,7 +7663,7 @@ "left": {"connect": [["m2","output"]]}, "right": {"connect": [["m2","output"]]}, "output": { - "name": "output", + "name": "output", "expose": true } } @@ -7699,7 +7683,7 @@ "grad": { "left": [[[[839808.0, 16777216.0, 150000000.0]]]], "right": [[[[1119744.0, 33554432.0, 360000000.0]]]] - } + } } }, "test_composite_3": { @@ -7722,13 +7706,13 @@ "name": "Add" } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right", "output": "output" } - } + } }, "m2": { "name": "Add" @@ -7809,7 +7793,7 @@ "grad": { "left": [[[[5.0, 5.0, 5.0]]]], "right": [[[[1.0, 1.0, 1.0]]]] - } + } } }, "test_composite_4": { @@ -7824,9 +7808,9 @@ }, "m3": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -7856,7 +7840,7 @@ "grad": { "left": [[[[18.0, 32.0, 50.0]]]], "right": [[[[12.0, 32.0, 60.0]]]] - } + } } }, "test_composite_5": { @@ -7871,14 +7855,14 @@ }, "m3": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right", "output": { - "name": "output1", + "name": "output1", "expose": true } }, @@ -7890,7 +7874,7 @@ "left": {"connect": [["m2","output"]]}, "right": {"connect": [["m2","output"]]}, "output": { - "name": "output", + "name": "output", "expose": true } } @@ -7912,7 +7896,7 @@ "grad": { "left": [24.0, 40.0, 60.0], "right": [14.0, 36.0, 66.0] - } + } } }, "test_composite_1_extend_from_inputs": { @@ -7927,9 +7911,9 @@ }, "m1": { "name": "Multiply" - } + } }, - "connections" : { + "connections" : { "m3": { "output": "output" }, @@ -7958,7 +7942,7 @@ "grad": { "left": [[[[18.0, 32.0, 50.0]]]], "right": [[[[12.0, 32.0, 60.0]]]] - } + } } }, "test_composite_2_extend_from_inputs": { @@ -7976,12 +7960,12 @@ }, "m1": { "name": "Multiply" - } + } }, - "connections" : { + "connections" : { "m3": { "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8007,12 +7991,12 @@ }, "m1": { "name": "Multiply" - } + } }, - "connections" : { + "connections" : { "m3": { "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8038,12 +8022,12 @@ }, "m1": { "name": "Add" - } + } }, - "connections" : { + "connections" : { "m3": { "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8059,7 +8043,7 @@ } } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right" @@ -8072,7 +8056,7 @@ "left": {"connect": [["m2","output"]]}, "right": {"connect": [["m2","output"]]}, "output": { - "name": "output", + "name": "output", "expose": true } } @@ -8092,7 +8076,7 @@ "grad": { "left": [[[[839808.0, 16777216.0, 150000000.0]]]], "right": [[[[1119744.0, 33554432.0, 360000000.0]]]] - } + } } }, "test_composite_3_extend_from_inputs": { @@ -8115,16 +8099,16 @@ "name": "Add" } }, - "connections" : { + "connections" : { "m1": { "left": "left", "right": "right", "output": { - "name": "output", + "name": "output", "expose": true } } - } + } }, "m1": { "name": "Add" @@ -8134,7 +8118,7 @@ "m2": { "left": "left", "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8153,7 +8137,7 @@ "m2": { "left": "left", "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8172,7 +8156,7 @@ "m2": { "left": "left", "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8192,7 +8176,7 @@ "m2": { "left": "left", "output": { - "name": "output", + "name": "output", "expose": true } }, @@ -8217,7 +8201,7 @@ "grad": { "left": [[[[5.0, 5.0, 5.0]]]], "right": [[[[1.0, 1.0, 1.0]]]] - } + } } }, "test_composite_4_extend_from_inputs": { @@ -8232,15 +8216,15 @@ }, "m2": { "name": "Relu" - }, + }, "m1": { "name": "Relu" - } + } }, - "connections" : { + "connections" : { "m4": { "output": { - "name": "output", + "name": "output", "expose": true }, "input": "my_input" @@ -8271,7 +8255,7 @@ }, "grad": { "input1": [[[[1.0, 1.0, 1.0]]]] - } + } } }, "test_composite_5_extend_from_inputs": { @@ -8294,10 +8278,10 @@ "connections" : { "m4": { "output": { - "name": "output", + "name": "output", "expose": true } - }, + }, "m3": { "input": {"connect": [["m4","input"]]} }, @@ -8322,7 +8306,7 @@ }, "grad": { "my_input": [[[[1.0, 1.0, 1.0]]]] - } + } } }, "test_composite_6_extend_from_inputs": { @@ -8345,10 +8329,10 @@ "connections" : { "m4": { "output": { - "name": "output", + "name": "output", "expose": true } - }, + }, "m3": { "input": {"connect": [["m4","input"]]} }, @@ -8373,7 +8357,7 @@ }, "grad": { "my_input": [[[[1.0, 1.0, 1.0]]]] - } + } } }, "test_composite_extend_numbers": { @@ -8381,16 +8365,16 @@ "name": "Model", "submodels": { "m1": { - "name": "Multiply" + "name": "Multiply" }, "m2": { - "name": "Multiply" + "name": "Multiply" }, "m3": { "name": "Add" } }, - "connections" : { + "connections" : { "m1": { "left": "input1", "right": 1 @@ -8420,7 +8404,7 @@ }, "grad": { "input1": [4.0] - } + } } }, "test_embeddig_1": { @@ -8629,4 +8613,4 @@ } } } -} +} \ No newline at end of file diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index f7c5cdff..34ed6750 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -422,7 +422,7 @@ "regular_args": { "robust": true }, - "randomized_args": { + "static_input_info": { "threshold": [1, 1] } }, @@ -667,7 +667,6 @@ "name": "CrossEntropy", "regular_args":{ "input_type": "logits", - "categorical": true, "weights": [1.0,0.2,1.0,1.0,1.0,1.0,0.7,1.0,1.0,1.0,1.0,0.1,1.0,1.0,1.0] } }, @@ -681,7 +680,8 @@ "mode": "int", "shapes": [[15, 15], [15, 15], [15, 15]], "interval": [[0, 0], [9, 9]] - } + }, + "categorical": true }, "iterations": 20 }, @@ -690,7 +690,6 @@ "name": "CrossEntropy", "regular_args":{ "input_type": "probs", - "categorical": true, "weights": [1.0,0.2,1.0,1.0,1.0,1.0,1.0,0.7,1.0,0.2,1.0,1.0,1.0,1.0,1.0] } }, @@ -705,7 +704,8 @@ "mode": "int", "shapes": [[15, 15], [15, 15], [15, 15]], "interval": [[0, 0], [9, 9]] - } + }, + "categorical": true }, "iterations": 20 }, @@ -714,7 +714,6 @@ "name": "CrossEntropy", "regular_args":{ "input_type": "log_probs", - "categorical": true, "weights": [1.0,0.2,1.0,1.0,1.0,1.0,1.0,0.7,1.0,0.2,1.0,1.0,1.0,1.0,1.0] } }, @@ -728,7 +727,8 @@ "mode": "int", "shapes": [[15, 15], [15, 15], [15, 15]], "interval": [[0, 0], [9, 9]] - } + }, + "categorical": true }, "iterations": 20 }, @@ -736,8 +736,7 @@ "model": { "name": "CrossEntropy", "regular_args":{ - "weights": [1.0,0.2,1.0,1.0,1.0,1.0,1.0,0.7,1.0,0.2,1.0,1.0,1.0,1.0,1.0], - "categorical": false + "weights": [1.0,0.2,1.0,1.0,1.0,1.0,1.0,0.7,1.0,0.2,1.0,1.0,1.0,1.0,1.0] } }, "input_info": { @@ -749,7 +748,8 @@ "target": { "shapes": [[15, 15], [15,15], [15,15]], "is_positive": true - } + }, + "categorical": false }, "iterations": 20, "64bit_relative_tolerance": 1e-12 @@ -759,9 +759,7 @@ "name": "CrossEntropy", "regular_args": { "weights": [1.0,0.2,1.0,1.0,1.0,1.0,1.0,0.7,1.0,0.2,1.0,1.0,1.0,1.0,1.0], - "input_type": "probs", - "categorical": false, - "robust": true + "input_type": "probs" } }, "input_info": { @@ -774,7 +772,9 @@ "target": { "shapes": [[15, 15], [15,15], [15, 15]], "is_positive": true - } + }, + "robust": true, + "categorical": false }, "iterations": 20, "64bit_relative_tolerance": 1e-12 @@ -784,8 +784,7 @@ "name": "CrossEntropy", "regular_args": { "weights": [1.0,0.2,1.0,1.0,1.0,1.0,1.0,0.7,1.0,0.2,1.0,1.0,1.0,1.0,1.0], - "input_type": "log_probs", - "categorical": false + "input_type": "log_probs" } }, "input_info": { @@ -797,7 +796,8 @@ "target": { "shapes": [[15, 15], [15,15], [15, 15]], "is_positive": true - } + }, + "categorical": false }, "iterations": 20 }, @@ -806,8 +806,7 @@ "name": "BinaryCrossEntropy", "regular_args":{ "input_type": "probs", - "pos_weight": 0.625, - "robust": true + "pos_weight": 0.625 } }, "static_input_info": { @@ -815,7 +814,8 @@ "mode": "int", "shapes": [[1, 20], [1, 1]], "interval": [[0, 0], [1, 1]] - } + }, + "robust": true }, "iterations": 5 }, @@ -824,8 +824,7 @@ "name": "BinaryCrossEntropy", "regular_args":{ "input_type": "probs", - "pos_weight": "auto", - "robust": true + "pos_weight": "auto" } }, "static_input_info": { @@ -833,7 +832,8 @@ "mode": "int", "shapes": [[1, 20], [1, 1]], "interval": [[0, 0], [1, 1]] - } + }, + "robust": true }, "iterations": 5 }, @@ -841,8 +841,7 @@ "model": { "name": "BinaryCrossEntropy", "regular_args":{ - "input_type": "probs", - "robust": true + "input_type": "probs" } }, "static_input_info": { @@ -850,7 +849,8 @@ "mode": "int", "shapes": [[1, 20], [1, 1]], "interval": [[0, 0], [1, 1]] - } + }, + "robust": true }, "iterations": 5 }, @@ -859,8 +859,7 @@ "name": "BinaryCrossEntropy", "regular_args":{ "input_type": "probs", - "pos_weight": 0.625, - "robust": false + "pos_weight": 0.625 } }, "static_input_info": { @@ -868,7 +867,8 @@ "mode": "int", "shapes": [[1, 20], [1, 1]], "interval": [[0, 0], [1, 1]] - } + }, + "robust": false }, "iterations": 5 }, @@ -877,8 +877,7 @@ "name": "BinaryCrossEntropy", "regular_args":{ "input_type": "probs", - "pos_weight": "auto", - "robust": false + "pos_weight": "auto" } }, "static_input_info": { @@ -886,7 +885,8 @@ "mode": "int", "shapes": [[1, 20], [1, 1]], "interval": [[0, 0], [1, 1]] - } + }, + "robust": false }, "iterations": 5 }, @@ -894,8 +894,7 @@ "model": { "name": "BinaryCrossEntropy", "regular_args":{ - "input_type": "probs", - "robust": false + "input_type": "probs" } }, "static_input_info": { @@ -903,21 +902,20 @@ "mode": "int", "shapes": [[1, 20], [1, 1]], "interval": [[0, 0], [1, 1]] - } + }, + "robust": false }, "iterations": 5 }, "test_binary_cross_entropy_logits_1": { "model": { - "name": "BinaryCrossEntropy", - "regular_args":{ - "robust": true - } + "name": "BinaryCrossEntropy" }, "static_input_info": { "target": { "shapes": [[1, 20], [1, 1]] - } + }, + "robust": true }, "iterations": 5 }, @@ -925,14 +923,14 @@ "model": { "name": "BinaryCrossEntropy", "regular_args":{ - "pos_weight": "auto", - "robust": true + "pos_weight": "auto" } }, "static_input_info": { "target": { "shapes": [[1, 20], [1, 1]] - } + }, + "robust": true }, "iterations": 5 }, @@ -940,30 +938,26 @@ "model": { "name": "BinaryCrossEntropy", "regular_args":{ - "pos_weight": 0.625, - "robust": true + "pos_weight": 0.625 } }, "static_input_info": { "target": { "shapes": [[1, 20], [1, 1]] - } + }, + "robust": true }, "iterations": 5 }, "test_binary_cross_entropy_logits_4": { "model": { - "name": "BinaryCrossEntropy", - "regular_args":{ - - "robust": false - } - + "name": "BinaryCrossEntropy" }, "static_input_info": { "target": { "shapes": [[1, 20], [1, 1]] - } + }, + "robust": false }, "iterations": 5 }, @@ -971,15 +965,14 @@ "model": { "name": "BinaryCrossEntropy", "regular_args":{ - "pos_weight": "auto", - "robust": false + "pos_weight": "auto" } - }, "static_input_info": { "target": { "shapes": [[1, 20], [1, 1]] - } + }, + "robust": false }, "iterations": 5 }, @@ -987,24 +980,23 @@ "model": { "name": "BinaryCrossEntropy", "regular_args":{ - "pos_weight": 0.625, - "robust": false + "pos_weight": 0.625 } }, "static_input_info": { "target": { "shapes": [[1, 20], [1, 1]] - } + }, + "robust": false }, "iterations": 5 }, "test_quantile_loss": { "model": { - "name": "QuantileLoss", - - "randomized_args": { - "quantile": [1,1] - } + "name": "QuantileLoss" + }, + "static_input_info": { + "quantile": [1, 1] }, "input_info": { "input" : { @@ -1129,10 +1121,10 @@ }, "test_softmax_2d_2": { "model": { - "name": "Softmax", - "regular_args": { - "axis": 0 - } + "name": "Softmax" + }, + "static_input_info": { + "axis": 0 }, "input_info": { "input" : { @@ -1155,11 +1147,7 @@ "test_sqrt": { "model": { "name": "Sqrt", - "randomized_args": { - "robust": [false, true], - "cutoff": [1, 1] - } - + "randomized_args": {"robust": [false, true]} }, "input_info": { "input" : { @@ -1206,8 +1194,7 @@ "model": { "name": "Log", "randomized_args": { - "robust": [false, true], - "cutoff": [1, 1] + "robust": [false, true] } }, "input_info": { diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 447a03fa..a562fe24 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -43,7 +43,8 @@ def evaluate_case( inputs = convert_to_array(backend, current_case.get("inputs", {})) results = convert_to_array(backend, current_case.get("results", {})) discard_keys = set(current_case.get("discard_keys", [])) - static_keys = convert_to_array(backend, current_case.get("static_keys", {})) + # static_keys = convert_to_array(backend, current_case.get("static_keys", {})) + static_keys = dict(current_case.get("static_keys", {})) reference_outputs = results["eval"] reference_gradients = results["grad"] reference_shapes = { @@ -53,7 +54,15 @@ def evaluate_case( assert_shapes_flag = True models: list[BaseModel] = [] - models.append(finalize_model(current_case)) + model = finalize_model(current_case) + # Convert static keys to array if they are not scalar. + for key, value in static_keys.items(): + if isinstance(model.conns._get_metadata(key).data, Scalar): + static_keys[key] = value + else: + static_keys[key] = convert_to_array(backend, value) + models.append(model) + if test_rtt: model_dict = model_to_dict(models[0]) models.append(dict_to_model(model_dict)) diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 221bcc58..f360a9cd 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -243,7 +243,6 @@ def test_default_in_numpy_error(): backend=NumpyBackend(), constant_keys=constant_keys, data_keys=data_keys, - # file_path = "constant_inputs" ) assert ( str(err_info.value) @@ -322,10 +321,7 @@ def test_default_given_compile_numpy(): static_inputs: dict[str, np.ndarray | int] = {"input": np_input, "axis": 0} expected_result = (np_input.mean(0) * 2).mean(0) compiled_model = mithril.compile( - model=model, - backend=NumpyBackend(), - constant_keys=static_inputs, - # file_path = "constant_inputs.py" + model=model, backend=NumpyBackend(), constant_keys=static_inputs ) inputs = compiled_model.randomize_params() data = {"axis": None} @@ -410,10 +406,7 @@ def test_constant_given_data_numpy(): } expected_result = (np_input.mean(0) * 2).mean(0) compiled_model = mithril.compile( - model=model, - backend=NumpyBackend(), - constant_keys=static_inputs, - # file_path = "constant_inputs.py" + model=model, backend=NumpyBackend(), constant_keys=static_inputs ) inputs = compiled_model.randomize_params() @@ -481,9 +474,9 @@ def test_constant_numpy_set_values(): def test_axis(): model = Model() - relu = LeakyRelu(slope=2.3) - rob_pow = Power(robust=True, threshold=TBD) - model += relu(input="input") + relu = LeakyRelu() + rob_pow = Power(robust=True) + model += relu(input="input", slope=2.3) model += rob_pow(base=relu.output, exponent="exponent", threshold=relu.slope) backend = NumpyBackend() @@ -511,10 +504,10 @@ def test_axis(): def test_axis_1(): model = Model() - relu = LeakyRelu(slope=TBD) - rob_pow = Power(robust=True, threshold=2.3) - model += rob_pow(base="base", exponent="exponent") - model += relu(input=rob_pow.output, slope=rob_pow.threshold) + relu = LeakyRelu() + rob_pow = Power(robust=True) + model += rob_pow(base="base", threshold=2.3, exponent="exponent") + model += relu(input=rob_pow.output, slope=rob_pow.threshold) # type: ignore # Check required value transfer occured in logical model # assert relu.conns.get_data("slope").value == 2.3 @@ -630,23 +623,16 @@ def test_scalar_mean_2_1(): def test_scalar_mean_2_2(): mean_model = Model() - rob_pow = Power(robust=True, threshold=1.3) - with pytest.raises(ValueError) as err_info: - mean_model += rob_pow(threshold=1.5, base="input") - assert ( - str(err_info.value) == "Value of Power's threshold given as 1.5. " - "But the value is already initialized as 1.3" + rob_pow = Model() + rob_pow += Power(robust=True)( + threshold=IOKey(name="threshold", value=1.3), base="base" ) - -def test_scalar_threshold(): - model = Model() - rob_pow = Power(robust=True) with pytest.raises(ValueError) as err_info: - model += rob_pow(threshold=1.5, base="input") + mean_model += rob_pow(threshold=1.5, base="input") assert ( - str(err_info.value) == "Value of Power's threshold given as 1.5. " - "But the value is already initialized as Constant.MIN_POSITIVE_NORMAL" + str(err_info.value) == "Value is set before as 1.3. A scalar" + " value can not be reset." ) @@ -1493,7 +1479,7 @@ def test_composite_1_set_values(): def test_composite_2(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=4) - leaky_relu = LeakyRelu(slope=TBD) + leaky_relu = LeakyRelu() model += conv1(input="input") conv1.input.set_differentiable(True) model += leaky_relu(input=conv1.output, output=IOKey(name="output"), slope=0.3) @@ -1504,10 +1490,12 @@ def test_composite_2(): def test_composite_2_set_values(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=4) - leaky_relu = LeakyRelu(slope=TBD) + leaky_relu = LeakyRelu() model += conv1(input="input") conv1.input.set_differentiable(True) - model += leaky_relu(input=conv1.output, output=IOKey(name="output")) + model += leaky_relu( + input=conv1.output, output=IOKey(name="output"), slope=NOT_GIVEN + ) model.set_values({leaky_relu.slope: 0.3}) model.set_shapes({"input": [1, 1, 4, 4]}) assert_all_backends_device_precision(model) @@ -1516,7 +1504,7 @@ def test_composite_2_set_values(): def test_composite_3(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=1, stride=TBD) - leaky_relu = LeakyRelu(slope=TBD) + leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input", stride=(2, 3)) conv1.input.set_differentiable(True) @@ -1531,12 +1519,12 @@ def test_composite_3(): def test_composite_3_set_values(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=1, stride=TBD) - leaky_relu = LeakyRelu(slope=TBD) + leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input", stride=(2, 3)) conv1.input.set_differentiable(True) model.set_values({conv1.stride: (2, 3)}) - model += leaky_relu(input=conv1.output) + model += leaky_relu(input=conv1.output, slope=NOT_GIVEN) model.set_values({leaky_relu.slope: 0.3}) model += mean_model(axis=conv1.stride) assert not isinstance(conv1.canonical_output, NotAvailable) @@ -1549,7 +1537,7 @@ def test_composite_3_set_values(): def test_composite_4(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=1, stride=TBD) - leaky_relu = LeakyRelu(slope=TBD) + leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input", stride=(2, 3)) conv1.input.set_differentiable(True) @@ -1564,12 +1552,12 @@ def test_composite_4(): def test_composite_4_set_values(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=1, stride=TBD) - leaky_relu = LeakyRelu(slope=TBD) + leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input") conv1.input.set_differentiable(True) model.set_values({conv1.stride: (2, 3)}) - model += leaky_relu(input=conv1.output) + model += leaky_relu(input=conv1.output, slope=NOT_GIVEN) model.set_values({leaky_relu.slope: 0.3}) model += mean_model(axis=conv1.stride) model.set_shapes({"input": [1, 1, 8, 8]}) @@ -2047,7 +2035,7 @@ def test_static_shape_model_3(): def test_static_shape_model_4(): model = Model() model += Relu()(input="input") - model += Log(robust=True, cutoff=TBD) + model += Log(robust=True)(cutoff=NOT_GIVEN) model += Shape() model += ToTensor() model += Relu() @@ -2078,7 +2066,7 @@ def test_static_shape_model_4(): def test_static_shape_model_5(): model = Model() model += Relu()(input="input") - model += (log := Log(robust=True, cutoff=TBD))(cutoff="cutoff") + model += (log := Log(robust=True))(cutoff="cutoff") model += Shape() model += ToTensor() model += Relu()(input=model.canonical_output, output=IOKey(name="output1")) @@ -2389,7 +2377,7 @@ def test_eye_ellipsis_2(): def test_cross_entropy_robust_ellipsis(): backend = TorchBackend() model = Model() - ce_model = CrossEntropy(robust=TBD, input_type="probs") + ce_model = CrossEntropy(input_type="probs") model += ce_model( input="input", target="target", output=IOKey(name="output"), robust="robust" ) @@ -2414,9 +2402,7 @@ def test_cross_entropy_robust_ellipsis(): def test_bce_ellipsis(): backend = NumpyBackend() model_1 = Model() - ce_model_1 = BinaryCrossEntropy( - robust=TBD, pos_weight=TBD, cutoff=TBD, input_type="probs" - ) + ce_model_1 = BinaryCrossEntropy(pos_weight=TBD, input_type="probs") model_1 += ce_model_1( input="input", target="target", diff --git a/tests/scripts/test_freeze.py b/tests/scripts/test_freeze.py new file mode 100644 index 00000000..1e6fc45d --- /dev/null +++ b/tests/scripts/test_freeze.py @@ -0,0 +1,75 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from mithril.models import Add, Linear, Model, ScalarItem + + +def test_freeze_set_values_primitive(): + model = Add() + assert model.is_frozen is True + + model._freeze() + assert model.is_frozen is True + + with pytest.raises(ValueError) as error_info: + model.set_values({"left": 1.0}) + assert str(error_info.value) == "Model is frozen, can not set the key: left!" + + +def test_freeze_set_values_extend_defined_logical(): + model = Linear() + assert model.is_frozen is True + + model._freeze() + assert model.is_frozen is True + + with pytest.raises(ValueError) as error_info: + model.set_values({"input": 1.0}) + assert str(error_info.value) == "Model is frozen, can not set the key: input!" + + with pytest.raises(AttributeError) as attr_error_info: + model += Add() + assert str(attr_error_info.value) == "Model is frozen and can not be extended!" + + +def test_freeze_set_values_extend_logical(): + model = Model() + model += Add()(left="left", right="right") + assert model.is_frozen is False + + model.set_values({"left": 1.0}) + model._freeze() + assert model.is_frozen is True + + with pytest.raises(ValueError) as error_info: + model.set_values({"right": 1.0}) + assert str(error_info.value) == "Model is frozen, can not set the key: right!" + + with pytest.raises(AttributeError) as attr_error_info: + model += Add() + assert str(attr_error_info.value) == "Model is frozen and can not be extended!" + + +def test_freeze_set_values_scalar(): + model = Model() + model += ScalarItem()(input="input") + assert model.is_frozen is False + + model._freeze() + model.set_values({"input": [1.0]}) + assert model.is_frozen is True + + assert model.input.data.metadata.data.value == [1.0] # type: ignore diff --git a/tests/scripts/test_primitive_calls.py b/tests/scripts/test_primitive_calls.py new file mode 100644 index 00000000..5281eb8f --- /dev/null +++ b/tests/scripts/test_primitive_calls.py @@ -0,0 +1,95 @@ +# Copyright 2022 Synnada, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import mithril as ml +from mithril.models import Model, Power + + +def test_power_call_threshold_iokey(): + model = Model() + pow = Power(robust=True) + model += pow(threshold=ml.IOKey("t", 0.1)) + assert model.t.data.metadata.data.value == 0.1 # type: ignore + + +def test_error_not_robust_power_call_threshold_iokey(): + pow = Power(robust=False) + + with pytest.raises(ValueError) as error_info: + pow(threshold=ml.IOKey("t", 0.1)) + + error_msg = str(error_info.value) + assert error_msg == "Threshold cannot be specified when robust mode is off" + + +def test_error_not_robust_power_call_threshold_str(): + pow = Power(robust=False) + + with pytest.raises(ValueError) as error_info: + pow(threshold="t") + + error_msg = str(error_info.value) + assert error_msg == "Threshold cannot be specified when robust mode is off" + + +def test_error_not_robust_power_call_threshold_float(): + pow = Power(robust=False) + + with pytest.raises(ValueError) as error_info: + pow(threshold=0.1) + + error_msg = str(error_info.value) + assert error_msg == "Threshold cannot be specified when robust mode is off" + + +def test_compile_robust_power_call_with_default_threshold(): + backend = ml.TorchBackend() + pow = Power(robust=True) + pm = ml.compile(pow, backend) + pm.evaluate(params={"base": backend.ones(3, 3), "exponent": backend.ones(3, 3)}) + + +@pytest.mark.skip( + reason="This test is not yet implemented. Naming convention bugs" + "should be fix when ToTensor like auto-added models created." +) +def test_error_robust_power_call_threshold_re_set_value(): + rob_pow = Model() + primitive_pow = Power(robust=True) + rob_pow += primitive_pow(threshold="threshold") + primitive_pow.set_values({"threshold": 1.3}) + from mithril.core import Constant + + mean_model = Model() + with pytest.raises(ValueError): + mean_model += rob_pow(threshold=Constant.MIN_POSITIVE_SUBNORMAL) + + +@pytest.mark.skip( + reason="This test is not yet implemented. Naming convention bugs" + "should be fix when ToTensor like auto-added models created." +) +def test_error_robust_power_call_threshold_input_keys(): + model1 = Model() + pow1 = Power(robust=True) + model1 += pow1(threshold=ml.IOKey("thres", 0.1)) + + model2 = Model() + pow2 = Power(robust=True) + model2 += pow2(threshold="thres") + model2.set_values({"thres": 0.1}) + + assert model1._input_keys == model2._input_keys diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index 5440b8f1..fcb78758 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -20,14 +20,8 @@ import numpy as np import pytest -from mithril import ( - JaxBackend, - MlxBackend, - NumpyBackend, - TorchBackend, - compile, - models, -) +from mithril import JaxBackend, MlxBackend, NumpyBackend, TorchBackend, compile, models +from mithril.framework.common import Tensor from mithril.utils.dict_conversions import dict_to_model from tests.scripts.test_utils import ( dict_to_random, @@ -183,6 +177,13 @@ def test_randomized(case: str) -> None: static_inputs[init_key] = shapes_to_random( dict_to_random(static_input_info, random_shapes), init_backend ) + static_inputs[init_key] = { + key: init_backend.array(value) + if isinstance(model.conns._get_metadata(key).data, Tensor) + else value + for key, value in static_inputs[init_key].items() + } + shapes: dict[str, list[int]] = {} for key, value in input_info.items(): shape = value["shapes"] @@ -227,8 +228,11 @@ def test_randomized(case: str) -> None: } static_inputs[backend.type] = { key: backend.array(value) + if isinstance(model.conns._get_metadata(key).data, Tensor) + else value for key, value in static_inputs[init_key].items() } + gradients[init_key] = compiled_model.evaluate_gradients( inputs[init_key], output_gradients=output_gradients[init_key] ) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index fedc1822..91f8b278 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -1502,8 +1502,8 @@ def test_cyclic_extension(): def test_canonic_example(): model = Model() - model += LeakyRelu() - model += LeakyRelu() + model += LeakyRelu()() + model += LeakyRelu()() comp_model = compile(model=model, backend=NumpyBackend()) assert set(comp_model._input_keys) == {"input", "_input_0"} assert set(comp_model.output_keys) == {"output"} @@ -3759,29 +3759,37 @@ def test_add_loss_unknown_key(): context = TrainModel(model) # Wrong keyword for loss - with pytest.raises(KeyError) as err_info: + with pytest.raises(TypeError) as err_info: context.add_loss(SquaredError(), inpu2t="output", target="target") - assert str(err_info.value) == '"The provided keys do not match the model\'s loss."' + assert ( + str(err_info.value) + == "SupervisedLoss.__call__() got an unexpected keyword argument 'inpu2t'" + ) - with pytest.raises(KeyError) as err_info: + with pytest.raises(TypeError) as err_info: context.add_loss(SquaredError(), input="output", targe2t="target") - assert str(err_info.value) == '"The provided keys do not match the model\'s loss."' + assert ( + str(err_info.value) + == "SupervisedLoss.__call__() got an unexpected keyword argument 'targe2t'" + ) # Wrong keyword for model - with pytest.raises(KeyError) as err_info: + with pytest.raises(KeyError) as key_err_info: context.add_loss(SquaredError(), input="output1", target="target") - assert str(err_info.value) == ( + assert str(key_err_info.value) == ( "'The provided keys are not valid; at least one of the keys must belong " "to the model!'" ) - with pytest.raises(KeyError) as err_info: + with pytest.raises(KeyError) as key_err_info: context.add_loss(SquaredError(), target="output") - assert str(err_info.value) == '"The provided keys do not match the model\'s loss."' + assert ( + str(key_err_info.value) == '"The provided keys do not match the model\'s loss."' + ) # Successfully add loss context.add_loss( @@ -6322,13 +6330,13 @@ def test_multi_write_2(): def test_multi_write_3(): model = Model() - l_relu1 = LeakyRelu(slope=0.85) + l_relu = Model() + l_relu += LeakyRelu()(slope=IOKey("slope", 0.85)) with pytest.raises(ValueError) as err_info: - model += l_relu1(input="input", output="output", slope=0.75) + model += l_relu(slope=0.75) assert str(err_info.value) == ( - "Value of LeakyRelu's slope given as 0.75. But the value is already " - "initialized as 0.85" + "Value is set before as 0.85. A scalar value can not be reset." ) @@ -6398,7 +6406,7 @@ def test_multi_write_8(): def test_leaky_relu_trainable_slope(): backend = JaxBackend() model = Model() - model += LeakyRelu(slope=TBD)(input="input", output="output", slope="slope") + model += LeakyRelu()(input="input", output="output", slope="slope") pm = mithril.compile(model=model, backend=backend) params = {"input": backend.array([-2.0, 2.0]), "slope": backend.array(0.2)} diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index 6bb294bd..cdab5509 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -1950,9 +1950,11 @@ def test_variadic_contradiction(): def test_cross_entropy_shapes_1(): model = Model() - ce = CrossEntropy(categorical=True) + ce = CrossEntropy() ce.set_shapes({"input": [8, 10], "target": [8]}) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=True, output=IOKey(name="output") + ) logical_ref = { "$_ToTensor_0_output": [], "$_input": None, @@ -1980,9 +1982,11 @@ def test_cross_entropy_shapes_1(): def test_cross_entropy_shapes_2(): model = Model() - ce = CrossEntropy(categorical=False) + ce = CrossEntropy() ce.set_shapes({"input": [8, 10]}) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=False, output=IOKey(name="output") + ) logical_ref = { "$_ToTensor_0_output": [], "$_input": None, @@ -2009,9 +2013,11 @@ def test_cross_entropy_shapes_2(): def test_cross_entropy_shapes_3(): model = Model() - ce = CrossEntropy(categorical=True) + ce = CrossEntropy() ce.set_shapes({"input": [8, 16, 32, 64], "target": [8, 32, 64]}) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=True, output=IOKey(name="output") + ) logical_ref = { "input": [8, 16, 32, 64], "output": [8, 32, 64], @@ -2037,10 +2043,12 @@ def test_cross_entropy_shapes_3(): def test_cross_entropy_shapes_5(): model = Model() - ce = CrossEntropy(categorical=True) + ce = CrossEntropy() shapes: dict[str, list] = {"input": [8, 16, ("V1", ...), 64], "target": [8, 32, 64]} ce.set_shapes(shapes) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=True, output=IOKey(name="output") + ) logical_ref = { "input": [8, 16, 32, 64], "output": [8, 32, 64], @@ -2066,10 +2074,12 @@ def test_cross_entropy_shapes_5(): def test_cross_entropy_shapes_6(): model = Model() - ce = CrossEntropy(categorical=True) + ce = CrossEntropy() shapes: dict[str, list] = {"input": [8, 16, ("V1", ...), 64], "output": [8, 32, 64]} ce.set_shapes(shapes) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=True, output=IOKey(name="output") + ) logical_ref = { "input": [8, 16, 32, 64], "output": [8, 32, 64], @@ -2096,9 +2106,11 @@ def test_cross_entropy_shapes_6(): def test_cross_entropy_shapes_7(): model = Model() shapes: dict[str, list] = {"input": [("V1", ...), 64], "target": [8, 16, 32, 64]} - ce = CrossEntropy(categorical=True) + ce = CrossEntropy() ce.set_shapes(shapes) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=True, output=IOKey(name="output") + ) logical_ref: Mapping[str, list | None] = { "input": [8, "u1", 16, 32, 64], "output": [8, 16, 32, 64], @@ -2125,9 +2137,11 @@ def test_cross_entropy_shapes_7(): def test_cross_entropy_shapes_8(): model = Model() shapes: dict[str, list] = {"input": [("V1", ...), 64], "target": [8, 16, 32, 64]} - ce = CrossEntropy(categorical=False) + ce = CrossEntropy() ce.set_shapes(shapes) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=False, output=IOKey(name="output") + ) logical_ref = { "input": [8, 16, 32, 64], "output": [8, 32, 64], @@ -2154,9 +2168,11 @@ def test_cross_entropy_shapes_8(): def test_cross_entropy_shapes_9(): model = Model() shapes: dict[str, list] = {"input": [8, 16, ("V1", ...), 64]} - ce = CrossEntropy(categorical=True) + ce = CrossEntropy() ce.set_shapes(shapes) - model += ce(input="input", target="target", output=IOKey(name="output")) + model += ce( + input="input", target="target", categorical=True, output=IOKey(name="output") + ) logical_ref: dict[str, list | None] = { "input": [8, 16, "(V1, ...)", 64], "output": [8, "(V1, ...)", 64], diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index f3a8afc5..f3f9de41 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -862,9 +862,10 @@ def test_table_5(): def test_physical_summary_1(): model = Model() model += Linear(dimension=5)(input="input") - model += LeakyRelu() + model += LeakyRelu()() model += (lin1 := Linear(dimension=3)) - model += LeakyRelu(slope=1e-1) + model += (l_relu := LeakyRelu()) + l_relu.set_values({"slope": 1e-1}) model += Relu() lin1.set_shapes({"input": [3, 5]}) comp_model = mithril.compile( @@ -886,7 +887,7 @@ def test_physical_summary_2(): model1 += Sigmoid() model = Model() model += Linear(dimension=5)(input="input") - model += LeakyRelu() + model += LeakyRelu()() model += Linear(dimension=3) model += model1 assert isinstance(model.canonical_input, Connection) @@ -1289,7 +1290,7 @@ def test_logical_model_summary_2(): model += Convolution2D(kernel_size=4, out_channels=4) model += Relu() model += Convolution2D(kernel_size=4, out_channels=4) - model += LeakyRelu() + model += LeakyRelu()() model += Flatten(start_dim=1) model += Linear(dimension=1) model += Sum() diff --git a/tests/scripts/test_train_context.py b/tests/scripts/test_train_context.py index e84b755f..0167fced 100644 --- a/tests/scripts/test_train_context.py +++ b/tests/scripts/test_train_context.py @@ -205,7 +205,7 @@ def test_add_loss_case_5(): output=IOKey(name="output1"), ) - assert str(err_info.value) == '"The provided keys do not match the model\'s loss."' + assert str(err_info.value) == "'Output of the loss model cannot be defined!'" def test_add_loss_case_6(): @@ -224,11 +224,8 @@ def test_add_loss_case_6(): ctx1.add_loss( Relu(), [Min(axis=-1)], input=relu3.output, output=IOKey(name="output1") ) - ctx1.add_loss( - Relu(), [Min(axis=-1)], input="output1", output=IOKey(name="output2") - ) - assert str(err_info.value) == '"The provided keys do not match the model\'s loss."' + assert str(err_info.value) == "'Output of the loss model cannot be defined!'" def test_add_loss_case_7(): diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 977cecd6..f8dcee32 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -862,9 +862,10 @@ def test_connect_type_conv_handling_1(): def test_type_initialization_1(): - model = LeakyRelu(slope=0.5) + model = Model() + model += LeakyRelu()(slope=IOKey("slope", 0.5)) - assert model.slope.metadata.data._type is float + assert model.slope.metadata.data._type is float # type: ignore def test_connect_1(): @@ -1355,7 +1356,7 @@ def test_coercion_2(): model = Model() reduce_model_1 = Sum(axis=TBD) reduce_model_2 = Sum(axis=TBD) - l_relu = LeakyRelu(slope=TBD) + l_relu = LeakyRelu() model += reduce_model_1(input="input1", axis="axis1") model += reduce_model_2(input="input2", axis="axis2") axis1 = reduce_model_1.axis.sum() diff --git a/tests/scripts/test_utils.py b/tests/scripts/test_utils.py index 5210c418..82b9d212 100644 --- a/tests/scripts/test_utils.py +++ b/tests/scripts/test_utils.py @@ -73,6 +73,7 @@ def dict_to_output_specs(specs_dict: dict[str, dict]) -> dict[str, dict]: result[key]["loss"] = model_dict[loss["fn"].lower()]( **loss.get("params", {}) ) + result[key]["loss_kwargs"] = loss.get("call_kwargs", {}) else: raise TypeError("Unsupported Loss type!") # Set reduce steps. @@ -122,7 +123,7 @@ def finalize_model(params: dict[str, Any]): if (target := spec.get("target_key")) is not None else {} ) - loss_kwargs = loss_input | loss_target + loss_kwargs = spec.get("loss_kwargs", {}) | loss_input | loss_target train_model.add_loss( loss_model=loss_model, reduce_steps=spec["reduce_steps"],