Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify TensorItem and TensorSlice models #121

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5f9193c
feat: Now People will be assigned based on applied labels, github act…
mehmetozsoy-synnada Nov 14, 2024
8309838
feat: Now actions bot requests a change in the case of tests failed
mehmetozsoy-synnada Nov 14, 2024
aea5265
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 19, 2024
5916d8f
fix: bug fix attempt 1
mehmetozsoy-synnada Nov 21, 2024
f3e9083
fix: resolve conflicts
mehmetozsoy-synnada Nov 22, 2024
6365adb
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 27, 2024
78ccbfa
merged with upstream
mehmetozsoy-synnada Dec 2, 2024
7c2ce69
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 4, 2024
74ce570
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 5, 2024
7835062
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 12, 2024
6533572
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 17, 2024
202366e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 20, 2024
8e47414
slice model is added to mithril
mehmetozsoy-synnada Dec 20, 2024
85ca767
TensorSlice model is removed
mehmetozsoy-synnada Dec 23, 2024
fea39af
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 25, 2024
98687f9
merged with main
mehmetozsoy-synnada Dec 25, 2024
4658244
minor bug fixed in physical shape
mehmetozsoy-synnada Dec 25, 2024
4ebe35a
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 26, 2024
6a8a9d4
tests are added for item + slice
mehmetozsoy-synnada Dec 26, 2024
4bc57d2
merged with upstream
mehmetozsoy-synnada Dec 26, 2024
a10bfbe
resolve conflicts
mehmetozsoy-synnada Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions mithril/backends/with_autograd/common_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"transpose",
"swapaxes",
"square",
"tensor_slice",
"primitive_slice",
"buffer",
"permute_tensor",
Expand Down Expand Up @@ -160,12 +159,6 @@ def square(input: DataType):
return input * input


def tensor_slice(
input: DataType, start: int | None, stop: int | None, step: int | None
):
return input[start:stop:step]


def buffer(input: DataType):
return input

Expand Down
2 changes: 0 additions & 2 deletions mithril/backends/with_autograd/jax_backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
subtract,
swapaxes,
tensor_item,
tensor_slice,
to_list,
to_tuple,
transpose,
Expand Down Expand Up @@ -192,7 +191,6 @@
"transpose",
"swapaxes",
"square",
"tensor_slice",
"buffer",
"permute_tensor",
"reshape",
Expand Down
2 changes: 0 additions & 2 deletions mithril/backends/with_autograd/mlx_backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
subtract,
swapaxes,
tensor_item,
tensor_slice,
to_list,
to_tuple,
transpose,
Expand Down Expand Up @@ -164,7 +163,6 @@
"transpose",
"swapaxes",
"square",
"tensor_slice",
"buffer",
"permute_tensor",
"reshape",
Expand Down
2 changes: 0 additions & 2 deletions mithril/backends/with_autograd/torch_backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
subtract,
swapaxes,
tensor_item,
tensor_slice,
to_list,
to_tuple,
tuple_converter,
Expand Down Expand Up @@ -181,7 +180,6 @@
"transpose",
"swapaxes",
"square",
"tensor_slice",
"buffer",
"permute_tensor",
"reshape",
Expand Down
11 changes: 0 additions & 11 deletions mithril/backends/with_manualgrad/common_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
"squared_error",
"transpose",
"square",
"tensor_slice",
"buffer",
"permute_tensor",
"reshape",
Expand Down Expand Up @@ -165,16 +164,6 @@ def square(input: DataType, cache: CacheType = None):
return input * input


def tensor_slice(
input: DataType,
start: int | None,
stop: int | None,
step: int | None,
cache: CacheType = None,
):
return input[start:stop:step]


def buffer(input: DataType, cache: CacheType = None):
return input

