Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Unify PrimitiveSlice and Scalaritem #128

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5f9193c
feat: Now People will be assigned based on applied labels, github act…
mehmetozsoy-synnada Nov 14, 2024
8309838
feat: Now actions bot requests a change in the case of tests failed
mehmetozsoy-synnada Nov 14, 2024
aea5265
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 19, 2024
5916d8f
fix: bug fix attempt 1
mehmetozsoy-synnada Nov 21, 2024
f3e9083
fix: resolve conflicts
mehmetozsoy-synnada Nov 22, 2024
6365adb
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 27, 2024
78ccbfa
merged with upstream
mehmetozsoy-synnada Dec 2, 2024
7c2ce69
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 4, 2024
74ce570
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 5, 2024
7835062
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 12, 2024
6533572
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 17, 2024
202366e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 20, 2024
8e47414
slice model is added to mithril
mehmetozsoy-synnada Dec 20, 2024
85ca767
TensorSlice model is removed
mehmetozsoy-synnada Dec 23, 2024
fea39af
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 25, 2024
98687f9
merged with main
mehmetozsoy-synnada Dec 25, 2024
4658244
minor bug fixed in physical shape
mehmetozsoy-synnada Dec 25, 2024
4ebe35a
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 26, 2024
6a8a9d4
tests are added for item + slice
mehmetozsoy-synnada Dec 26, 2024
4bc57d2
merged with upstream
mehmetozsoy-synnada Dec 26, 2024
a10bfbe
resolve conflicts
mehmetozsoy-synnada Dec 26, 2024
d935a43
PrimitiveSlice is removed
mehmetozsoy-synnada Dec 27, 2024
796a447
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
d767d0f
merged with main
mehmetozsoy-synnada Dec 27, 2024
7293b6e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
5fac647
Merge branch 'main' into add-slice-support-to-item-models
kberat-synnada Dec 27, 2024
408cddf
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
f4c0597
Merge branch 'main' of https://github.com/mehmetozsoy-synnada/mithril…
mehmetozsoy-synnada Dec 27, 2024
a50e111
commented tests about PrimitiveSlice are adited and re-opened
mehmetozsoy-synnada Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions mithril/framework/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,19 +869,41 @@ def __getitem__(
key: slice
| int
| EllipsisType
| tuple[slice | int | None | EllipsisType, ...]
| tuple[slice | int | None | EllipsisType | TemplateBase, ...]
| IOKey
| TemplateBase
| None,
):
if key is ...:
key = slice(None)
if isinstance(key, slice):
start, stop, step = key.start, key.stop, key.step
return ExtendTemplate(connections=[self, start, stop, step], model="slice")
elif isinstance(key, int | tuple):
return ExtendTemplate(connections=[self, key], model="get_item")
else:
raise TypeError(f"Unsupported key type: {type(key)}")
match key:
case slice():
slice_output = ExtendTemplate(
connections=[key.start, key.stop, key.step], model="slice"
)
output = ExtendTemplate(connections=[self, slice_output], model="index")

case int() | EllipsisType() | None:
output = ExtendTemplate(connections=[self, key], model="index")

case tuple():
connections: list[TemplateBase | int | None | EllipsisType] = []
for item in key:
if isinstance(item, slice):
slice_output = ExtendTemplate(
connections=[item.start, item.stop, item.step],
model="slice",
)
connections.append(slice_output)
else:
connections.append(item)
tuple_template = ExtendTemplate(
connections=connections, # type: ignore
model="to_tuple",
defaults={"n": len(key)},
)
output = ExtendTemplate(
connections=[self, tuple_template], model="index"
)
return output

def __add__(self, other: TemplateConnectionType):
return ExtendTemplate(connections=[self, other], model="add")
Expand Down Expand Up @@ -915,12 +937,12 @@ def __rfloordiv__(self, other: TemplateConnectionType):

def __pow__(self, other: TemplateConnectionType):
return ExtendTemplate(
connections=[self, other], model="pow", defaults={"robust", "threshold"}
connections=[self, other], model="pow", defaults={"robust": False}
)

def __rpow__(self, other: TemplateConnectionType):
return ExtendTemplate(
connections=[other, self], model="pow", defaults={"robust", "threshold"}
connections=[other, self], model="pow", defaults={"robust": False}
)

