Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add string support to IOKey values #11

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pr-label-and-assign.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ jobs:

- name: Check out the repository
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}

- id: pr-labeler # label the Pull Request based on change in files
uses: actions/labeler@v5
Expand Down
8 changes: 4 additions & 4 deletions mithril/framework/codegen/python_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def compute_evaluate(
self,
params: dict[str, DataType] | None = None,
data: dict[str, DataType] | None = None,
cache: dict[str, DataType | MainValueType] | None = None,
cache: dict[str, DataType | MainValueType | str] | None = None,
*,
fn: Callable,
):
Expand Down Expand Up @@ -508,7 +508,7 @@ def compute_gradients(
params: dict[str, DataType],
data: dict[str, DataType | MainValueType] | None = None,
output_gradients: dict[str, DataType] | None = None,
cache: Mapping[str, DataType | MainValueType] | None = None,
cache: Mapping[str, DataType | MainValueType | str] | None = None,
include_output: bool = False,
*,
raw_evaluate_fn: Callable,
Expand Down Expand Up @@ -577,8 +577,8 @@ def compute_gradients(
def filter_ignored_outputs(
self,
params: dict[str, DataType] | None = None,
data: Mapping[str, MainValueType | DataType] | None = None,
cache: Mapping[str, MainValueType | DataType] | None = None,
data: Mapping[str, MainValueType | DataType | str] | None = None,
cache: Mapping[str, MainValueType | DataType | str] | None = None,
ignore_grad_keys=None,
*,
raw_evaluate_fn: Callable,
Expand Down
6 changes: 4 additions & 2 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class KeyType(Enum):
| Constant
| tuple[int | None, ...]
| dict
| str
)
type DeferredValueType = (
float | tuple[float, ...] | list[float] | EllipsisType | ToBeDetermined
Expand All @@ -178,6 +179,7 @@ class KeyType(Enum):
| type[PaddingType]
| type[EllipsisType]
| type[ToBeDetermined]
| type[str]
| NestedListType
| type[None]
| UnionType
Expand Down Expand Up @@ -791,7 +793,7 @@ class Scalar(BaseData):
def __init__(
self,
possible_types: ScalarType | UnionType | None = None,
value: MainValueType | ToBeDetermined = TBD,
value: MainValueType | ToBeDetermined | str = TBD,
) -> None:
if possible_types is None:
if isinstance(value, ToBeDetermined):
Expand Down Expand Up @@ -1137,7 +1139,7 @@ class IOKey(TemplateBase):
def __init__(
self,
name: str | None = None,
value: MainValueType | ToBeDetermined = TBD,
value: MainValueType | ToBeDetermined | str = TBD,
shape: ShapeTemplateType | None = None,
type: UnionType | type | None = None,
expose: bool = True,
Expand Down
10 changes: 6 additions & 4 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,10 +1758,12 @@ def reverse_constraints(
raise ValueError("Shape mismatch in Transpose model")
status = True

elif isinstance(axes_val, int | Sequence):
axes_val = [axes_val] if isinstance(axes_val, int) else axes_val
in_unis = [Uniadic() for idx in range(len(axes_val))]
out_unis = [in_unis[axis] for axis in axes_val]
elif isinstance(axes_val, int | tuple | list):
a_val: list[int] | tuple[int, ...] = (
[axes_val] if isinstance(axes_val, int) else axes_val
)
in_unis = [Uniadic() for idx in range(len(a_val))]
out_unis = [in_unis[axis] for axis in a_val]

updates |= input_shape._update_uniadics(input_shape.prefix, in_unis)
updates |= input_shape._update_uniadics(input_shape.reverse, in_unis[::-1])
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _set_shapes(
def set_shapes(self, shapes: ShapesType) -> None:
self._set_shapes(shapes, trace=True)

def set_values(self, values: dict[str | Connection, MainValueType]) -> None:
def set_values(self, values: dict[str | Connection, MainValueType | str]) -> None:
"""
Set multiple values in the model.
This method updates values in the outermost model by traversing up the
Expand Down
4 changes: 2 additions & 2 deletions mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,12 @@ def handle_auto_conversion(

# Create data object based on given_value or given key_type.
if is_value_given:
assert isinstance(set_value, MainValueType)
assert isinstance(set_value, MainValueType | str)
data = Scalar(value=set_value)

elif key_type == Scalar:
if set_type is None:
set_type = MainValueType
set_type = MainValueType | type[str]
data = Scalar(possible_types=set_type)

else:
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/physical/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self._runtime_static_keys: set[str] = set()
self._unused_keys: set[str] = set()
# Final tensor values of data store.
self.data_values: dict[str, DataType | MainValueType] = dict()
self.data_values: dict[str, DataType | MainValueType | str] = dict()
self.constraint_solver: ConstraintSolver = deepcopy(solver, memo=memo)

@property
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/physical/flat_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _is_duplicate(
continue

# Extract value from data or static_keys
value: DataType | MainValueType | ToBeDetermined
value: DataType | MainValueType | ToBeDetermined | str
if conn.key in data and data[conn.key].value is not TBD:
value = data[conn.key].value
else:
Expand Down
167 changes: 167 additions & 0 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
from mithril.core import Constant, epsilon_table
from mithril.framework.common import (
NOT_AVAILABLE,
NOT_GIVEN,
TBD,
ConnectionType,
NotAvailable,
ToBeDetermined,
UniadicRecord,
Variadic,
create_shape_map,
Expand Down Expand Up @@ -7012,3 +7015,167 @@ def test_output_keys_canonical_output_2():
model2 += model(output=IOKey("output"))

assert set(model2.output_keys) == set(["output", "#canonical_output"])


def test_string_iokey_value_1():
# This tes tests if string value given in init
# is working properly

# For this Purpose, dummy einsum primitive is introduced
# since it has a string input

# This test comprises four steps:
# 1. Register Einsum Primitive
# 2. Create a model that uses Einsum Primitive and compile it
# 3. Evaluate the model
# 4. Compare the results

import torch

backend = TorchBackend()

# Define einsum primitive fn
def einsum(input, equation):
return torch.einsum(equation, input)

# Define einsum primitive Model
class ReduceEinsum(PrimitiveModel):
# Small Einsum Model that is written for test purposes.
# Now it only supports single input and single output

def __init__(self, equation: str | ToBeDetermined) -> None:
if not isinstance(equation, ToBeDetermined):
# Parse the equation
input, output = equation.replace(" ", "").split("->")
# Parse the shapes
all_input_shapes = list(input)
all_output_shapes = list(output)
# Create TensorType and Scalar Inputs
# Note that equation is string
tensor_input = TensorType(all_input_shapes)
tensor_output = TensorType(all_output_shapes)
scalar_equation = Scalar(str, equation)

else:
# case where equation is TBD
tensor_input = TensorType([("Var1", ...)])
tensor_output = TensorType([("Var2", ...)])
scalar_equation = Scalar(str)

kwargs: dict[str, TensorType | Scalar] = {
"output": tensor_output,
"input": tensor_input,
"equation": scalar_equation,
}

super().__init__(formula_key="einsum", **kwargs)
self._freeze()

def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
equation: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(input=input, equation=equation, output=output)

TorchBackend.register_primitive(einsum)

# create the model and add einsum
model = Model()

# note that string input is given in __init__
a = ReduceEinsum(equation=TBD)(
input="input", equation=IOKey(value="ij->i"), output="output"
)
model += a

# Compile the model and assert the results
pm = mithril.compile(model=model, backend=backend)
input = backend.ones((7, 6))
trainable_keys = {"input": input}
outputs = pm.evaluate(trainable_keys)
ref_outputs = {"output": backend.ones(7) * 6}
assert_results_equal(outputs, ref_outputs)


def test_string_iokey_value_2():
# This tes tests if string value handling of
# IOKey is working properly.

# For this Purpose, Dumy Einsum Primitive is introduced
# since it has a string input

# This test comprises four steps:
# 1. Register Einsum Primitive
# 2. Create a model that uses Einsum Primitive and compile it
# 3. Evaluate the model
# 4. Compare the results

import torch

backend = TorchBackend()

# Define einsum primitive fn
def einsum(input, equation):
return torch.einsum(equation, input)

# Define einsum primitive Model
class ReduceEinsum(PrimitiveModel):
# Small Einsum Model that is written for test purposes.
# Now it only supports single input and single output

def __init__(self, equation: str | ToBeDetermined) -> None:
if not isinstance(equation, ToBeDetermined):
# Parse the equation
input, output = equation.replace(" ", "").split("->")
# Parse the shapes
all_input_shapes = list(input)
all_output_shapes = list(output)
# Create TensorType and Scalar Inputs
# Note that equation is string
tensor_input = TensorType(all_input_shapes)
tensor_output = TensorType(all_output_shapes)
scalar_equation = Scalar(str, equation)

else:
# case where equation is TBD
tensor_input = TensorType([("Var1", ...)])
tensor_output = TensorType([("Var2", ...)])
scalar_equation = Scalar(str)

kwargs: dict[str, TensorType | Scalar] = {
"output": tensor_output,
"input": tensor_input,
"equation": scalar_equation,
}

super().__init__(formula_key="einsum", **kwargs)
self._freeze()

def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
equation: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(input=input, equation=equation, output=output)

TorchBackend.register_primitive(einsum)

# create the model and add einsum
model = Model()

# note that in __init__, equation is TBD and string is given as IOKey value
a = ReduceEinsum(equation=TBD)(
input="input", equation=IOKey(value="ij->i"), output="output"
)
model += a

# Compile the model and assert the results
pm = mithril.compile(model=model, backend=backend)
input = backend.ones((7, 6))
trainable_keys = {"input": input}
outputs = pm.evaluate(trainable_keys)
ref_outputs = {"output": backend.ones(7) * 6}
assert_results_equal(outputs, ref_outputs)
Loading