Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Update Model Architecture with User-interacting Logic #174

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarks/speed_benchmarks/speed_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from mithril.models import (
MLP,
BaseModel,
Convolution2D,
Flatten,
MaxPool2D,
Expand Down Expand Up @@ -84,7 +83,7 @@ def create_compl_conv(
def create_compl_mlp(
input_size: int,
dimensions: Sequence[int | None],
activations: list[type[BaseModel]],
activations: list[type[Model]],
):
"""Mithril's MLP wrapper with input size

Expand Down
7 changes: 4 additions & 3 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
short,
)
from .framework.codegen import code_gen_map
from .framework.common import TBD, Connection, IOKey
from .framework.common import TBD
from .framework.logical import Connection, IOKey
from .framework.physical.model import PhysicalConstantType, PhysicalShapeType
from .models import BaseModel, PhysicalModel
from .models import Model, PhysicalModel
from .models.train_model import TrainModel

__all__ = [
Expand Down Expand Up @@ -97,7 +98,7 @@


def compile(
model: BaseModel,
model: Model,
backend: Backend[DataType],
*,
constant_keys: PhysicalConstantType[DataType] | None = None,
Expand Down
6 changes: 3 additions & 3 deletions mithril/framework/codegen/numpy_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@

if not self.pm.inference:
# TODO: Change this with cache refactor
cache_name = output_key + f"_{model.cache_name}"
cache_name = output_key + f"_{PrimitiveModel.cache_name}"
used_keys.add(cache_name)
targets.append(
ast.Subscript(
Expand All @@ -284,14 +284,14 @@
return targets, used_keys

def get_cache_name(self, output_key: str, model: PrimitiveModel) -> str:
cache_name = "_".join([output_key, model.cache_name])
cache_name = "_".join([output_key, PrimitiveModel.cache_name])
if cache_name not in self.pm.data_store._all_data:
self.add_cache(model, output_key)

return cache_name

def add_cache(self, model: PrimitiveModel, output_key: str) -> None:
cache_name = "_".join([output_key, model.cache_name])
cache_name = "_".join([output_key, PrimitiveModel.cache_name])

Check warning on line 294 in mithril/framework/codegen/numpy_gen.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/codegen/numpy_gen.py#L294

Added line #L294 was not covered by tests
cache_value: dict[str, Any] | None = None if self.pm.inference else {}
# Create a scalar for caches in manualgrad backend.
self.pm.data_store.update_data(
Expand Down
1 change: 0 additions & 1 deletion mithril/framework/codegen/python_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def generate_evaluate(self) -> ast.FunctionDef:

model, g_input_keys, l_input_keys = self.get_primitive_details(output_key)
formula_key = model.formula_key

primitive_function = (
self.pm.backend.primitive_function_dict[formula_key]
if formula_key in self.pm.backend.primitive_function_dict
Expand Down
Loading