def __matmul__(self, other: TemplateConnectionType):
Expand Down Expand Up @@ -1070,7 +1092,7 @@ def var(

def sqrt(self):
return ExtendTemplate(
connections=[self], model="sqrt", defaults={"robust", "cutoff"}
connections=[self], model="sqrt", defaults={"robust": False}
)

def exp(self):
Expand All @@ -1093,7 +1115,7 @@ def __init__(
self,
connections: list[TemplateConnectionType],
model: str,
defaults: set[str] | None = None,
defaults: dict[str, Any] | None = None,
) -> None:
for connection in connections:
if isinstance(connection, str):
Expand All @@ -1105,7 +1127,7 @@ def __init__(
self.model = model

if defaults is None:
defaults = set()
defaults = {}
self.defaults = defaults
self.output_connection = None

Expand Down Expand Up @@ -1217,6 +1239,7 @@ def set_differentiable(self, differentiable: bool = True) -> None:
| int
| float
| list[int | float]
| EllipsisType
| tuple[slice | int | None | EllipsisType | TemplateBase, ...]
| None
)
Expand Down
85 changes: 59 additions & 26 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@


def scalar_item_type_constraint_forward_helper(
input_type: GenericAlias | UnionType | type, index_val: int | ToBeDetermined
input_type: GenericAlias | UnionType | type, index_val: int | slice | ToBeDetermined
) -> type | UnionType | GenericAlias:
# forward inference of scalar item type constraint:
# Examples:
Expand All @@ -393,34 +393,49 @@

new_type = input_type
if isinstance(input_type, GenericAlias):
if input_type.__origin__ is tuple:
origin = get_origin(input_type)
if origin is tuple:
if ... in input_type.__args__:
variadic_required = True
# if second value is ellipsis, directly take first value
# (tuple[int, ...] -> int)
new_type = input_type.__args__[0]
else:
# case when type of tuple is exact (ex: tuple[int, float])
if not isinstance(index_val, ToBeDetermined):
variadic_required = False
# if index val is specified, directly take the corresponding item
# of type
new_type = input_type.__args__[index_val]
if isinstance(index_val, int):
new_type = input_type.__args__[index_val]
else:
new_type = tuple[*input_type.__args__[index_val]] # type: ignore
else:
variadic_required = True
# if not specified this means it can be all of them,
# take union of all types inside tuple
new_type = create_union_type(*input_type.__args__)

elif input_type.__origin__ is list:
# if list, directly take first argument (list[int | float] -> int | float)
new_type = input_type.__args__[0]
if variadic_required and isinstance(index_val, slice):
new_type = tuple[new_type, ...] # type: ignore

elif origin is list:
if isinstance(index_val, slice):
new_type = input_type
else:
new_type = input_type.__args__[0]
elif input_type is list or input_type is tuple:
new_type = input_type | int | float | list
if isinstance(index_val, slice):
new_type = input_type

Check warning on line 429 in mithril/framework/constraints.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/constraints.py#L429

Added line #L429 was not covered by tests
else:
new_type = input_type | int | float | list

return new_type


def check_index_type_compatibility(
_type: type,
index: int | ToBeDetermined,
index: int | ToBeDetermined | slice,
is_variadic: bool,
raise_error: bool = False,
) -> bool:
Expand All @@ -431,7 +446,7 @@
and not is_variadic
):
args_len = len(_type.__args__)
if not (-args_len <= index <= args_len - 1):
if isinstance(index, int) and not (-args_len <= index <= args_len - 1):

Check warning on line 449 in mithril/framework/constraints.py

View check run for this annotation

Codecov / codecov/patch

mithril/framework/constraints.py#L449

