Skip to content

Commit

Permalink
feat: Update set_values, set_shapes and set_types methods of `B…
Browse files Browse the repository at this point in the history
…aseModel` (#21)
  • Loading branch information
norhan-synnada authored Nov 22, 2024
1 parent f133a1d commit b98d025
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 82 deletions.
120 changes: 74 additions & 46 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from itertools import chain
from types import UnionType
from typing import Any

Expand Down Expand Up @@ -213,38 +214,78 @@ def create_connection(self, metadata: IOHyperEdge, key: str) -> ConnectionData:
return con.data

def _set_shapes(
self, shapes: ShapesType, *, trace: bool = False, updates: Updates | None = None
self,
shapes: ShapesType,
trace: bool = False,
updates: Updates | None = None,
**kwargs: ShapeTemplateType,
) -> None:
if trace:
self.assigned_shapes.append(shapes)
# Initialize assigned shapes dictionary to store assigned shapes.
assigned_shapes = {}

if updates is None:
updates = Updates()

model = self._get_outermost_parent()
metadatas: OrderedSet[IOHyperEdge] = OrderedSet()
used_keys: dict[str | int, ShapeType] = {}
shape_nodes = {}
for key, shape in shapes.items():
# TODO: Can this be refactored to use a single loop?
for key, shape in chain(shapes.items(), kwargs.items()):
metadata = self.conns.extract_metadata(key)
if metadata is None:
raise KeyError("Requires valid IO connection to set shapes!")
if metadata in metadatas:
raise KeyError("shape of same connection has already given")
metadatas.add(metadata)
outer_conn = next(iter(model.conns.metadata_dict[metadata]))
given_repr = create_shape_repr(shape, model.constraint_solver, used_keys)
shape_nodes[outer_conn.key] = given_repr.node
for key, node in shape_nodes.items():
shape_node = model.conns.get_shape_node(key)
# Get inner string representation of the metadata and save
# use this name in order to merge .
conn = self.conns.get_con_by_metadata(metadata)
assert conn is not None
inner_key = conn.key
shape_nodes[key] = (given_repr.node, inner_key)
assigned_shapes[inner_key] = shape
# Apply updates to the shape nodes.
for key in chain(shapes, kwargs):
node, _inner_key = shape_nodes[key]
shape_node = self.conns.get_shape_node(_inner_key)
assert shape_node is not None
updates |= shape_node.merge(node)

if trace:
self.assigned_shapes.append(assigned_shapes)

model.constraint_solver(updates)

def set_shapes(self, shapes: ShapesType) -> None:
self._set_shapes(shapes, trace=True)
def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates:
"""
Set value for the given connection.
Args:
key (str | Connection): Connection key or Connection object to set value.
value (MainValueType): Value to set for the given connection.
def set_values(self, values: dict[str | Connection, MainValueType | str]) -> None:
Raises:
KeyError: If the provided key is not a valid IO key.
"""

if key.key not in self.conns.input_keys:
raise KeyError("Internal or output keys' values cannot be set.")
# Data is scalar, set the value directly.
return key.metadata.data.set_value(value)

def set_shapes(
self, config: ShapesType | None = None, **kwargs: ShapeTemplateType
) -> None:
if config is None:
config = {}
self._set_shapes(config, trace=True, updates=None, **kwargs)

def set_values(
self,
config: Mapping[str | Connection, MainValueType | str]
| Mapping[Connection, MainValueType | str]
| Mapping[str, MainValueType | str]
| None = None,
**kwargs: MainValueType | str,
) -> None:
"""
Set multiple values in the model.
This method updates values in the outermost model by traversing up the
Expand All @@ -253,23 +294,24 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non
with the updated values.
Args:
values (dict[str | Connection, MainValueType]): A dictionary where
config (dict[str | Connection, MainValueType]): A dictionary where
keys are either strings or Connection objects, and values are
of type MainValueType.
**kwargs (MainValueType): Key-value pairs where keys are string names
of connections present in this model.
Raises:
KeyError: If a valid key or Connection is not provided in the values
dictionary.
"""

# Make all value updates in the outermost model.
if config is None:
config = {}
# Make all value updates in the outermost model.s
model = self._get_outermost_parent()

updates = Updates()

# TODO: Currently Setting values in fozen models are prevented only for Tensors.
# Scalar and Tensors should not be operated differently. This should be fixed.
for key in values:
for key in chain(config, kwargs):
metadata = self.conns.extract_metadata(key)
if isinstance(metadata.data, Tensor) and model.is_frozen:
conn_data = model.conns.get_con_by_metadata(metadata)
Expand All @@ -278,24 +320,24 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non
f"Model is frozen, can not set the key: {conn_data.key}!"
)

for key, value in values.items():
for key, value in chain(config.items(), kwargs.items()):
# Perform metadata extraction process on self.
metadata = self.conns.extract_metadata(key)

# Perform validity check and updates on model.
if (conn_data := model.conns.get_con_by_metadata(metadata)) is None:
raise KeyError("Requires valid key or Connection to set values!")

updates |= model._set_value(conn_data, value)

# Solve constraints with the updated values.
model.constraint_solver(updates)

