Skip to content

Commit

Permalink
refactor: Refactor Tensors' temporary value storage logic (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
kberat-synnada authored Nov 21, 2024
1 parent 378c534 commit 39e3eb0
Show file tree
Hide file tree
Showing 25 changed files with 873 additions and 773 deletions.
53 changes: 8 additions & 45 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ class NotAvailable(SingletonObject):
pass


# class ToBeDetermined(SingletonObject):
# """
# A singleton class representing a null data indicating
# that no data is provided.
# """
class Auto(SingletonObject):
"""
A singleton class representing a configuration
setting of automatically handled arguments.
"""

# pass
pass


class ToBeDetermined(SingletonObject):
Expand All @@ -134,9 +134,7 @@ class ToBeDetermined(SingletonObject):
NOT_GIVEN = NullConnection()
NOT_AVAILABLE = NotAvailable()
TBD = ToBeDetermined()
# TBD = ToBeDetermined()
# TBD = TBD
# ToBeDetermined = ToBeDetermined
AUTO = Auto()


class UpdateType(Enum):
Expand Down Expand Up @@ -645,7 +643,6 @@ def match(self, other: Tensor | Scalar) -> Updates:

class Tensor(BaseData, GenericDataType[DataType]):
_type: type[float] | type[int] | type[bool] | UnionType
temp_value: TensorValueType | ToBeDetermined
value: DataType | ToBeDetermined

