Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Remove factory inputs #130

Merged
merged 5 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def array(
_dtype = utils.determine_dtype(input, dtype, self.precision)

with jax.default_device(self.device):
array = jax.numpy.array(
input, dtype=utils.dtype_map[_dtype], device=self.device
)
array = jax.numpy.array(input, dtype=utils.dtype_map[_dtype])

if self._parallel_manager is not None:
array = self._parallel_manager.parallelize(array, device_mesh)
Expand Down
2 changes: 1 addition & 1 deletion mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def match(self, other: BaseData[T]) -> Updates:
self.differentiable = other.differentiable = is_diff

if self.is_valued or other.is_valued:
valued, non_valued = (self, other) if self.is_valued else (other, self)
valued, non_valued = (other, self) if other.is_valued else (self, other)
updates |= non_valued.set_value(valued.value)
if non_valued is other:
if other.is_tensor:
Expand Down
45 changes: 0 additions & 45 deletions mithril/framework/logical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,14 @@
from ...utils.utils import OrderedSet
from ..common import (
NOT_AVAILABLE,
NOT_GIVEN,
TBD,
Connection,
ConnectionData,
Connections,
ConnectionType,
Constraint,
ConstraintFunctionType,
ConstraintSolver,
ExtendTemplate,
IOHyperEdge,
IOKey,
MainValueInstance,
MainValueType,
NestedListType,
NotAvailable,
Expand All @@ -46,8 +41,6 @@
ShapeTemplateType,
ShapeType,
Tensor,
TensorValueType,
ToBeDetermined,
UniadicRecord,
Updates,
UpdateType,
Expand Down Expand Up @@ -95,50 +88,12 @@ class BaseModel(abc.ABC):
factory_args: dict[str, Any] = {}

def __call__(self, **kwargs: ConnectionType) -> ExtendInfo:
for key, val in self.factory_inputs.items():
if val is not TBD:
if key not in kwargs or (con := kwargs[key]) is NOT_GIVEN:
kwargs[key] = val # type: ignore
continue
match con:
case Connection():
kwargs[key] = IOKey(value=val, connections={con})
# TODO: Maybe we could check con's value if matches with val
case item if isinstance(item, MainValueInstance) and con != val:
raise ValueError(
f"Given value {con} for local key: '{key}' "
f"has already being set to {val}!"
)
case str():
kwargs[key] = IOKey(name=con, value=val, expose=False)
case IOKey():
if con.data.value is not TBD and con.data.value != val:
raise ValueError(
f"Given IOKey for local key: '{key}' is not valid!"
)
else:
kwargs[key] = IOKey(
name=con.name,
expose=con.expose,
connections=con.connections,
type=con.data.type,
shape=con.data.shape,
value=val,
)
case ExtendTemplate():
raise ValueError(
"Multi-write detected for a valued "
f"local key: '{key}' is not valid!"
)
return ExtendInfo(self, kwargs)

def __init__(self, name: str | None = None, enforce_jit: bool = True) -> None:
self.parent: BaseModel | None = (
None # TODO: maybe set it only to PrimitiveModel / Model.
)
self.factory_inputs: dict[
str, TensorValueType | MainValueType | ToBeDetermined
] = {}
self.assigned_shapes: list[ShapesType] = []
self.assigned_constraints: list[dict[str, str | list[str]]] = []
self.conns = Connections()
Expand Down
Loading
Loading