def set_types(
self,
types: Mapping[str | Connection, type | UnionType]
config: Mapping[str | Connection, type | UnionType]
| Mapping[Connection, type | UnionType]
| Mapping[str, type | UnionType],
| Mapping[str, type | UnionType]
| None = None,
**kwargs: type | UnionType,
):
"""
Set types of any connection in the Model
Expand All @@ -312,33 +354,19 @@ def set_types(
of type of type or UnionType objects.
"""
# get the outermost parent as all the updates will happen here
if config is None:
config = {}

# Get the outermost parent as all the updates will happen here.
model = self._get_outermost_parent()
updates = Updates()
for key, key_type in types.items():
for key, key_type in chain(config.items(), kwargs.items()):
metadata = self.conns.extract_metadata(key)
data = metadata.data
updates |= data.set_type(key_type)
# run the constraints for updating affected connections
# Run the constraints for updating affected connections.
model.constraint_solver(updates)

def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates:
"""
Set value for the given connection.
Args:
key (str | Connection): Connection key or Connection object to set value.
value (MainValueType): Value to set for the given connection.
Raises:
KeyError: If the provided key is not a valid IO key.
"""

if key.key not in self.conns.input_keys:
raise KeyError("Internal or output keys' values cannot be set.")
# Data is scalar, set the value directly.
return key.metadata.data.set_value(value)

def get_shapes(
self, uni_keys=None, var_keys=None, symbolic=True, verbose=False
) -> _ShapesType:
Expand Down
6 changes: 3 additions & 3 deletions mithril/utils/dict_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def shape_to_dict(shapes):
if isinstance(item, tuple): # variadic
shape_list.append(f"{item[0]},...")
else:
shape_list.append(str(item))
shape_list.append(item)
shape_dict[key] = shape_list
return shape_dict

Expand All @@ -559,10 +559,10 @@ def dict_to_shape(shape_dict):
for key, shape_list in shape_dict.items():
shapes[key] = []
for shape in shape_list:
if "..." in shape:
if isinstance(shape, str) and "..." in shape:
shapes[key].append((shape.split(",")[0], ...))
else:
shapes[key].append(int(shape))
shapes[key].append(shape)

return shapes

Expand Down
24 changes: 11 additions & 13 deletions tests/scripts/test_constant_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,13 +1752,12 @@ def test_unused_cached_values_1_set_values():
model = Model()
linear_model = Linear(dimension=2)
model += linear_model()
model.set_values(
{
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
linear_model.input: [[3.0], [2.0]],
}
)
config: dict[Connection, list] = {
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
linear_model.input: [[3.0], [2.0]],
}
model.set_values(config)
comp_model = mithril.compile(model=model, backend=(backend := NumpyBackend()))
dtype = backend.get_backend_array_type()
cache = comp_model.data_store.data_values
Expand Down Expand Up @@ -1823,12 +1822,11 @@ def test_unused_cached_values_2_set_values():
model = Model()
linear_model = Linear(dimension=2)
model += linear_model()
model.set_values(
{
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
}
)
config: dict[Connection, list] = {
linear_model.w: [[1.0, 2.0]],
linear_model.b: [3.0, 1.0],
}
model.set_values(config)
comp_model = mithril.compile(
model=model, backend=(backend := NumpyBackend()), safe_names=False
)
Expand Down
47 changes: 47 additions & 0 deletions tests/scripts/test_model_to_dict_rtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import re

import pytest

import mithril
from mithril import JaxBackend, TorchBackend
from mithril.framework.common import TBD, IOKey
Expand Down Expand Up @@ -58,6 +60,51 @@ def test_linear_expose():
)


def test_linear_expose_set_shapes():
model = Model()
lin_1 = Linear()
lin_2 = Linear()
model += lin_1(input="input", w="w")
model += lin_2(input=lin_1.output, w="w1", output=IOKey(name="output2"))
model.set_shapes({lin_1.b: [42]})
model.set_shapes({lin_2.b: [21]})
model_dict_created = dict_conversions.model_to_dict(model)
model_recreated = dict_conversions.dict_to_model(model_dict_created)
model_dict_recreated = dict_conversions.model_to_dict(model_recreated)

assert model_dict_created == model_dict_recreated
assert model.shapes == model_recreated.shapes
assert_models_equal(model, model_recreated)

backend = JaxBackend(precision=64)
assert_evaluations_equal(
model, model_recreated, backend, static_keys={"input": backend.ones([4, 256])}
)


@pytest.mark.skip(reason="Dict conversion does not support extend from inputs yet")
def test_linear_expose_set_shapes_extend_from_inputs():
model = Model()
lin_1 = Linear()
lin_2 = Linear()
model += lin_2(w="w1", output=IOKey(name="output2"))
model += lin_1(input="input", w="w", output=lin_2.input)
model.set_shapes({lin_1.b: [42]})
model.set_shapes({lin_2.b: [21]})
model_dict_created = dict_conversions.model_to_dict(model)
model_recreated = dict_conversions.dict_to_model(model_dict_created)
model_dict_recreated = dict_conversions.model_to_dict(model_recreated)

assert model_dict_created == model_dict_recreated
assert model.shapes == model_recreated.shapes
assert_models_equal(model, model_recreated)

backend = JaxBackend(precision=64)
assert_evaluations_equal(
model, model_recreated, backend, static_keys={"input": backend.ones([4, 256])}
)


def test_linear_set_diff():
model = Model()
linear = Linear(dimension=42)
Expand Down
Loading

0 comments on commit b98d025

Please sign in to comment.