def __init__(
Expand All @@ -657,12 +654,10 @@ def __init__(
) -> None:
super().__init__(type=possible_types)
self._differentiable: bool = value is TBD
self.temp_value = TBD
self.value = TBD
# Update type if any value is given.
if not isinstance(value, ToBeDetermined):
self.set_type(find_dominant_type(value))
self.temp_value = value
else:
self.value = value
self.shape: ShapeNode = shape
Expand All @@ -681,21 +676,11 @@ def is_valued(self) -> bool:
def _set_as_physical(self):
super()._set_as_physical()

def _convert_value(self, backend: Backend) -> DataType | ToBeDetermined:
if isinstance(self.temp_value, Constant):
self.value = backend.array(
epsilon_table[backend.precision][self.temp_value]
)
elif self.temp_value is not TBD:
self.value = backend.array(self.temp_value)
return self.value

def make_physical(self, backend: Backend, memo: dict[int, Tensor | Scalar]):
physical_tensor = deepcopy(self, memo)
# Update data as physical data.
physical_tensor._set_as_physical()
# Update value of physical data taking backend into account.
physical_tensor._convert_value(backend)
return physical_tensor

def __deepcopy__(self, memo: dict[int, Tensor | Scalar]):
Expand Down Expand Up @@ -732,31 +717,9 @@ def match_shapes(self, other: Tensor): # type: ignore[override]
# which requires handling of interval arithmetic in logical level also.

def set_value(self, value: DataType | TensorValueType) -> Updates: # type: ignore[override]
if self._logical_data:
assert isinstance(value, TensorValueType)
return self._set_logical_value(value)
else:
if not self._logical_data:
assert self.is_tensor_type(value)
return self._set_physical_value(value)

def _set_logical_value(self, value: TensorValueType) -> Updates:
if isinstance(value, ToBeDetermined):
if self.temp_value is not TBD:
raise ValueError(
f"Value is set before as {self.temp_value}. Can not be reset."
)
# if self.value is TBD and value is None:
# raise ValueError(
# "Already set as non-differentiable. Can not be reverted \
# to a differentiable state."
# )
self.value = value
else:
if self.temp_value is not TBD and self.temp_value != value:
raise ValueError(
f"Value is set before as {self.temp_value}. Can not be reset."
)
self.temp_value = value
return Updates()

def _set_physical_value(self, value: DataType) -> Updates:
Expand Down
29 changes: 29 additions & 0 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3495,6 +3495,35 @@ def tuple_converter_constraint(output: Scalar, input: Scalar) -> ConstrainResult
return status, updates


def cross_entropy_constraint(
categorical: Scalar, input: Tensor, target: Tensor
) -> ConstrainResultType:
assert input._temp_shape is not None, "Input shape of reverse is not set!"
assert target._temp_shape is not None, "Target shape of reverse is not set!"

status = False
updates = Updates()
categorical_value = categorical.value

input_shape: ShapeRepr = input._temp_shape
target_shape: ShapeRepr = target._temp_shape

if categorical_value is not TBD:
if not categorical_value:
updates |= target_shape._match(input_shape)
else:
N = Uniadic()
C = Uniadic()
var = Variadic()
in_repr = ShapeRepr([N, C], var)
target_repr = ShapeRepr([N], var)
updates = input_shape._match(in_repr)
updates = target_shape._match(target_repr)

status = True
return status, updates


type_constraints = {
general_tensor_type_constraint,
floor_divide_type_constraint,
Expand Down
16 changes: 14 additions & 2 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, enforce_jit: bool = True) -> None:
self._jittable = True
self.constraint_solver: ConstraintSolver = ConstraintSolver()
self.safe_shapes: dict[str, ShapeTemplateType] = {}
self.is_frozen = False

@abc.abstractmethod
def summary(
Expand Down Expand Up @@ -174,8 +175,8 @@ def __setattr__(self, name: str, value: Any):
else:
super().__setattr__(name, value)

@abc.abstractmethod
def _freeze(self) -> None: ...
def _freeze(self) -> None:
self.is_frozen = True

@abc.abstractmethod
def extract_connection_info(
Expand Down Expand Up @@ -266,6 +267,17 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non

updates = Updates()

# TODO: Currently Setting values in fozen models are prevented only for Tensors.
# Scalar and Tensors should not be operated differently. This should be fixed.
for key in values:
metadata = self.conns.extract_metadata(key)
if isinstance(metadata.data, Tensor) and model.is_frozen:
conn_data = model.conns.get_con_by_metadata(metadata)
assert conn_data is not None
raise ValueError(
f"Model is frozen, can not set the key: {conn_data.key}!"
)

for key, value in values.items():
# Perform metadata extraction process on self.
metadata = self.conns.extract_metadata(key)
Expand Down
56 changes: 21 additions & 35 deletions mithril/framework/logical/essential_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,12 @@ def __call__( # type: ignore[override]
return super().__call__(input=input, output=output)


ToTupleOutputType = tuple[int | float | bool | list | tuple, ...]


class ToTuple(PrimitiveModel):
def __init__(self, n: int) -> None:
self.factory_args = {"n": n}
key_definitions = {}
key_definitions["output"] = Scalar(ToTupleOutputType)
key_definitions = {
"output": Scalar(tuple[int | float | bool | list | tuple, ...])
}
key_definitions |= {
f"input{idx+1}": Scalar(int | float | bool | list | tuple)
for idx in range(n)
Expand Down Expand Up @@ -182,15 +180,14 @@ def __call__( # type: ignore[override]
class Power(PrimitiveModel):
base: Connection
exponent: Connection
threshold: Connection
output: Connection

def __init__(
self,
robust: bool = False,
threshold: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL,
) -> None:
self.factory_args = {"threshold": threshold, "robust": robust}
self.robust = robust
self.factory_args = {"robust": robust}
assert isinstance(robust, bool), "Robust must be a boolean value!"

if robust:
Expand All @@ -199,16 +196,10 @@ def __init__(
output=TensorType([("Var_out", ...)]),
base=TensorType([("Var_1", ...)]),
exponent=TensorType([("Var_2", ...)]),
threshold=TensorType([], ConstantType, threshold),
threshold=TensorType([], ConstantType),
)

self.threshold.set_differentiable(False) # type: ignore
else:
if threshold != Constant.MIN_POSITIVE_NORMAL:
raise KeyError(
"Threshold cannot be specified \
when robust mode is off"
)

super().__init__(
formula_key="power",
output=TensorType([("Var_out", ...)]),
Expand All @@ -228,13 +219,15 @@ def __call__( # type: ignore[override]
self,
base: ConnectionType = NOT_GIVEN,
exponent: ConnectionType = NOT_GIVEN,
threshold: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
*,
threshold: ConnectionType = Constant.MIN_POSITIVE_NORMAL,
) -> ExtendInfo:
kwargs = {"base": base, "exponent": exponent, "output": output}
if "threshold" in self._input_keys:
is_constant = isinstance(threshold, Constant)
if self.robust:
kwargs["threshold"] = threshold
elif threshold != NOT_GIVEN:
elif not (is_constant and threshold is Constant.MIN_POSITIVE_NORMAL):
raise ValueError("Threshold cannot be specified when robust mode is off")

return super().__call__(**kwargs)
Expand Down Expand Up @@ -941,26 +934,21 @@ def __init__(self) -> None:

class Sqrt(PrimitiveModel):
input: Connection
cutoff: Connection
output: Connection

def __init__(
self,
robust: bool = False,
cutoff: ConstantType | ToBeDetermined = Constant.MIN_POSITIVE_NORMAL,
) -> None:
self.factory_args = {"robust": robust, "cutoff": cutoff}
self.robust = robust
self.factory_args = {"robust": robust}

if robust:
if isinstance(cutoff, str) and cutoff != Constant.MIN_POSITIVE_NORMAL:
raise ValueError(f"cutoff can only be set to 'min_positive_normal' \
in string format, got {cutoff}")

super().__init__(
formula_key="robust_sqrt",
output=TensorType([("Var", ...)], float),
input=TensorType([("Var", ...)]),
cutoff=TensorType([], ConstantType, cutoff),
cutoff=TensorType([], ConstantType),
)
else:
super().__init__(
Expand All @@ -972,19 +960,17 @@ def __init__(
def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
cutoff: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
*,
cutoff: ConnectionType = Constant.MIN_POSITIVE_NORMAL,
) -> ExtendInfo:
kwargs = {"input": input, "output": output}

if self.formula_key == "sqrt" and cutoff != NOT_GIVEN:
raise ValueError(
"Sqrt does not accept cutoff argument \
when initialized with robust = False."
)

if self.formula_key == "robust_sqrt":
is_constant = isinstance(cutoff, Constant)
if self.robust:
kwargs["cutoff"] = cutoff
elif not (is_constant and cutoff == Constant.MIN_POSITIVE_NORMAL):
raise ValueError("Cutoff cannot be specified when robust mode is off")

return super().__call__(**kwargs)

Expand Down
Loading

0 comments on commit 39e3eb0

Please sign in to comment.