Skip to content

Commit

Permalink
Add BaseKey and use BaseKey for primitive models, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kberat-synnada committed Dec 25, 2024
1 parent 9fefff3 commit b21b9a9
Show file tree
Hide file tree
Showing 13 changed files with 562 additions and 527 deletions.
15 changes: 15 additions & 0 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,15 @@ def __init__(
self.output_connection = None


@dataclass
class BaseKey:
value: TensorValueType | MainValueType | ToBeDetermined | str = TBD
shape: ShapeTemplateType | None = None
type: NestedListType | UnionType | type | None = None
interval: list[float | int] | None = None


@dataclass
class IOKey(TemplateBase):
def __init__(
self,
Expand Down Expand Up @@ -1141,6 +1150,12 @@ def __init__(
conn = item.data if isinstance(item, Connection) else item
self._connections.add(conn)

def __eq__(self, other: object):
if isinstance(other, int | float | bool | list | Connection | IOKey | tuple):
return ExtendTemplate(connections=[self, other], model="eq")
else:
raise ValueError("Unsupported type for equality operation.")

Check warning on line 1157 in mithril/framework/common.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/common.py#L1157

Added line #L1157 was not covered by tests


class Connection(TemplateBase):
def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool):
Expand Down
234 changes: 119 additions & 115 deletions mithril/framework/logical/essential_primitives.py

Large diffs are not rendered by default.

33 changes: 17 additions & 16 deletions mithril/framework/logical/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from ..common import (
NOT_AVAILABLE,
TBD,
BaseKey,
Connection,
IOHyperEdge,
IOKey,
KeyType,
NotAvailable,
Scalar,
Expand Down Expand Up @@ -54,17 +54,17 @@ def __init__(
self,
formula_key: str,
name: str | None = None,
**kwargs: IOKey | Tensor | Scalar,
**kwargs: BaseKey | Tensor | Scalar,
) -> None:
self._formula_key = formula_key
self.grad_formula = formula_key + "_grad"

super().__init__(name=name)
# Get shape_templates of TensorTypes and create corresponding shapes.
shape_templates = {
key: value._shape
key: value.shape
for key, value in kwargs.items()
if isinstance(value, IOKey) and value._shape is not None
if isinstance(value, BaseKey) and value.shape is not None
}
shapes = create_shape_map(shape_templates, self.constraint_solver)
data_set: set[Tensor] = set()
Expand All @@ -73,8 +73,9 @@ def __init__(
for key, value in kwargs.items():
# TODO: The first if block is temporary. All if else blocks will be
# removed after the implementation of the new type system.
if get_origin(value._type) is Union:
args = get_args(value._type)
value_type = value.type if isinstance(value, BaseKey) else value._type
if get_origin(value_type) is Union:
args = get_args(value_type)
types = []
for _type in args:
# TODO: assertion will be removed,
Expand All @@ -83,30 +84,30 @@ def __init__(
types.append(get_mytensor_subtype(_type))
possible_types = reduce(lambda x, y: x | y, types) # type: ignore

assert isinstance(value, IOKey)
assert isinstance(value, BaseKey)
_value: Tensor | Scalar = Tensor(
shape=shapes[key].node,
possible_types=possible_types,
value=value._value, # type: ignore
interval=value._interval,
value=value.value, # type: ignore
interval=value.interval,
)
assert isinstance(_value, Tensor)
data_set.add(_value)
elif is_mytensor_type(value._type):
assert isinstance(value, IOKey)
elif is_mytensor_type(value_type):
assert isinstance(value, BaseKey)
_value = Tensor(
shape=shapes[key].node,
possible_types=get_mytensor_subtype(value._type), # type: ignore
value=value._value, # type: ignore
interval=value._interval,
possible_types=get_mytensor_subtype(value_type), # type: ignore
value=value.value, # type: ignore
interval=value.interval,
)
data_set.add(_value)
elif isinstance(value, Tensor | Scalar):
_value = value
else:
_value = Scalar(
possible_types=value._type, # type: ignore
value=value._value, # type: ignore
possible_types=value_type, # type: ignore
value=value.value, # type: ignore
)

conn_data = self.create_connection(IOHyperEdge(_value), key)
Expand Down
Loading

0 comments on commit b21b9a9

Please sign in to comment.