-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: Update types of all models #222
base: main
Are you sure you want to change the base?
refactor: Update types of all models #222
Conversation
…Left and ShiftRight is made polymorphic, initial_indexer_type_constraint is updated
…into update-types-of-all-models
…are updated, tests are fixed
…into update-types-of-all-models
Codecov ReportAttention: Patch coverage is
@@ Coverage Diff @@
## main #222 +/- ##
=======================================
Coverage 89.30% 89.31%
=======================================
Files 68 68
Lines 16666 16578 -88
=======================================
- Hits 14884 14807 -77
+ Misses 1782 1771 -11
|
def robust_sqrt(input: jax.Array, cutoff: jax.Array) -> jax.Array: | ||
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = cutoff)) | ||
def robust_sqrt(input: jax.Array, threshold: jax.Array) -> jax.Array: | ||
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = threshold)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove unnecessary comments
def robust_log(input: jax.Array, cutoff: jax.Array) -> jax.Array: | ||
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = cutoff)) | ||
def robust_log(input: jax.Array, threshold: jax.Array) -> jax.Array: | ||
# v_mapped_func= jax.vmap(partial(robust_log_helper, threshold = threshold)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove unnecessary comments
@@ -142,6 +143,14 @@ def is_repr_known(repr: ShapeRepr) -> bool: | |||
return repr.root is None and all([uni.value is not None for uni in repr.prefix]) | |||
|
|||
|
|||
def two_add(input1: Any, input2: Any, input3: Any) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can convert this function to accept *args as input and sum them all for more generic purposes
input=BaseKey(shape=[("Var_2", ...)], type=Tensor, value=input), | ||
target=BaseKey(shape=[("Var_3", ...)], type=Tensor, value=target), | ||
cutoff=BaseKey(shape=[], type=Tensor, value=cutoff), | ||
input=BaseKey( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to set shapes explicitly if they are different Variadics. It is the default behaviour
@@ -890,6 +901,7 @@ def __init__( | |||
approximate=BaseKey(value=approximate, type=bool), | |||
name=name, | |||
input=input, | |||
output=BaseKey(shape=[("Var", ...)], type=Tensor[float]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to set shapes
name=name, | ||
axis=axis_key, | ||
input=input, | ||
output=BaseKey(shape=[("Var", ...)], type=Tensor[float]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to set shapes explicitly
slope=BaseKey( | ||
value=slope, type=int | float | bool | Tensor[int | float | bool] | ||
), | ||
output=BaseKey(shape=[("Var", ...)], type=Tensor[float]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to set shapes explicitly
@@ -2109,7 +2164,7 @@ def __init__( | |||
super().__init__( | |||
formula_key="randn", | |||
name=name, | |||
output=BaseKey(shape=[("output", ...)], type=Tensor), | |||
output=BaseKey(shape=[("output", ...)], type=Tensor[float]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to set shapes explicitly
We should be able to remove wrapping operation in PhysicalModel, this part:
If we can get rid of this, we can delete convert_data_to_physical too |
Description
Closes #221.
What is Changed
Include the changes introduced in this PR.
cutoff
->threshold
general_tensor_type_constraint
is replaced withgeneral_type_constraint
__call__()
carried to__init__()
Checklist: