Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor Differentiability #185

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 52 additions & 18 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,13 +827,15 @@
value: TensorValueType | ToBeDetermined = TBD,
type: _TensorTypes = int | float | bool,
shape: ShapeNode | None = None,
differentiable: bool = False,
):
if shape is None:
# If shape is not provided, create a new shape with a Variadic root.
shape = ShapeRepr(root=Variadic()).node
self.shape: ShapeNode = shape
self.type: _TensorTypes = type
self.referees: set[IOHyperEdge] = set()
self.differentiable = differentiable
# Initialize value as TBD and then set if any value is provided.
self.value: TensorValueType | ToBeDetermined = TBD
if not isinstance(value, ToBeDetermined):
Expand Down Expand Up @@ -891,6 +893,7 @@
# Transfer all referees of other to self and update all
# Tensors in all edges of other with self.
self.referees |= other.referees
self.differentiable |= other.differentiable
for edge in other.referees:
# TODO: Update here when we have list of tensors in an edge.
edge._value = self
Expand Down Expand Up @@ -927,7 +930,6 @@
type: set() for type in UpdateType
}
self._temp_shape: ShapeRepr | None = None # set random repr
self.differentiable: bool = False
self.interval: list[float | int] | None = interval
# Initially set type and value as not determined yet.
self._type = ToBeDetermined
Expand All @@ -938,6 +940,13 @@
if value is not TBD:
self.set_value(value)

@property
def differentiable(self) -> bool:
if self.is_tensor:
assert isinstance(self._value, Tensor)
return self._value.differentiable
return False

@property
def is_polymorphic(self) -> bool:
# Returns if the edge is of polymorphic type or not.
Expand Down Expand Up @@ -1020,6 +1029,13 @@
return self.value is TBD or self.value == _other_value
return True

def set_differentiability(self, differentiable: bool) -> None:
if self.is_tensor:
assert isinstance(self._value, Tensor)
self._value.differentiable = differentiable
elif differentiable:
raise ValueError("Non-tensor edges cannot be differentiable.")

Check warning on line 1037 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1037

Added line #L1037 was not covered by tests

def set_type(self, typ: type[Tensor[int | float | bool]] | ScalarType) -> Updates:
updates = Updates()
if self._type != typ:
Expand Down Expand Up @@ -1048,9 +1064,7 @@
# Add self as type update, set new type and update differentiability.
updates.add(self, UpdateType.TYPE)
self._type = new_type
self.differentiable = (self.value is TBD) and bool(
find_intersection_type(Tensor[float], self._type)
)

return updates

def set_value(
Expand Down Expand Up @@ -1100,8 +1114,11 @@
self._value = value
updates.add(self, UpdateType.VALUE)
updates.value_updates.add(self)
# Add self to updates.
self.differentiable = self.value is TBD

# # Add self to updates.
if self.value != TBD:
self.set_differentiability(False)

return updates

def match(self, other: IOHyperEdge) -> Updates:
Expand Down Expand Up @@ -1130,11 +1147,6 @@
self.constraints[type] |= other.constraints[type]
other.constraints[type] = set()

# Update differentiability.
if isinstance(self._value, Tensor) and self._value.value is TBD:
is_diff = self.differentiable | other.differentiable
# TODO: Is it required to set other as well?
self.differentiable = other.differentiable = is_diff
return updates

def add_constraint(self, constraint: Constraint) -> None:
Expand All @@ -1161,6 +1173,7 @@
| ScalarType
| None = None,
expose: bool | None = None,
differentiable: bool = False,
interval: list[float | int] | None = None,
connections: set[ConnectionData | str] | None = None,
) -> None:
Expand All @@ -1176,6 +1189,9 @@
# Convert to generic Tensor type if Tensor type is provided.
type = Tensor[int | float | bool]

if differentiable is True:
type = Tensor[float]

self.name = name
self.expose = expose
if connections is None:
Expand All @@ -1193,6 +1209,9 @@
f"Got {shape}."
)

if differentiable and value is not TBD:
raise ValueError("Scalar values can not be set as differentiable.")

Check warning on line 1213 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1213

Added line #L1213 was not covered by tests

