Skip to content

Commit

Permalink
feat!: Update Model Architecture with User-interacting Logic (synnada…
Browse files Browse the repository at this point in the history
  • Loading branch information
kberat-synnada authored and mehmetozsoy-synnada committed Feb 12, 2025
1 parent c80f1bf commit ad9dc9a
Show file tree
Hide file tree
Showing 50 changed files with 3,887 additions and 2,968 deletions.
3 changes: 1 addition & 2 deletions benchmarks/speed_benchmarks/speed_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from mithril.models import (
MLP,
BaseModel,
Convolution2D,
Flatten,
MaxPool2D,
Expand Down Expand Up @@ -84,7 +83,7 @@ def create_compl_conv(
def create_compl_mlp(
input_size: int,
dimensions: Sequence[int | None],
activations: list[type[BaseModel]],
activations: list[type[Model]],
):
"""Mithril's MLP wrapper with input size
Expand Down
7 changes: 4 additions & 3 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
short,
)
from .framework.codegen import code_gen_map
from .framework.common import TBD, Connection, IOKey
from .framework.common import TBD
from .framework.logical import Connection, IOKey
from .framework.physical.model import PhysicalConstantType, PhysicalShapeType
from .models import BaseModel, PhysicalModel
from .models import Model, PhysicalModel
from .models.train_model import TrainModel

__all__ = [
Expand Down Expand Up @@ -97,7 +98,7 @@


def compile(
model: BaseModel,
model: Model,
backend: Backend[DataType],
*,
constant_keys: PhysicalConstantType[DataType] | None = None,
Expand Down
23 changes: 11 additions & 12 deletions mithril/framework/codegen/numpy_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
find_intersection_type,
is_type_adjustment_required,
)
from ..logical import PrimitiveModel
from ..logical import Operator
from .python_gen import PythonCodeGen, RawGradientType
from .utils import check_repr_inequality

Expand Down Expand Up @@ -212,11 +212,11 @@ def evaluate_gradients_wrapper_manualgrad(

def get_primitive_details(
self, output_key: str
) -> tuple[PrimitiveModel, list[str], list[str]]:
) -> tuple[Operator, list[str], list[str]]:
model = self.pm.flat_graph.get_model(output_key)

global_input_keys = self.pm.flat_graph.get_source_keys(output_key)
global_input_keys += [self.get_cache_name(output_key, model)]
global_input_keys += [self.get_cache_name(output_key)]
local_input_keys = list(model.input_keys) + ["cache"]

return model, global_input_keys, local_input_keys
Expand All @@ -229,7 +229,7 @@ def is_static_scalar(self, key: str) -> bool:

def call_primitive(
self,
model: PrimitiveModel,
model: Operator,
fn: Callable[..., Any],
l_input_keys: list[str],
g_input_keys: list[str],
Expand Down Expand Up @@ -259,7 +259,7 @@ def call_primitive(
return ast.Assign(targets, generated_fn), used_keys | _used_keys

def create_primitive_call_targets(
self, output_key: str, model: PrimitiveModel, inference: bool
self, output_key: str, model: Operator, inference: bool
) -> tuple[list[ast.expr | ast.Name], set[str]]:
targets: list[ast.expr | ast.Name] = []

Expand All @@ -271,27 +271,26 @@ def create_primitive_call_targets(

if not self.pm.inference:
# TODO: Change this with cache refactor
cache_name = output_key + f"_{model.cache_name}"
cache_name = output_key + f"_{Operator.cache_name}"
used_keys.add(cache_name)
targets.append(
ast.Subscript(
value=ast.Name(id=cache_name, ctx=ast.Load()),
slice=ast.Constant(value=PrimitiveModel.output_key),
slice=ast.Constant(value=Operator.output_key),
ctx=ast.Store(),
)
)

return targets, used_keys

def get_cache_name(self, output_key: str, model: PrimitiveModel) -> str:
cache_name = "_".join([output_key, model.cache_name])
def get_cache_name(self, output_key: str) -> str:
cache_name = "_".join([output_key, Operator.cache_name])
if cache_name not in self.pm.flat_graph.all_data:
self.add_cache(model, output_key)
self.add_cache(output_key, cache_name)

return cache_name

def add_cache(self, model: PrimitiveModel, output_key: str) -> None:
cache_name = "_".join([output_key, model.cache_name])
def add_cache(self, output_key: str, cache_name: str) -> None:
cache_value: dict[str, Any] | None = None if self.pm.inference else {}
# Create a scalar for caches in manualgrad backend.
self.pm.flat_graph.update_data(
Expand Down
9 changes: 4 additions & 5 deletions mithril/framework/codegen/python_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
EvaluateType,
ParamsEvalType,
)
from ..logical import PrimitiveModel
from ..logical import Operator
from ..physical.model import PhysicalModel
from ..utils import GeneratedFunction
from .code_gen import CodeGen
Expand Down Expand Up @@ -307,7 +307,7 @@ def is_static_scalar(self, key: str) -> bool:

def get_primitive_details(
self, output_key: str
) -> tuple[PrimitiveModel, list[str], list[str]]:
) -> tuple[Operator, list[str], list[str]]:
model = self.pm.flat_graph.get_model(output_key)

global_input_keys = self.pm.flat_graph.get_source_keys(output_key)
Expand All @@ -317,7 +317,7 @@ def get_primitive_details(

def call_primitive(
self,
model: PrimitiveModel,
model: Operator,
fn: Callable[..., Any],
l_input_keys: list[str],
g_input_keys: list[str],
Expand Down Expand Up @@ -354,7 +354,6 @@ def generate_evaluate(self) -> ast.FunctionDef:
for output_key in self.pm.flat_graph.topological_order:
model, g_input_keys, l_input_keys = self.get_primitive_details(output_key)
formula_key = model.formula_key

primitive_function = (
self.pm.backend.primitive_function_dict[formula_key]
if formula_key in self.pm.backend.primitive_function_dict
Expand Down Expand Up @@ -568,7 +567,7 @@ def create_primitive_call(
return generated_fn, used_keys

def create_primitive_call_targets(
self, output_key: str, model: PrimitiveModel, inference: bool
self, output_key: str, model: Operator, inference: bool
) -> tuple[list[ast.expr], set[str]]:
if (
keyword.iskeyword(output_key)
Expand Down
4 changes: 2 additions & 2 deletions mithril/framework/codegen/torch_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from ...backends.with_autograd.torch_backend import TorchBackend
from ..logical import PrimitiveModel
from ..logical import Operator
from ..physical.model import PhysicalModel
from .python_gen import PythonCodeGen

Expand All @@ -34,7 +34,7 @@ def __init__(self, pm: PhysicalModel[torch.Tensor]) -> None:

def call_primitive(
self,
model: PrimitiveModel,
model: Operator,
fn: Callable[..., Any],
l_input_keys: list[str],
g_input_keys: list[str],
Expand Down
Loading

0 comments on commit ad9dc9a

Please sign in to comment.