Skip to content

Commit

Permalink
feat: Enable Connection type keys in compile (#13)
Browse files Browse the repository at this point in the history
Co-authored-by: kberat-synnada <97015093+kberat-synnada@users.noreply.github.com>
  • Loading branch information
norhan-synnada and kberat-synnada authored Nov 18, 2024
1 parent 2a98f88 commit 422fdcb
Show file tree
Hide file tree
Showing 17 changed files with 646 additions and 202 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ numpy_model = ml.compile(

# Compile different logical models with the same backend
other_model = Model()
other_model += Linear(dimension=32)
other_model += Linear(dimension=32)(input = "input")
jax_model1 = ml.compile(
model=other_model,
backend=backend_jax,
Expand Down
54 changes: 21 additions & 33 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import builtins
import platform
from collections.abc import Mapping, Sequence
from collections.abc import Iterable

from .backends.backend import Backend, UnavailableBackend
from .core import (
Expand All @@ -33,7 +33,8 @@
short,
)
from .framework.codegen import code_gen_map
from .framework.common import TBD, Connect, Connection, Constant, IOKey, MainValueType
from .framework.common import TBD, Connect, Connection, Constant, IOKey
from .framework.physical.model import PhysicalConstantType, PhysicalShapeType
from .models import BaseModel, PhysicalModel
from .models.train_model import TrainModel

Expand Down Expand Up @@ -96,19 +97,18 @@
def compile(
model: BaseModel,
backend: Backend[DataType],
*,
constant_keys: PhysicalConstantType | None = None,
data_keys: Iterable[str | Connection] | None = None,
discard_keys: Iterable[str | Connection] | None = None,
jacobian_keys: Iterable[str | Connection] | None = None,
trainable_keys: Iterable[str | Connection] | None = None,
shapes: PhysicalShapeType | None = None,
inference: builtins.bool = False,
discard_keys: set[str] | tuple[str] | list[str] | str | None = None,
jacobian_keys: set[str] | None = None,
shapes: Mapping[str, Sequence[builtins.int | None]]
| Mapping[Connection, Sequence[builtins.int | None]]
| Mapping[str | Connection, Sequence[builtins.int | None]]
| None = None,
data_keys: set[str] | None = None,
constant_keys: Mapping[str, DataType | MainValueType] | None = None,
trainable_keys: set[str] | None = None,
jit: builtins.bool = True,
file_path: str | None = None,
safe_shapes: builtins.bool = True,
safe_names: builtins.bool = True,
) -> PhysicalModel[DataType]:
"""Compilation of Logical Model.
Expand All @@ -134,29 +134,15 @@ def compile(
if model.parent is not None:
raise ValueError("Model with a parent could not be compiled!")

if discard_keys is None:
discard_keys = set()
elif isinstance(discard_keys, tuple | list):
discard_keys = set(discard_keys)
elif isinstance(discard_keys, str):
discard_keys = set([discard_keys])

if jacobian_keys is None:
jacobian_keys = set()
if shapes is None:
shapes = dict()
if data_keys is None:
data_keys = set()
if constant_keys is None:
constant_keys = dict()
if trainable_keys is None:
trainable_keys = set()

assert isinstance(discard_keys, set), (
f"Expected discard_keys to be of type 'set', but got type "
f"'{type(discard_keys).__name__}' instead."
)
# Convert keys to required types.
constant_keys = constant_keys if constant_keys is not None else dict()
data_keys = set(data_keys) if data_keys is not None else set()
discard_keys = set(discard_keys) if discard_keys is not None else set()
jacobian_keys = set(jacobian_keys) if jacobian_keys is not None else set()
shapes = shapes if shapes is not None else dict()
trainable_keys = set(trainable_keys) if trainable_keys is not None else set()

# Initialize Physical Model.
pm = PhysicalModel[DataType](
model=model,
backend=backend,
Expand All @@ -168,6 +154,7 @@ def compile(
shapes=shapes,
inference=inference,
safe_shapes=safe_shapes,
safe_names=safe_names,
)

if jit and file_path is not None:
Expand All @@ -177,6 +164,7 @@ def compile(
"'jit' flag to 'False'"
)

# Pick code generator based on backend and generate code.
CodeGen_Cls = code_gen_map[backend.__class__]
codegen = CodeGen_Cls(pm)
codegen.generate_code(file_path=file_path)
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ def __init__(self, key, metadata, is_key_autogenerated: bool):
self.data = ConnectionData(key, metadata, is_key_autogenerated, self)

@property
def key(self):
def key(self) -> str:
return self.data.key

@property
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _get_outermost_parent(self):

def _generate_keys(
self, symbolic=True, include_internals=True, include_outputs=False
):
) -> dict[str, str]:
return {}

def __setattr__(self, name: str, value: Any):
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,7 @@ def _update_key_name(

def _generate_keys(
self, symbolic=True, include_internals=True, include_outputs=False
):
) -> dict[str, str]:
key_mappings: dict[str, str] = {}
raw_keys: dict[str, list[str]] = {}
underscored_keys = set[str]()
Expand Down
Loading

0 comments on commit 422fdcb

Please sign in to comment.