Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update set_values, set_shapes and set_types methods of BaseModel #21

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 74 additions & 46 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from itertools import chain
from types import UnionType
from typing import Any

Expand Down Expand Up @@ -213,38 +214,78 @@ def create_connection(self, metadata: IOHyperEdge, key: str) -> ConnectionData:
return con.data

def _set_shapes(
self, shapes: ShapesType, *, trace: bool = False, updates: Updates | None = None
self,
shapes: ShapesType,
trace: bool = False,
updates: Updates | None = None,
**kwargs: ShapeTemplateType,
) -> None:
if trace:
self.assigned_shapes.append(shapes)
# Initialize assigned shapes dictionary to store assigned shapes.
assigned_shapes = {}

if updates is None:
updates = Updates()

model = self._get_outermost_parent()
metadatas: OrderedSet[IOHyperEdge] = OrderedSet()
used_keys: dict[str | int, ShapeType] = {}
shape_nodes = {}
for key, shape in shapes.items():
# TODO: Can this be refactored to use a single loop?
for key, shape in chain(shapes.items(), kwargs.items()):
metadata = self.conns.extract_metadata(key)
if metadata is None:
raise KeyError("Requires valid IO connection to set shapes!")
if metadata in metadatas:
raise KeyError("shape of same connection has already given")
metadatas.add(metadata)
outer_conn = next(iter(model.conns.metadata_dict[metadata]))
given_repr = create_shape_repr(shape, model.constraint_solver, used_keys)
shape_nodes[outer_conn.key] = given_repr.node
for key, node in shape_nodes.items():
shape_node = model.conns.get_shape_node(key)
# Get inner string representation of the metadata and save
# use this name in order to merge .
conn = self.conns.get_con_by_metadata(metadata)
assert conn is not None
inner_key = conn.key
shape_nodes[key] = (given_repr.node, inner_key)
assigned_shapes[inner_key] = shape
# Apply updates to the shape nodes.
for key in chain(shapes, kwargs):
node, _inner_key = shape_nodes[key]
shape_node = self.conns.get_shape_node(_inner_key)
assert shape_node is not None
updates |= shape_node.merge(node)

if trace:
self.assigned_shapes.append(assigned_shapes)

model.constraint_solver(updates)

def set_shapes(self, shapes: ShapesType) -> None:
self._set_shapes(shapes, trace=True)
def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates:
"""
Set value for the given connection.

Args:
key (str | Connection): Connection key or Connection object to set value.
value (MainValueType): Value to set for the given connection.

def set_values(self, values: dict[str | Connection, MainValueType | str]) -> None:
Raises:
KeyError: If the provided key is not a valid IO key.
"""

if key.key not in self.conns.input_keys:
raise KeyError("Internal or output keys' values cannot be set.")
# Data is scalar, set the value directly.
return key.metadata.data.set_value(value)

def set_shapes(
self, config: ShapesType | None = None, **kwargs: ShapeTemplateType
) -> None:
if config is None:
config = {}
self._set_shapes(config, trace=True, updates=None, **kwargs)

def set_values(
self,
config: Mapping[str | Connection, MainValueType | str]
| Mapping[Connection, MainValueType | str]
| Mapping[str, MainValueType | str]
| None = None,
**kwargs: MainValueType | str,
) -> None:
"""
Set multiple values in the model.
This method updates values in the outermost model by traversing up the
Expand All @@ -253,23 +294,24 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non
with the updated values.

Args:
values (dict[str | Connection, MainValueType]): A dictionary where
config (dict[str | Connection, MainValueType]): A dictionary where
keys are either strings or Connection objects, and values are
of type MainValueType.
**kwargs (MainValueType): Key-value pairs where keys are string names
of connections present in this model.

Raises:
KeyError: If a valid key or Connection is not provided in the values
dictionary.
"""

# Make all value updates in the outermost model.
if config is None:
config = {}
# Make all value updates in the outermost model.s
model = self._get_outermost_parent()

updates = Updates()

# TODO: Currently Setting values in fozen models are prevented only for Tensors.
# Scalar and Tensors should not be operated differently. This should be fixed.
for key in values:
for key in chain(config, kwargs):
metadata = self.conns.extract_metadata(key)
if isinstance(metadata.data, Tensor) and model.is_frozen:
conn_data = model.conns.get_con_by_metadata(metadata)
Expand All @@ -278,24 +320,24 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non
f"Model is frozen, can not set the key: {conn_data.key}!"
)

for key, value in values.items():
for key, value in chain(config.items(), kwargs.items()):
# Perform metadata extraction process on self.
metadata = self.conns.extract_metadata(key)

# Perform validity check and updates on model.
if (conn_data := model.conns.get_con_by_metadata(metadata)) is None:
raise KeyError("Requires valid key or Connection to set values!")

updates |= model._set_value(conn_data, value)

# Solve constraints with the updated values.
model.constraint_solver(updates)

