Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor Differentiability #185

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ def create_gpt(bias, block_size, dims, num_heads, num_layers, vocab_size):
gpt = Model()
gpt += transformer(input="input")
gpt += Linear(vocab_size, use_bias=False, name="lm_head")(output=IOKey("output"))
gpt.input.set_differentiable(False) # type: ignore
gpt.set_differentiability({gpt.input: False}) # type: ignore
return gpt
70 changes: 50 additions & 20 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,13 +827,15 @@
value: TensorValueType | ToBeDetermined = TBD,
type: _TensorTypes = int | float | bool,
shape: ShapeNode | None = None,
differentiable: bool = False,
):
if shape is None:
# If shape is not provided, create a new shape with a Variadic root.
shape = ShapeRepr(root=Variadic()).node
self.shape: ShapeNode = shape
self.type: _TensorTypes = type
self.referees: set[IOHyperEdge] = set()
self.differentiable = differentiable
# Initialize value as TBD and then set if any value is provided.
self.value: TensorValueType | ToBeDetermined = TBD
if not isinstance(value, ToBeDetermined):
Expand All @@ -860,6 +862,7 @@
raise ValueError(
f"Value is set before as {self.value}. A value can not be reset."
)

updates = Updates()
# Set value.
if self.value is TBD:
Expand All @@ -874,6 +877,8 @@
updates.add(edge, update_type=UpdateType.VALUE)
updates.add(edge, update_type=UpdateType.SHAPE)
self.value = val

self.differentiable = False
return updates

def match(self, other: Tensor[int | float | bool]) -> Updates:
Expand All @@ -888,9 +893,16 @@
)
assert not isinstance(valued.value, ToBeDetermined)
updates |= non_valued.set_value(valued.value)

self.differentiable = False
other.differentiable = False
else:
self.differentiable |= other.differentiable

# Transfer all referees of other to self and update all
# Tensors in all edges of other with self.
self.referees |= other.referees

for edge in other.referees:
# TODO: Update here when we have list of tensors in an edge.
edge._value = self
Expand Down Expand Up @@ -927,7 +939,6 @@
type: set() for type in UpdateType
}
self._temp_shape: ShapeRepr | None = None # set random repr
self.differentiable: bool = False
self.interval: list[float | int] | None = interval
# Initially set type and value as not determined yet.
self._type = ToBeDetermined
Expand All @@ -938,6 +949,13 @@
if value is not TBD:
self.set_value(value)

@property
def differentiable(self) -> bool:
if self.is_tensor:
assert isinstance(self._value, Tensor)
return self._value.differentiable
return False

@property
def is_polymorphic(self) -> bool:
# Returns if the edge is of polymorphic type or not.
Expand Down Expand Up @@ -1020,6 +1038,13 @@
return self.value is TBD or self.value == _other_value
return True

def set_differentiability(self, differentiable: bool) -> None:
if self.is_tensor:
assert isinstance(self._value, Tensor)
self._value.differentiable = differentiable
elif differentiable:
raise ValueError("Non-tensor edges cannot be differentiable.")

Check warning on line 1046 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1046

Added line #L1046 was not covered by tests

def set_type(self, typ: type[Tensor[int | float | bool]] | ScalarType) -> Updates:
updates = Updates()
if self._type != typ:
Expand Down Expand Up @@ -1048,9 +1073,7 @@
# Add self as type update, set new type and update differentiability.
updates.add(self, UpdateType.TYPE)
self._type = new_type
self.differentiable = (self.value is TBD) and bool(
find_intersection_type(Tensor[float], self._type)
)

return updates

def set_value(
Expand Down Expand Up @@ -1100,8 +1123,11 @@
self._value = value
updates.add(self, UpdateType.VALUE)
updates.value_updates.add(self)
# Add self to updates.
self.differentiable = self.value is TBD

# # Add self to updates.
if self.value != TBD:
self.set_differentiability(False)

return updates

def match(self, other: IOHyperEdge) -> Updates:
Expand Down Expand Up @@ -1130,11 +1156,6 @@
self.constraints[type] |= other.constraints[type]
other.constraints[type] = set()

# Update differentiability.
if isinstance(self._value, Tensor) and self._value.value is TBD:
is_diff = self.differentiable | other.differentiable
# TODO: Is it required to set other as well?
self.differentiable = other.differentiable = is_diff
return updates

def add_constraint(self, constraint: Constraint) -> None:
Expand All @@ -1161,6 +1182,7 @@
| ScalarType
| None = None,
expose: bool | None = None,
differentiable: bool = False,
interval: list[float | int] | None = None,
connections: set[ConnectionData | str] | None = None,
) -> None:
Expand All @@ -1176,6 +1198,9 @@
# Convert to generic Tensor type if Tensor type is provided.
type = Tensor[int | float | bool]

if differentiable:
type = Tensor[float]

self.name = name
self.expose = expose
if connections is None:
Expand All @@ -1193,6 +1218,9 @@
f"Got {shape}."
)

if differentiable and value is not TBD:
raise ValueError("Scalar values can not be set as differentiable.")

Check warning on line 1222 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1222

Added line #L1222 was not covered by tests

if value is not TBD and type is not None:
value_type = find_type(value)
if find_intersection_type(value_type, type) is None:
Expand All @@ -1204,6 +1232,7 @@
self.value_shape = shape
self.type = type
self.interval = interval
self.differentiable = differentiable


@dataclass
Expand All @@ -1226,14 +1255,16 @@
def __eq__(self, other: object) -> bool:
return id(self) == id(other)

def set_differentiable(self, differentiable: bool = True) -> None:
def set_differentiability(self, differentiable: bool = True) -> Updates:
updates = Updates()
# TODO: Move this method to Model class as set_shapes, set_types etc.
if self.metadata.is_tensor:
self.metadata.differentiable = differentiable
self.metadata.set_differentiability(differentiable)
elif differentiable:
if self.metadata.edge_type is not ToBeDetermined:
raise ValueError("Scalar data can not be set as differentiable.")
self.metadata.differentiable = differentiable
updates |= self.metadata.set_type(Tensor[float])
self.metadata.set_differentiability(differentiable)

return updates


ShapesType = (
Expand Down Expand Up @@ -1377,10 +1408,9 @@
raise ValueError("No matching key type found!")

def get_non_diff_keys(self) -> set[str]:
return {key for key, conn in self.all.items() if conn.metadata.is_non_diff}

def is_key_non_diff(self, key: str) -> bool:
return self.get_data(key).is_non_diff
return {
key for key, conn in self.all.items() if not conn.metadata.differentiable
}

def get_connection(self, key: str) -> ConnectionData | None:
internals = self._connection_dict[KeyType.INTERNAL]
Expand Down
26 changes: 26 additions & 0 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
type=connection.type,
shape=connection.value_shape,
value=connection.value,
differentiable=connection.differentiable,
)

return _connection
Expand All @@ -250,6 +251,7 @@
d_map = self.dependency_map.local_output_dependency_map
expose = given_connection.expose
outer_key = given_connection.name

con_obj = None
set_value: (
ToBeDetermined
Expand Down Expand Up @@ -347,6 +349,9 @@
if not isinstance(set_value, NullConnection):
updates |= con_obj.metadata.set_value(set_value)

if given_connection.differentiable:
updates |= con_obj.set_differentiability(True)

# Check multi-write error for con_obj.
self._check_multi_write(is_input, local_connection, con_obj)

Expand Down Expand Up @@ -1072,6 +1077,27 @@
self.conns.add(con)
return con

def _set_differentiability(
self, config: dict[str | ConnectionData, bool], **kwargs: bool
) -> None:
updates = Updates()

for key, value in chain(config.items(), kwargs.items()):
if isinstance(key, str):
if key not in self.conns.all:
raise KeyError(f"Connection {key} is not found in the model.")

Check warning on line 1088 in mithril/framework/logical/base.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/logical/base.py#L1088

Added line #L1088 was not covered by tests

conn_data = self.conns.all[key]
updates |= conn_data.set_differentiability(value)
elif isinstance(key, ConnectionData):
if key not in self.conns.all.values():
raise KeyError(f"Connection {key} is not found in the model.")

Check warning on line 1094 in mithril/framework/logical/base.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/logical/base.py#L1094

Added line #L1094 was not covered by tests

updates |= key.set_differentiability(value)

model = self._get_outermost_parent()
model.constraint_solver(updates)

def _set_shapes(
self,
shapes: ShapesType,
Expand Down
18 changes: 15 additions & 3 deletions mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ def key(self) -> str:
def metadata(self) -> IOHyperEdge:
return self.data.metadata

def set_differentiable(self, differentiable: bool = True) -> None:
self.data.set_differentiable(differentiable)

def __hash__(self) -> int:
return hash(id(self))

Expand All @@ -412,6 +409,7 @@ def __init__(
| ScalarType
| None = None,
expose: bool | None = None,
differantiable: bool = False,
interval: list[float | int] | None = None,
connections: set[Connection | str] | None = None,
) -> None:
Expand All @@ -428,6 +426,7 @@ def __init__(
expose=expose,
interval=interval,
connections=_connections,
differentiable=differantiable,
)


Expand Down Expand Up @@ -591,6 +590,7 @@ def _prepare_keys(
type=connection.type,
shape=connection.value_shape,
value=connection.value,
differentiable=connection.differentiable,
)
case _:
_connection = connection # type: ignore
Expand Down Expand Up @@ -725,6 +725,18 @@ def __or__(self, info: ExtendInfo | Model) -> Self:
| Mapping[Connection, ShapeTemplateType]
)

def set_differentiability(
self, config: dict[str | Connection, bool] | None = None, **kwargs: bool
) -> None:
if config is None:
config = {}

_config: dict[str | ConnectionData, bool] = {
key.data if isinstance(key, Connection) else key: value
for key, value in config.items()
}
self._set_differentiability(_config, **kwargs)

def set_shapes(
self, config: ShapeType | None = None, **kwargs: ShapeTemplateType
) -> None:
Expand Down
8 changes: 8 additions & 0 deletions mithril/framework/logical/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
tensor = Tensor(
type=get_args(value.type)[0],
shape=shapes[key].node,
differentiable=value.differentiable,
)
edge = IOHyperEdge(value=tensor, interval=value.interval)
data_set.add(edge)
Expand Down Expand Up @@ -144,3 +145,10 @@ def extend(
**kwargs: ConnectionDataType,
) -> None:
raise NotImplementedError("Operators cannot be extended!")

def infer_differentiability(self, *inputs: bool) -> bool:
# Function to infer differentiability of the operator
# based on the differentiability of its inputs

# If any of the inputs are differentiable, the output is differentiable
return any(inputs)
7 changes: 6 additions & 1 deletion mithril/framework/logical/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ def __init__(
super().__init__(
formula_key="floor_divide",
name=name,
output=BaseKey(type=Tensor[int | float] | int | float),
output=BaseKey(
type=Tensor[int | float] | int | float, differentiable=False
),
numerator=BaseKey(value=numerator),
denominator=BaseKey(value=denominator),
)
Expand Down Expand Up @@ -464,6 +466,9 @@ def __init__(
dependencies={bcast_constraint},
)

def infer_differentiability(self, *inputs: bool) -> bool:
return False


class MatrixMultiplyOp(Operator):
_model_name: str = "MatrixMultiply"
Expand Down
12 changes: 11 additions & 1 deletion mithril/framework/logical/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@ def __init__(
name: str | None = None,
) -> None:
super().__init__(name=name, enforce_jit=model._jittable)
self._extend(model, {k: IOKey(k, expose=True) for k in model.external_keys})
self._extend(
model,
{
k: IOKey(
k,
expose=True,
differantiable=model.conns.all[k].metadata.differentiable,
)
for k in model.external_keys
},
)

@property
def submodel(self) -> Operator:
Expand Down
Loading