Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Generalize edge type polymorphism. #170

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions mithril/framework/codegen/numpy_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from ...backends.with_manualgrad.numpy_backend import NumpyBackend
from ...framework.physical.model import PhysicalModel
from ...framework.utils import find_intersection_type
from ...utils.func_utils import is_make_array_required, prepare_function_args
from ..common import (
DataEvalType,
Expand All @@ -33,7 +32,7 @@
IOHyperEdge,
LossKey,
ParamsEvalType,
Tensor,
find_intersection_type,
is_type_adjustment_required,
)
from ..logical import PrimitiveModel
Expand Down Expand Up @@ -318,7 +317,7 @@ def generate_evaluate_gradients(
key
for key in all_ignored_keys
if key in self.pm.data
and self.pm.data[key].edge_type is Tensor
and self.pm.data[key].is_tensor
and find_intersection_type(self.pm.data[key].value_type, float)
}

Expand Down
3 changes: 1 addition & 2 deletions mithril/framework/codegen/python_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
EvaluateGradientsType,
EvaluateType,
ParamsEvalType,
Tensor,
)
from ..logical import PrimitiveModel
from ..physical.model import PhysicalModel
Expand Down Expand Up @@ -301,7 +300,7 @@ def generate_imports(self) -> list[ast.stmt]:
def is_static_scalar(self, key: str) -> bool:
return (
key in self.pm.data_store.cached_data
and self.pm.data[key].edge_type != Tensor
and not self.pm.data[key].is_tensor
and self.pm.data[key].edge_type != Dtype
and not isinstance(self.pm.data_store.cached_data[key], enum.Enum)
)
Expand Down
Loading
Loading