Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update types of all models #222

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

mehmetozsoy-synnada
Copy link
Collaborator

Description

Closes #221.

What is Changed

Include the changes introduced in this PR.

  • cutoff -> threshold
  • types of all primitives are updated
  • all instances of general_tensor_type_constraint is replaced with general_type_constraint
  • values that are taken from __call__() carried to __init__()
  • tests are updated accordingly.

Checklist:

  • Tests that cover the code added.
  • Corresponding changes documented.
  • All tests passed.
  • The code linted and styled (pre-commit run --all-files has passed).

Copy link

codecov bot commented Mar 5, 2025

Codecov Report

Attention: Patch coverage is 92.85714% with 7 lines in your changes missing coverage. Please review.

Project coverage is 89.31%. Comparing base (8938a63) to head (0350620).

Files with missing lines Patch % Lines
mithril/models/primitives.py 87.80% 5 Missing ⚠️
mithril/cores/python/numpy/ops_grad.py 92.85% 1 Missing ⚠️
mithril/cores/python/torch/ops.py 87.50% 1 Missing ⚠️
@@           Coverage Diff           @@
##             main     #222   +/-   ##
=======================================
  Coverage   89.30%   89.31%           
=======================================
  Files          68       68           
  Lines       16666    16578   -88     
=======================================
- Hits        14884    14807   -77     
+ Misses       1782     1771   -11     
Files with missing lines Coverage Δ
mithril/cores/python/common_primitives.py 67.18% <100.00%> (+0.52%) ⬆️
mithril/cores/python/jax/ops.py 85.51% <100.00%> (-0.09%) ⬇️
mithril/cores/python/mlx/ops.py 86.40% <100.00%> (ø)
mithril/cores/python/numpy/ops.py 81.40% <100.00%> (ø)
mithril/framework/constraints.py 94.04% <100.00%> (+0.40%) ⬆️
mithril/framework/logical/operators.py 100.00% <100.00%> (ø)
mithril/models/train_model.py 91.88% <ø> (ø)
mithril/cores/python/numpy/ops_grad.py 88.12% <92.85%> (ø)
mithril/cores/python/torch/ops.py 80.73% <87.50%> (-0.11%) ⬇️
mithril/models/primitives.py 97.91% <87.80%> (-0.11%) ⬇️

... and 2 files with indirect coverage changes

def robust_sqrt(input: jax.Array, cutoff: jax.Array) -> jax.Array:
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = cutoff))
def robust_sqrt(input: jax.Array, threshold: jax.Array) -> jax.Array:
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = threshold))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove unnecessary comments

def robust_log(input: jax.Array, cutoff: jax.Array) -> jax.Array:
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = cutoff))
def robust_log(input: jax.Array, threshold: jax.Array) -> jax.Array:
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = threshold))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove unnecessary comments

@@ -142,6 +143,14 @@ def is_repr_known(repr: ShapeRepr) -> bool:
return repr.root is None and all([uni.value is not None for uni in repr.prefix])


def two_add(input1: Any, input2: Any, input3: Any) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can convert this function to accept *args as input and sum them all for more generic purposes

input=BaseKey(shape=[("Var_2", ...)], type=Tensor, value=input),
target=BaseKey(shape=[("Var_3", ...)], type=Tensor, value=target),
cutoff=BaseKey(shape=[], type=Tensor, value=cutoff),
input=BaseKey(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to set shapes explicitly if they are different Variadics. It is the default behaviour

@@ -890,6 +901,7 @@ def __init__(
approximate=BaseKey(value=approximate, type=bool),
name=name,
input=input,
output=BaseKey(shape=[("Var", ...)], type=Tensor[float]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to set shapes

name=name,
axis=axis_key,
input=input,
output=BaseKey(shape=[("Var", ...)], type=Tensor[float]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to set shapes explicitly

slope=BaseKey(
value=slope, type=int | float | bool | Tensor[int | float | bool]
),
output=BaseKey(shape=[("Var", ...)], type=Tensor[float]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to set shapes explicitly

@@ -2109,7 +2164,7 @@ def __init__(
super().__init__(
formula_key="randn",
name=name,
output=BaseKey(shape=[("output", ...)], type=Tensor),
output=BaseKey(shape=[("output", ...)], type=Tensor[float]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to set shapes explicitly

@kberat-synnada
Copy link
Collaborator

We should be able to remove wrapping operation in PhysicalModel, this part:

        constant_keys = {  # type: ignore
            key: StaticDataStore.convert_data_to_physical(value, backend)  # type: ignore
            for key, value in model().connections.items()
            if value is not NOT_GIVEN
        } | constant_keys

If we can get rid of this, we can delete convert_data_to_physical too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[OTHER] Update types of all models
3 participants