Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Linted Physical folder #138

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
5f9193c
feat: Now People will be assigned based on applied labels, github act…
mehmetozsoy-synnada Nov 14, 2024
8309838
feat: Now actions bot requests a change in the case of tests failed
mehmetozsoy-synnada Nov 14, 2024
aea5265
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 19, 2024
5916d8f
fix: bug fix attempt 1
mehmetozsoy-synnada Nov 21, 2024
f3e9083
fix: resolve conflicts
mehmetozsoy-synnada Nov 22, 2024
dd6535b
style: fix some function annotations
mehmetozsoy-synnada Nov 24, 2024
913eeea
style: eye constraint bug is fixed, logical part of frameworks is fin…
mehmetozsoy-synnada Nov 25, 2024
f88a7bb
fix: pypi errors are fixed in logical, physical, common, constraints,…
mehmetozsoy-synnada Nov 27, 2024
6365adb
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Nov 27, 2024
80b038c
fix type annotations in specified folders
mehmetozsoy-synnada Nov 27, 2024
f43801b
fix: reviews are applied
mehmetozsoy-synnada Nov 28, 2024
d40da6b
apply reviews
mehmetozsoy-synnada Nov 28, 2024
4f83459
evaluate return type is updated
mehmetozsoy-synnada Dec 2, 2024
78ccbfa
merged with upstream
mehmetozsoy-synnada Dec 2, 2024
9764084
merged with main
mehmetozsoy-synnada Dec 2, 2024
7c2ce69
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 4, 2024
a5fd6ad
Merge branch 'main' of https://github.com/mehmetozsoy-synnada/mithril…
mehmetozsoy-synnada Dec 4, 2024
db37cf8
reviews are applied except ones that includes codegen
mehmetozsoy-synnada Dec 4, 2024
3be2f7d
chore: type annotations in codegen are fixed
mehmetozsoy-synnada Dec 5, 2024
74ce570
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 5, 2024
989a192
chore: Type annotations in codegen is fixed
mehmetozsoy-synnada Dec 5, 2024
4bf0184
overload decorator is added to evaluate_gradients_wrapper_manualgrad …
mehmetozsoy-synnada Dec 5, 2024
ac7051d
updated based on reviews
mehmetozsoy-synnada Dec 9, 2024
cdda878
CI pre-commit fix attempt
mehmetozsoy-synnada Dec 9, 2024
7c3c644
codegen typehints are fixed
mehmetozsoy-synnada Dec 9, 2024
8938418
codegen typehints are fixed
mehmetozsoy-synnada Dec 9, 2024
1d9ec33
codegen typehints are fixed
mehmetozsoy-synnada Dec 9, 2024
47f285b
numpy 2.2.0 support is added
mehmetozsoy-synnada Dec 9, 2024
1866707
updated based on reviews
mehmetozsoy-synnada Dec 9, 2024
5367463
Merge branch 'main' into solve-pypi-ignores
kberat-synnada Dec 9, 2024
0450bb2
type annotations in backend files are fixed partially
mehmetozsoy-synnada Dec 12, 2024
7835062
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 12, 2024
859a021
backend is typed
mehmetozsoy-synnada Dec 17, 2024
6533572
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 17, 2024
ca5d468
merged with dev
mehmetozsoy-synnada Dec 17, 2024
202366e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 20, 2024
fea39af
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 25, 2024
4ebe35a
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 26, 2024
7a3dadc
Merge branch 'main' of https://github.com/mehmetozsoy-synnada/mithril…
mehmetozsoy-synnada Dec 26, 2024
3ad4dfe
underscored attributes are edited
mehmetozsoy-synnada Dec 26, 2024
4bc57d2
merged with upstream
mehmetozsoy-synnada Dec 26, 2024
b3c8e28
underscored attributes are edited
mehmetozsoy-synnada Dec 26, 2024
c1d22b5
some missing changes are added
mehmetozsoy-synnada Dec 27, 2024
a05b88a
Merge branch 'main' into solve-pypi-ignores
kberat-synnada Dec 27, 2024
4b99c38
Update PhysicalModel _flat_graph attribute
kberat-synnada Dec 27, 2024
796a447
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
7293b6e
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
fbe3d00
resolved conflicts
mehmetozsoy-synnada Dec 27, 2024
1612113
Merge branch 'solve-pypi-ignores' of https://github.com/mehmetozsoy-s…
mehmetozsoy-synnada Dec 27, 2024
408cddf
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 27, 2024
73d8e0b
type errors are partially cleared
mehmetozsoy-synnada Dec 31, 2024
b32eecc
Merge remote-tracking branch 'upstream/main'
mehmetozsoy-synnada Dec 31, 2024
941b091
Jacobian codes are removed from codebase
mehmetozsoy-synnada Dec 31, 2024
82997d9
Physical folder strictly linted
mehmetozsoy-synnada Jan 2, 2025
4d48eda
Merge branch 'main' into solve-pypi-ignores
kberat-synnada Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/speed_benchmarks/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def setup(self):
def __call__(self, inputs):
x = inputs
for lyr, actv in zip(self.layers, self.jax_activations, strict=False):
x = lyr(x) # type: ignore
x = lyr(x)
x = actv(x) # type: ignore
return x

Expand Down
2 changes: 1 addition & 1 deletion examples/model_api/linear_regression_jax_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@
for i in range(num_epochs):
outputs, gradients = pm.evaluate_all(params)
updates, opt_state = optimizer.update(gradients, opt_state)
params = optax.apply_updates(params, updates) # type: ignore
params = optax.apply_updates(params, updates)
print(f"Epoch: {i} / {num_epochs} -> ", outputs["final_cost"])
10 changes: 5 additions & 5 deletions examples/model_api/variable_length_many_to_one_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
target_end = int(input_end + target_lengths[idx])

# NOTE: Pylance sees int, int type arguments but throws an error.
single_input = backend.arange(start, input_end).reshape(-1, input_dim) # type: ignore
single_target = backend.arange(input_end, target_end).reshape(-1, output_dim) # type: ignore
single_input = backend.arange(start, input_end).reshape(-1, input_dim)
single_target = backend.arange(input_end, target_end).reshape(-1, output_dim)

single_data = (single_input, single_target)
train_data.append(single_data)
Expand Down Expand Up @@ -150,7 +150,7 @@
# Prepare the test input data.
test_input = backend.arange(
starting_number,
starting_number + inference_max_input, # type: ignore
starting_number + inference_max_input,
).reshape(-1, input_dim)

# Prepare the test data.
Expand All @@ -172,7 +172,7 @@

# Prepare target values.
test_target_values = backend.arange(
starting_number + inference_max_input, # type: ignore
starting_number + inference_max_input,
starting_number + inference_max_input + inference_max_target_length,
)

Expand Down Expand Up @@ -204,4 +204,4 @@
)

# Measure test error.
error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() # type: ignore
error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum()
15 changes: 8 additions & 7 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,27 @@ def __init__(self, precision: int = 32, device: str = "cpu") -> None:
# setattr(self, key, value)

@property
def precision(self):
def precision(self) -> int:
return self._precision

#!!
@property
def device(self):
def device(self) -> Any:
return self._device

def get_device(self):
def get_device(self) -> str:
return self._device

@property
def inf(self) -> DataType | float:
raise NotImplementedError("inf is not implemented")

@property
def pi(self):
def pi(self) -> float:
return math.pi

@property
def e(self):
def e(self) -> float:
return math.e

@property
Expand Down Expand Up @@ -104,7 +105,7 @@ def to_device(
def block_until_ready(self, data: DataType) -> DataType | None:
raise RuntimeError("Backend does not support block_until_ready method!")

def empty_cache(self): # noqa: B027
def empty_cache(self) -> None: # noqa: B027
pass
# print("Warning: empty_cache is not supported!")

Expand All @@ -126,7 +127,7 @@ def cast(self, value: Any) -> Any:

return value

def __del__(self):
def __del__(self) -> None:
self.empty_cache()

@overload
Expand Down
4 changes: 2 additions & 2 deletions mithril/backends/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def run_callable(self, *primals: Any, fn_name: str) -> dict[str, Any]:
@abstractmethod
def parallelize(
self, tensor: DataType, device_mesh: tuple[int, ...] | None = None
) -> dict[str, Any]:
) -> DataType:
raise NotImplementedError()

def clean_up(self):
def clean_up(self) -> None:
self.callables = dict()
self.device_mesh = None
self.n_devices = -1
98 changes: 50 additions & 48 deletions mithril/backends/with_autograd/common_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,80 +63,80 @@
]


def greater(left: DataType, right: DataType):
def greater(left: DataType, right: DataType) -> DataType:
return left > right


def greater_equal(left: DataType, right: DataType):
def greater_equal(left: DataType, right: DataType) -> DataType:
return left >= right


def less(left: DataType, right: DataType):
def less(left: DataType, right: DataType) -> DataType:
return left < right


def less_equal(left: DataType, right: DataType):
def less_equal(left: DataType, right: DataType) -> DataType:
return left <= right


def equal(left: DataType, right: DataType):
return left == right
def equal(left: DataType, right: DataType) -> DataType:
return left == right # type: ignore


def not_equal(left: DataType, right: DataType):
return left != right
def not_equal(left: DataType, right: DataType) -> DataType:
return left != right # type: ignore


def logical_not(input: DataType):
def logical_not(input: DataType) -> DataType:
return ~input


def logical_or(left: DataType, right: DataType):
return left | right
def logical_or(left: DataType, right: DataType) -> DataType:
return left | right # type: ignore


def logical_and(left: DataType, right: DataType):
return left & right
def logical_and(left: DataType, right: DataType) -> DataType:
return left & right # type: ignore


def matrix_multiplication(left: DataType, right: DataType):
return left @ right
def matrix_multiplication(left: DataType, right: DataType) -> DataType:
return left @ right # type: ignore


def add(left: DataType, right: DataType):
return left + right
def add(left: DataType, right: DataType) -> DataType:
return left + right # type: ignore


def subtract(left: DataType, right: DataType):
return left - right
def subtract(left: DataType, right: DataType) -> DataType:
return left - right # type: ignore


def multiplication(left: DataType, right: DataType):
return left * right
def multiplication(left: DataType, right: DataType) -> DataType:
return left * right # type: ignore


def divide(numerator: DataType, denominator: DataType):
return numerator / denominator
def divide(numerator: DataType, denominator: DataType) -> DataType:
return numerator / denominator # type: ignore


def floor_divide(numerator: DataType, denominator: DataType):
return numerator // denominator
def floor_divide(numerator: DataType, denominator: DataType) -> DataType:
return numerator // denominator # type: ignore


def shift_left(input: DataType, shift: DataType):
return input << shift
def shift_left(input: DataType, shift: DataType) -> DataType:
return input << shift # type: ignore


def shift_right(input: DataType, shift: DataType):
return input >> shift
def shift_right(input: DataType, shift: DataType) -> DataType:
return input >> shift # type: ignore


def power(base: DataType, exponent: DataType):
return base**exponent
def power(base: DataType, exponent: DataType) -> DataType:
return base**exponent # type: ignore


def squared_error(input: DataType, target: DataType):
return (input - target) ** 2
def squared_error(input: DataType, target: DataType) -> DataType:
return (input - target) ** 2 # type: ignore


def minus(input: DataType) -> DataType:
Expand All @@ -148,18 +148,18 @@ def transpose(
) -> DataType:
if not axes:
return input.T
return input.transpose(*axes)
return input.transpose(*axes) # type: ignore


def swapaxes(input: DataType, axis1: int, axis2: int):
def swapaxes(input: DataType, axis1: int, axis2: int) -> DataType:
return input.swapaxes(axis1, axis2)


def square(input: DataType):
return input * input
def square(input: DataType) -> DataType:
return input * input # type: ignore


def buffer(input: DataType):
def buffer(input: DataType) -> DataType:
return input


Expand All @@ -168,27 +168,29 @@ def permute_tensor(input: DataType, indices: DataType) -> DataType:


def reshape(input: DataType, shape: tuple[int, ...]) -> DataType:
return input.reshape(shape)
return input.reshape(shape) # type: ignore


def item(input: DataType) -> int | float | bool:
return input.item() # type: ignore