Expand Down
2 changes: 0 additions & 2 deletions mithril/backends/with_manualgrad/numpy_backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
subtract,
swapaxes,
tensor_item,
tensor_slice,
to_list,
to_tuple,
transpose,
Expand Down Expand Up @@ -182,7 +181,6 @@
"squared_error",
"transpose",
"square",
"tensor_slice",
"buffer",
"permute_tensor",
"reshape",
Expand Down
14 changes: 0 additions & 14 deletions mithril/backends/with_manualgrad/numpy_backend/ops_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
"softplus_grad",
"gelu_grad",
"stop_gradient_grad",
"tensor_slice_grad",
"tensor_item_grad",
"permute_tensor_grad",
"transpose_grad",
Expand Down Expand Up @@ -810,19 +809,6 @@ def stop_gradient_grad(
return np.zeros_like(output_gradient)


def tensor_slice_grad(
output_gradient: np.ndarray[Any, Any],
cache: CacheType,
idx: int,
*inputs: np.ndarray[Any, Any],
) -> np.ndarray[Any, Any]:
verify_shapes(inputs, idx, non_differentiables=[1, 2, 3, 4])
input, start, stop, step = inputs
grad = np.zeros_like(input)
grad[start:stop:step] = output_gradient
return grad


def tensor_item_grad(
output_gradient: np.ndarray[Any, Any],
cache: CacheType,
Expand Down
31 changes: 0 additions & 31 deletions mithril/framework/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
"scalar_item_constraints",
"to_tuple_constraints",
"tensor_item_constraints",
"tensor_slice_constraints",
"tensor_to_list_type_constraint",
"reduce_type_constraint",
"type_constraints",
Expand Down Expand Up @@ -3466,36 +3465,6 @@ def tensor_item_constraint_helper(
return input_unis, output_unis, status, current_index


def tensor_slice_constraints(
output: Tensor, input: Tensor, start: Scalar, stop: Scalar, step: Scalar
) -> ConstrainResultType:
assert output._temp_shape is not None, "Output shape of TensorSlice is not set!"
assert input._temp_shape is not None, "Input shape of TensorSlice is not set!"
output_shape: ShapeRepr = output._temp_shape
input_shape: ShapeRepr = input._temp_shape
updated_symbols = Updates()
status = False

if input_shape.prefix and output_shape.prefix:
in_uni, out_uni = input_shape[0], output_shape[0]
if in_uni.value is not None and out_uni.value is not None:
status = True
else:
if (
start.value is not TBD
and stop.value is not TBD
and step.value is not TBD
and in_uni.value is not None
):
slc = slice(start.value, stop.value, step.value)
out_val = len(list(range(in_uni.value))[slc])
out_uni.set_value(out_val)
updated_symbols.add(out_uni)
status = True

return status, updated_symbols


def split_constraints(output: Tensor, input: Tensor, split_size: Scalar, axis: Scalar):
status = False
split_size_val = split_size.value
Expand Down
144 changes: 47 additions & 97 deletions mithril/framework/logical/essential_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
slice_constraints,
split_constraints,
tensor_item_constraints,
tensor_slice_constraints,
tensor_to_list_constraints,
tensor_to_list_type_constraint,
to_list_constraints,
Expand Down Expand Up @@ -101,7 +100,6 @@
"ShiftLeft",
"ShiftRight",
"TensorItem",
"TensorSlice",
"ArgMax",
"ArgMin",
"Cast",
Expand Down Expand Up @@ -606,54 +604,6 @@ def __call__( # type: ignore[override]
)


class TensorSlice(PrimitiveModel):
input: Connection
start: Connection
stop: Connection
step: Connection
output: Connection

def __init__(
self,
name: str | None = None,
start: int | None | ToBeDetermined = None,
stop: int | None | ToBeDetermined = None,
step: int | None | ToBeDetermined = None,
input: TensorValueType | ToBeDetermined = TBD,
) -> None:
self.factory_args = {"start": start, "stop": stop, "step": step}
super().__init__(
formula_key="tensor_slice",
name=name,
output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType),
input=BaseKey(shape=["b", ("Var1", ...)], type=GenericTensorType),
start=BaseKey(type=int | None, value=start),
stop=BaseKey(type=int | None, value=stop),
step=BaseKey(type=int | None, value=step),
)
self.factory_inputs = {"input": input}

self._set_constraint(
fn=tensor_slice_constraints,
keys=[PrimitiveModel.output_key, "input", "start", "stop", "step"],
)
self._set_constraint(
fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"]
)

def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
start: ConnectionType = NOT_GIVEN,
stop: ConnectionType = NOT_GIVEN,
step: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(
input=input, start=start, stop=stop, step=step, output=output
)


class Item(PrimitiveModel):
input: Connection
output: Connection
Expand Down Expand Up @@ -718,50 +668,6 @@ def __call__( # type: ignore[override]
return super().__call__(input=input, index=index, output=output)


class TensorItem(PrimitiveModel):
input: Connection
index: Connection
output: Connection

def __init__(
self,
name: str | None = None,
index: int | ToBeDetermined = TBD,
input: TensorValueType | ToBeDetermined = TBD,
) -> None:
super().__init__(
formula_key="tensor_item",
name=name,
output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType),
input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType),
index=BaseKey(
type=int
| slice
| EllipsisType
| None
| tuple[int | slice | EllipsisType | None, ...],
value=index,
),
)
self.factory_inputs = {"input": input, "index": index}

self._set_constraint(
fn=tensor_item_constraints,
keys=[PrimitiveModel.output_key, "input", "index"],
)
self._set_constraint(
fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"]
)

def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
index: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(input=input, index=index, output=output)


class ToTensor(PrimitiveModel):
input: Connection
output: Connection
Expand Down Expand Up @@ -1523,9 +1429,9 @@ class Slice(PrimitiveModel):

def __init__(
self,
start: int | None | ToBeDetermined = 0,
stop: int | None | ToBeDetermined = None,
step: int | None | ToBeDetermined = None,
start: int | None | ToBeDetermined = TBD,
stop: int | None | ToBeDetermined = TBD,
step: int | None | ToBeDetermined = TBD,
name: str | None = None,
):
super().__init__(
Expand All @@ -1550,3 +1456,47 @@ def __call__( # type: ignore[override]
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(start=start, stop=stop, step=step, output=output)


class TensorItem(PrimitiveModel):
input: Connection
index: Connection
output: Connection

def __init__(
self,
name: str | None = None,
index: int | ToBeDetermined = TBD,
input: TensorValueType | ToBeDetermined = TBD,
) -> None:
super().__init__(
formula_key="tensor_item",
name=name,
output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType),
input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType),
index=BaseKey(
type=int
| slice
| EllipsisType
| None
| tuple[int | slice | EllipsisType | None, ...],
value=index,
),
)
self.factory_inputs = {"input": input, "index": index}

self._set_constraint(
fn=tensor_item_constraints,
keys=[PrimitiveModel.output_key, "input", "index"],
)
self._set_constraint(
fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"]
)

def __call__( # type: ignore[override]
self,
input: ConnectionType = NOT_GIVEN,
index: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(input=input, index=index, output=output)
2 changes: 0 additions & 2 deletions mithril/framework/logical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
Subtract,
Sum,
TensorItem,
TensorSlice,
TensorToList,
ToList,
ToTensor,
Expand Down Expand Up @@ -143,7 +142,6 @@
coercion_table: dict[tuple[str, type[Tensor] | type[Scalar]], type[PrimitiveModel]] = {
("get_item", Tensor): TensorItem,
("get_item", Scalar): ScalarItem,
("slice", Tensor): TensorSlice,
("slice", Scalar): PrimitiveSlice,
}

Expand Down
Loading
Loading