def set_types(
self,
types: Mapping[str | Connection, type | UnionType]
config: Mapping[str | Connection, type | UnionType]
| Mapping[Connection, type | UnionType]
| Mapping[str, type | UnionType],
| Mapping[str, type | UnionType]
| None = None,
**kwargs: type | UnionType,
):
"""
Set types of any connection in the Model
Expand All @@ -312,33 +354,19 @@ def set_types(
of type of type or UnionType objects.

"""
# get the outermost parent as all the updates will happen here
if config is None:
config = {}

# Get the outermost parent as all the updates will happen here.
model = self._get_outermost_parent()
updates = Updates()
for key, key_type in types.items():
for key, key_type in chain(config.items(), kwargs.items()):
metadata = self.conns.extract_metadata(key)
data = metadata.data
updates |= data.set_type(key_type)
# run the constraints for updating affected connections
# Run the constraints for updating affected connections.
model.constraint_solver(updates)

def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates:
"""
Set value for the given connection.

Args:
key (str | Connection): Connection key or Connection object to set value.
value (MainValueType): Value to set for the given connection.

Raises:
KeyError: If the provided key is not a valid IO key.
"""

if key.key not in self.conns.input_keys:
raise KeyError("Internal or output keys' values cannot be set.")
# Data is scalar, set the value directly.
return key.metadata.data.set_value(value)

def get_shapes(
self, uni_keys=None, var_keys=None, symbolic=True, verbose=False
) -> _ShapesType:
Expand Down
6 changes: 3 additions & 3 deletions mithril/utils/dict_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def shape_to_dict(shapes):
if isinstance(item, tuple): # variadic
shape_list.append(f"{item[0]},...")
else:
shape_list.append(str(item))
shape_list.append(item)
shape_dict[key] = shape_list
return shape_dict

Expand All @@ -559,10 +559,10 @@ def dict_to_shape(shape_dict):
for key, shape_list in shape_dict.items():
shapes[key] = []
for shape in shape_list:
if "..." in shape:
if isinstance(shape, str) and "..." in shape:
shapes[key].append((shape.split(",")[0], ...))
else:
shapes[key].append(int(shape))
shapes[key].append(shape)

return shapes

Expand Down
24 changes: 11 additions & 13 deletions tests/scripts/test_constant_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,13 +1752,12 @@ def test_unused_cached_values_1_set_values():
model = Model()
linear_model = Linear(dimension=2)
model += linear_model()
model.set_values(
{
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
linear_model.input: [[3.0], [2.0]],
}
)
config: dict[Connection, list] = {
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
linear_model.input: [[3.0], [2.0]],
}
model.set_values(config)
comp_model = mithril.compile(model=model, backend=(backend := NumpyBackend()))
dtype = backend.get_backend_array_type()
cache = comp_model.data_store.data_values
Expand Down Expand Up @@ -1823,12 +1822,11 @@ def test_unused_cached_values_2_set_values():
model = Model()
linear_model = Linear(dimension=2)
model += linear_model()
model.set_values(
{
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
}
)
config: dict[Connection, list] = {
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
}
model.set_values(config)
comp_model = mithril.compile(
model=model, backend=(backend := NumpyBackend()), safe_names=False
)
Expand Down
47 changes: 47 additions & 0 deletions tests/scripts/test_model_to_dict_rtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import re

import pytest

import mithril
from mithril import JaxBackend, TorchBackend
from mithril.framework.common import TBD, IOKey
Expand Down Expand Up @@ -58,6 +60,51 @@ def test_linear_expose():
)


def test_linear_expose_set_shapes():
model = Model()
lin_1 = Linear()
lin_2 = Linear()
model += lin_1(input="input", w="w")
model += lin_2(input=lin_1.output, w="w1", output=IOKey(name="output2"))
model.set_shapes({lin_1.b: [42]})
model.set_shapes({lin_2.b: [21]})
model_dict_created = dict_conversions.model_to_dict(model)
model_recreated = dict_conversions.dict_to_model(model_dict_created)
model_dict_recreated = dict_conversions.model_to_dict(model_recreated)

assert model_dict_created == model_dict_recreated
assert model.shapes == model_recreated.shapes
assert_models_equal(model, model_recreated)

backend = JaxBackend(precision=64)
assert_evaluations_equal(
model, model_recreated, backend, static_keys={"input": backend.ones([4, 256])}
)


@pytest.mark.skip(reason="Dict conversion does not support extend from inputs yet")
def test_linear_expose_set_shapes_extend_from_inputs():
model = Model()
lin_1 = Linear()
lin_2 = Linear()
model += lin_2(w="w1", output=IOKey(name="output2"))
model += lin_1(input="input", w="w", output=lin_2.input)
model.set_shapes({lin_1.b: [42]})
model.set_shapes({lin_2.b: [21]})
model_dict_created = dict_conversions.model_to_dict(model)
model_recreated = dict_conversions.dict_to_model(model_dict_created)
model_dict_recreated = dict_conversions.model_to_dict(model_recreated)

assert model_dict_created == model_dict_recreated
assert model.shapes == model_recreated.shapes
assert_models_equal(model, model_recreated)

backend = JaxBackend(precision=64)
assert_evaluations_equal(
model, model_recreated, backend, static_keys={"input": backend.ones([4, 256])}
)


def test_linear_set_diff():
model = Model()
linear = Linear(dimension=42)
Expand Down
Loading
Loading