if value is not TBD and type is not None:
value_type = find_type(value)
if find_intersection_type(value_type, type) is None:
Expand All @@ -1204,6 +1223,7 @@
self.value_shape = shape
self.type = type
self.interval = interval
self.differentiable = differentiable


@dataclass
Expand All @@ -1226,14 +1246,18 @@
def __eq__(self, other: object) -> bool:
return id(self) == id(other)

def set_differentiable(self, differentiable: bool = True) -> None:
def set_differentiable(self, differentiable: bool = True) -> Updates:
updates = Updates()
# TODO: Move this method to Model class as set_shapes, set_types etc.
if self.metadata.is_tensor:
self.metadata.differentiable = differentiable
assert isinstance(self.metadata._value, Tensor)
self.metadata.set_differentiability(differentiable)
elif differentiable:
if self.metadata.edge_type is not ToBeDetermined:
raise ValueError("Scalar data can not be set as differentiable.")
self.metadata.differentiable = differentiable
updates |= self.metadata.set_type(Tensor[float])
assert isinstance(self.metadata._value, Tensor)
self.metadata.set_differentiability(differentiable)

return updates


ShapesType = (
Expand Down Expand Up @@ -1332,6 +1356,14 @@
self._connection_dict[KeyType.INPUT] | self._connection_dict[KeyType.OUTPUT]
).keys()

@property
def valued_input_keys(self) -> set[str]:
return {
key
for key, conn in self._connection_dict[KeyType.INPUT].items()
if conn.metadata.is_valued
}