def tensor_item(input: DataType, index: int | slice | tuple[int | slice, ...]):
return input[index]
def tensor_item(
input: DataType, index: int | slice | tuple[int | slice, ...]
) -> DataType:
return input[index] # type: ignore


def primitive_slice(start: int | None, stop: int | None, step: int | None):
def primitive_slice(start: int | None, stop: int | None, step: int | None) -> slice:
return slice(start, stop, step)


def length(input: DataType) -> int:
return len(input)


def cartesian_diff(left: DataType, right: DataType):
return left[:, None, :] - right[None, :, :]
def cartesian_diff(left: DataType, right: DataType) -> DataType:
return left[:, None, :] - right[None, :, :] # type: ignore


def primitive_embedding(input: DataType, weight: DataType) -> DataType:
Expand Down Expand Up @@ -218,11 +220,11 @@ def union(*inputs: int | float | tuple[int | float, ...]) -> tuple[int | float,
return result


def to_tuple(*args: tuple[int | float | bool, ...]):
def to_tuple(*args: int | float | bool) -> tuple[int | float | bool, ...]:
return tuple(args)


def to_list(*args: tuple[int | float | bool, ...]):
def to_list(*args: int | float | bool) -> list[int | float | bool]:
return list(args)


Expand Down Expand Up @@ -291,7 +293,7 @@ def padding_converter_2d(
def stride_converter(
input: int | PaddingType | tuple[int, int] | None,
kernel_size: int | tuple[int, int],
):
) -> int | tuple[int, int] | PaddingType:
if input is None:
return kernel_size
else:
Expand All @@ -303,7 +305,7 @@ def tuple_converter(
| PaddingType
| tuple[int, int]
| tuple[tuple[int, int], tuple[int, int]],
):
) -> tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] | PaddingType:
if isinstance(input, int):
return (input, input)
else:
Expand Down
20 changes: 11 additions & 9 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def is_manualgrad(self) -> bool:
return False

@property
def inf(self):
def inf(self) -> float:
return jax.numpy.inf

@property
def nan(self):
def nan(self) -> float:
return jax.numpy.nan

def get_backend_array_type(self):
def get_backend_array_type(self) -> type[jax.Array]:
return jax.Array

@property
Expand All @@ -98,7 +98,7 @@ def DataType(self): # noqa: N802
return utils.ArrayType

@staticmethod
def get_available_devices():
def get_available_devices() -> list[str]:
"""Static method to get a list of available devices.

Parameters
Expand All @@ -112,7 +112,7 @@ def get_available_devices():
def register_primitive(fn: Callable[..., Any]) -> None:
JaxBackend.registered_primitives[fn.__name__] = fn

def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
self.seed = seed
self.prng_key = jax.random.PRNGKey(seed)

Expand Down Expand Up @@ -145,23 +145,23 @@ def block_until_ready(self, data: jax.Array) -> jax.Array | None:

def register_callable(
self, fn: Callable[..., Any], fn_name: str, jit: bool = False
):
) -> None:
assert (
self._parallel_manager is not None
), "Parallel manager is not initialized!"

fn_name = str(id(self)) + fn_name
return self._parallel_manager.register_callable(fn, fn_name, jit)

def _run_callable(self, *primals: jax.Array, fn_name: str):
def _run_callable(self, *primals: jax.Array, fn_name: str) -> Any:
assert (
self._parallel_manager is not None
), "Parallel manager is not initialized!"

fn_name = str(id(self)) + fn_name
return self._parallel_manager.run_callable(*primals, fn_name=fn_name)

def _create_parallel(self, device_mesh: tuple[int, ...]):
def _create_parallel(self, device_mesh: tuple[int, ...]) -> None:
self._parallel_manager = JaxParallel(math.prod(device_mesh), self._device)

def array(
Expand Down Expand Up @@ -538,7 +538,9 @@ def multinomial(

return samples

def jit(self, *args: Any, **kwargs: Any):
def jit( # type: ignore[override]
self, *args: Any, **kwargs: Any
) -> Callable[..., jax.Array | tuple[jax.Array, ...]] | dict[str, jax.Array]:
return jax.jit(*args, **kwargs)

def grad(
Expand Down
Loading
Loading