Added line #L449 was not covered by tests
if raise_error:
raise TypeError(
f"Index value {index} is out of range for type {_type}!"
Expand All @@ -443,7 +458,7 @@
def scalar_item_reduce_input_type(
output_type: type | UnionType | GenericAlias,
input_type: type | UnionType | GenericAlias,
index: int | ToBeDetermined,
index: int | slice | ToBeDetermined,
):
possible_types = []
out_origin: type[list] | type[tuple] | type[UnionType] | None = get_origin(
Expand Down Expand Up @@ -490,19 +505,28 @@
input_origin, index, is_variadic, raise_error=True
):
if index == ... or input_origin is list:
for arg in input_type.__args__:
if find_intersection_type(output_type, arg):
return input_type
if isinstance(index, int):
for arg in input_type.__args__:
if find_intersection_type(output_type, arg):
return input_type
else:
return input_type

elif input_origin is tuple:
possible_types = [
arg if idx != index else find_intersection_type(arg, output_type)
for idx, arg in enumerate(input_type.__args__)
]
return (
input_origin[*possible_types, ...] # type: ignore
if is_variadic
else input_origin[*possible_types] # type: ignore
)
if isinstance(index, int):
possible_types = [
arg
if idx != index
else find_intersection_type(arg, output_type)
for idx, arg in enumerate(input_type.__args__)
]
return (
input_origin[*possible_types, ...] # type: ignore
if is_variadic
else input_origin[*possible_types] # type: ignore
)
else:
return input_type
else:
return input_type

Expand All @@ -514,7 +538,11 @@
input_type = input.type
output_type = output.type
index_value = index.value
assert isinstance(index_value, ToBeDetermined) or type(index_value) is int
assert (
isinstance(index_value, ToBeDetermined)
or type(index_value) is int
or type(index_value) is slice
)

if not (
isinstance(input_type, UnionType)
Expand Down Expand Up @@ -3289,7 +3317,11 @@
or type(input.value) is list
)

assert isinstance(index.value, ToBeDetermined) or type(index.value) is int
assert (
isinstance(index.value, ToBeDetermined)
or type(index.value) is int
or type(index.value) is slice
)

updates = Updates()
status = False
Expand All @@ -3299,7 +3331,9 @@
):
updates |= output.set_value(input.value[index.value])
status = True
elif not isinstance(input.value, ToBeDetermined) and isinstance(output.value, int):
elif not isinstance(input.value, ToBeDetermined) and isinstance(
output.value, int | float | bool
):
# Try to infer index value from input-output values. If
# output value appears only once in input sequence, write its
# index as the value of index argument.
Expand Down Expand Up @@ -3427,7 +3461,6 @@
input_shape.root,
(output_suffix + input_shape.reverse[idx_suf:])[::-1],
)

return status, updated_symbols


Expand Down
63 changes: 9 additions & 54 deletions mithril/framework/logical/essential_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
reverse_constraints,
scalar_item_constraints,
scalar_item_type_constraint,
scalar_slice_type_constraint,
shape_constraints,
size_constraints,
slice_constraints,
Expand Down Expand Up @@ -74,7 +73,6 @@
"Length",
"Size",
"Exponential",
"PrimitiveSlice",
"Item",
"ScalarItem",
"ToTensor",
Expand Down Expand Up @@ -147,10 +145,16 @@ def __init__(
) -> None:
self.factory_args = {"n": n}
key_definitions = {
"output": BaseKey(type=tuple[int | float | bool | list | tuple, ...])
"output": BaseKey(
type=tuple[
int | float | bool | list | tuple | slice | EllipsisType | None, ...
]
)
}
key_definitions |= {
f"input{idx+1}": BaseKey(type=int | float | bool | list | tuple)
f"input{idx+1}": BaseKey(
type=int | float | bool | list | tuple | slice | EllipsisType | None
)
for idx in range(n)
}
self.factory_inputs = kwargs # type: ignore
Expand Down Expand Up @@ -555,55 +559,6 @@ def __call__( # type: ignore[override]
return super().__call__(input=input, dim=dim, output=output)


class PrimitiveSlice(PrimitiveModel):
input: Connection
start: Connection
stop: Connection
step: Connection
output: Connection

def __init__(
self,
name: str | None = None,
start: int | None | ToBeDetermined = None,
stop: int | None | ToBeDetermined = None,
step: int | None | ToBeDetermined = None,
input: TensorValueType | ToBeDetermined = TBD,
) -> None:
self.factory_args = {"start": start, "stop": stop, "step": step}
super().__init__(
formula_key="sequence_slice",
name=name,
output=BaseKey(
type=tuple[int | float | bool, ...] | list[int | float | bool]
),
input=BaseKey(
type=tuple[int | float | bool, ...] | list[int | float | bool]
),
start=BaseKey(type=int | None, value=start),
stop=BaseKey(type=int | None, value=stop),
step=BaseKey(type=int | None, value=step),
)

self.factory_inputs = {"input": input}
self._set_constraint(
fn=scalar_slice_type_constraint,
keys=[PrimitiveModel.output_key, "input", "start", "stop", "step"],
)

def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
start: ConnectionType = NOT_GIVEN,
stop: ConnectionType = NOT_GIVEN,
step: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(
input=input, start=start, stop=stop, step=step, output=output
)


class Item(PrimitiveModel):
input: Connection
output: Connection
Expand Down Expand Up @@ -646,7 +601,7 @@ def __init__(
name=name,
output=BaseKey(type=int | float | list | tuple),
input=BaseKey(type=list | tuple),
index=BaseKey(type=int, value=index),
index=BaseKey(type=int | slice, value=index),
)
self.factory_inputs = {"input": input, "index": index}

Expand Down
Loading
Loading