def add(
self,
connection: ConnectionData,
Expand Down Expand Up @@ -1377,10 +1409,12 @@
raise ValueError("No matching key type found!")

def get_non_diff_keys(self) -> set[str]:
return {key for key, conn in self.all.items() if conn.metadata.is_non_diff}
return {
key for key, conn in self.all.items() if not conn.metadata.differentiable
}

def is_key_non_diff(self, key: str) -> bool:
return self.get_data(key).is_non_diff
return not self.get_data(key).differentiable

Check warning on line 1417 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1417

Added line #L1417 was not covered by tests

def get_connection(self, key: str) -> ConnectionData | None:
internals = self._connection_dict[KeyType.INTERNAL]
Expand Down
19 changes: 17 additions & 2 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,18 @@ def _prepare_keys(
case NullConnection():
_connection = BaseKey()
case str():
_connection = BaseKey(name=connection)
given_conn = self.conns.get_connection(connection)
diff = (
given_conn.metadata.differentiable
if given_conn is not None
else local_connection.metadata.differentiable
)
_connection = BaseKey(name=connection, differentiable=diff)
case ConnectionData():
_connection = BaseKey(connections={connection})
_connection = BaseKey(
connections={connection},
differentiable=connection.metadata.differentiable,
)
case _ if isinstance(connection, MainValueInstance | Tensor):
# find_dominant_type returns the dominant type in a container.
# If a container has a value of type Connection or ExtendTemplate
Expand All @@ -231,6 +240,7 @@ def _prepare_keys(
type=connection.type,
shape=connection.value_shape,
value=connection.value,
differentiable=connection.differentiable,
)

return _connection
Expand All @@ -250,6 +260,9 @@ def _add_connection(
d_map = self.dependency_map.local_output_dependency_map
expose = given_connection.expose
outer_key = given_connection.name
differentiable = (
given_connection.differentiable | local_connection.metadata.differentiable
) and given_connection.value is TBD
con_obj = None
set_value: (
ToBeDetermined
Expand Down Expand Up @@ -347,6 +360,8 @@ def _add_connection(
if not isinstance(set_value, NullConnection):
updates |= con_obj.metadata.set_value(set_value)

updates |= con_obj.set_differentiable(differentiable)

# Check multi-write error for con_obj.
self._check_multi_write(is_input, local_connection, con_obj)

Expand Down
3 changes: 3 additions & 0 deletions mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def __init__(
| ScalarType
| None = None,
expose: bool | None = None,
differantiable: bool = False,
interval: list[float | int] | None = None,
connections: set[Connection | str] | None = None,
) -> None:
Expand All @@ -428,6 +429,7 @@ def __init__(
expose=expose,
interval=interval,
connections=_connections,
differentiable=differantiable,
)


Expand Down Expand Up @@ -591,6 +593,7 @@ def _prepare_keys(
type=connection.type,
shape=connection.value_shape,
value=connection.value,
differentiable=connection.differentiable,
)
case _:
_connection = connection # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions mithril/framework/logical/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
tensor = Tensor(
type=get_args(value.type)[0],
shape=shapes[key].node,
differentiable=value.differentiable,
)
edge = IOHyperEdge(value=tensor, interval=value.interval)
data_set.add(edge)
Expand Down Expand Up @@ -144,3 +145,6 @@ def extend(
**kwargs: ConnectionDataType,
) -> None:
raise NotImplementedError("Operators cannot be extended!")

def infer_differentiability(self, output: bool, *inputs: bool) -> bool:
return any(inputs)
7 changes: 6 additions & 1 deletion mithril/framework/logical/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ def __init__(
super().__init__(
formula_key="floor_divide",
name=name,
output=BaseKey(type=Tensor[int | float] | int | float),
output=BaseKey(
type=Tensor[int | float] | int | float, differentiable=False
),
numerator=BaseKey(value=numerator),
denominator=BaseKey(value=denominator),
)
Expand Down Expand Up @@ -464,6 +466,9 @@ def __init__(
dependencies={bcast_constraint},
)

def infer_differentiability(self, output: bool, *inputs: bool) -> bool:
return False


class MatrixMultiplyOp(Operator):
_model_name: str = "MatrixMultiply"
Expand Down
12 changes: 11 additions & 1 deletion mithril/framework/logical/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@ def __init__(
name: str | None = None,
) -> None:
super().__init__(name=name, enforce_jit=model._jittable)
self._extend(model, {k: IOKey(k, expose=True) for k in model.external_keys})
self._extend(
model,
{
k: IOKey(
k,
expose=True,
differantiable=model.conns.all[k].metadata.differentiable,
)
for k in model.external_keys
},
)

@property
def submodel(self) -> Operator:
Expand Down
29 changes: 16 additions & 13 deletions mithril/framework/physical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(

if global_key in self._non_differentiable_keys:
# TODO: Create an API for setting differentiability of a tensor.
physical_data.differentiable = False
physical_data.set_differentiability(False)
elif global_key in self._trainable_tensor_inputs:
# if physical_data.edge_type not in (Tensor, ToBeDetermined):
if not (
Expand All @@ -204,7 +204,7 @@ def __init__(
raise ValueError(
f"Valued data can not be trainable: {global_key}"
)
physical_data.differentiable = True
physical_data.set_differentiability(True)

model_data[key] = physical_data
self.flat_graph.data_memo[id(logical_data)] = physical_data
Expand All @@ -224,7 +224,7 @@ def __init__(
output = Operator.output_key
_data_dict: dict[str, IOHyperEdge] = {}

self._infer_differentiability(model_data)
self._infer_differentiability(p_model, model_data)
for inner_key in p_model.external_keys:
outer_key = mappings[inner_key]
if outer_key not in self.data:
Expand Down Expand Up @@ -414,21 +414,24 @@ def output_keys(self) -> list[str]:
def input_keys(self) -> set[str]:
return self._input_keys

def _infer_differentiability(self, model_data: dict[str, IOHyperEdge]) -> None:
def _infer_differentiability(
self, p_model: Operator, model_data: dict[str, IOHyperEdge]
) -> None:
# Infer output differentiability only for the models
# that have a Tensor type output.
output_key = Operator.output_key
output_edge = model_data[output_key]
input_diffs = [
value.differentiable
for key, value in model_data.items()
if key != output_key
]

if output_edge.is_tensor:
# If any of the inputs are differentiable, then
# the output is also differentiable.
for key, value in model_data.items():
if key != output_key and not value.is_non_diff:
output_edge.differentiable = True
return
# If all inputs are non-differentiable, then the output is also
# non-differentiable.
output_edge.differentiable = False
diff = p_model.infer_differentiability(
output_edge.differentiable, *input_diffs
)
output_edge.set_differentiability(diff)

def randomize_params(
self,
Expand Down
Loading