Skip to content

Commit

Permalink
Jacobian codes are removed from codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetozsoy-synnada committed Dec 31, 2024
1 parent b32eecc commit 941b091
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 129 deletions.
3 changes: 0 additions & 3 deletions mithril/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def compile(
constant_keys: PhysicalConstantType[DataType] | None = None,
data_keys: Iterable[str | Connection] | None = None,
discard_keys: Iterable[str | Connection] | None = None,
jacobian_keys: Iterable[str | Connection] | None = None,
trainable_keys: Iterable[str | Connection] | None = None,
shapes: PhysicalShapeType | None = None,
inference: builtins.bool = False,
Expand Down Expand Up @@ -135,7 +134,6 @@ def compile(
constant_keys = constant_keys if constant_keys is not None else dict()
data_keys = set(data_keys) if data_keys is not None else set()
discard_keys = set(discard_keys) if discard_keys is not None else set()
jacobian_keys = set(jacobian_keys) if jacobian_keys is not None else set()
shapes = shapes if shapes is not None else dict()
trainable_keys = set(trainable_keys) if trainable_keys is not None else set()

Expand All @@ -146,7 +144,6 @@ def compile(
data_keys=data_keys,
constant_keys=constant_keys,
trainable_keys=trainable_keys,
jacobian_keys=jacobian_keys,
discard_keys=discard_keys,
shapes=shapes,
inference=inference,
Expand Down
99 changes: 2 additions & 97 deletions mithril/framework/physical/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import math
import random
import warnings
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from functools import partial, reduce

from ...backends.backend import Backend, ParallelBackend
from ...core import DataType, GenericDataType
Expand Down Expand Up @@ -87,7 +86,6 @@ def __init__(
data_keys: StringOrConnectionSetType,
constant_keys: PhysicalConstantType[DataType],
trainable_keys: StringOrConnectionSetType,
jacobian_keys: StringOrConnectionSetType,
shapes: PhysicalShapeType,
inference: bool,
safe_shapes: bool,
Expand Down Expand Up @@ -155,7 +153,6 @@ def __init__(
_trainable_keys = {self._convert_key(model, key) for key in trainable_keys}
_discard_keys = {self._convert_key(model, key) for key in discard_keys}
_shapes = {self._convert_key(model, k): v for k, v in shapes.items()}
_jacobian_keys = {self._convert_key(model, key) for key in jacobian_keys}

# Check provided constant and data_keys do not have
# any preset value. Note that this check is done after key conversions.
Expand All @@ -164,9 +161,7 @@ def __init__(
self._check_overridden_nontrainable_keys(model, constant_keys, data_keys)

# Final validation process of provided keys.
self._validate_keys(
_constant_keys, _data_keys, _trainable_keys, _discard_keys, _jacobian_keys
)
self._validate_keys(_constant_keys, _data_keys, _trainable_keys, _discard_keys)

# Set provided non-differentiable and trainable tensor keys.
self._non_differentiable_keys: set[str] = _constant_keys.keys() | _data_keys
Expand Down Expand Up @@ -249,7 +244,6 @@ def __init__(
self._pre_compile(
constant_keys=_constant_keys,
data_keys=_data_keys,
jacobian_keys=_jacobian_keys,
shapes=_shapes,
)

Expand Down Expand Up @@ -327,7 +321,6 @@ def _validate_keys(
data_keys: set[str],
trainable_keys: set[str],
discard_keys: set[str],
jacobian_keys: set[str],
) -> None:
# Make sure no common keys in constant_keys, data_keys, trainable_keys
# and discard_keys.
Expand Down Expand Up @@ -368,13 +361,6 @@ def _validate_keys(
f"Invalid keys: {', '.join(str(key) for key in internal_discards)}."
)

# Given jacobian keys must be subset of input keys.
if jacobian_diff := (jacobian_keys - self._input_keys):
raise KeyError(
"Provided jacobian keys must be subset of the input keys. "
f"Invalid keys: {', '.join(str(key) for key in jacobian_diff)}."
)

def get_shapes(
self,
model: BaseModel | None = None,
Expand Down Expand Up @@ -514,15 +500,7 @@ def _pre_compile(
constant_keys: dict[str, DataType | MainValueType],
data_keys: set[str],
shapes: PhysicalShapeType,
jacobian_keys: set[str],
):
if jacobian_keys and self.backend.is_manualgrad:
raise Exception(
"Jacobians are only calculated for the backends that have "
"autograd capability."
)

self.jacobian_keys = jacobian_keys
self.ignore_grad_keys: set[str] = set()

# Set given shapes.
Expand Down Expand Up @@ -615,79 +593,6 @@ def generate_functions(
)
self._generated_evaluate_all_fn: EvaluateAllType[DataType] | None = eval_all_fn

def create_jacobian_fn(self, generated_fn: Callable):
# TODO: Fix this method to make it picklable!
if self.backend.is_manualgrad:
raise (
NotImplementedError(
"Currently Jacobian is not supported for manuel grad!"
)
)

# TODO: Consider to JIT this function.
def multiplier(x, y):
return x * y

def jacobian_fn(
inputs: dict[str, DataType], data: dict[str, DataType] | None = None
):
# Function for calculating jacobians for the requested
# outputs stated in jacobian keys. We use more efficient
# jacobian method considerin input-output dimensionalities.
if data is None:
data = {}

def jacobian_wrapper(input, output):
total_inputs = inputs | input

return generated_fn(params=total_inputs, data=data)[output]

jacobians: dict[str, dict[str, DataType]] = {}

# Define default jacobian method as jacrev since
# output dimensionality is generally lower than input.
jacobian_method = self.backend.jacrev # type: ignore

# Iterate over all requested outputs for Jacobian calculations.
for out in self.jacobian_keys:
jacobians[out] = {}
# Iterate over all trainable inputs.

jacobian_par_fn = jacobian_method(partial(jacobian_wrapper, output=out))

for key in inputs:
# if all(isinstance(dim, int) for dim in self.shapes[out]) and all(
# isinstance(dim, int) for dim in self.shapes[key]
# ):
key_shp = self.shapes[key]
out_shp = self.shapes[out]
if (
isinstance(key_shp, list)
and isinstance(out_shp, list)
and is_list_int(key_shp)
and is_list_int(out_shp)
):
# If dimensions are known, jacrev is more efficient
# for wide Jacobian matrices where output dimensionalitiy
# is lower than input dimensionality.
# jacfwd is more efficient in oppisite condition.
cond = reduce(multiplier, out_shp) >= reduce(
multiplier, key_shp
)
jacobian_method = [self.backend.jacrev, self.backend.jacfwd][ # type: ignore
cond
]
# Provide input in dict format in order to get jacobians in dict
# format since all inputs are originally provided in dict format.
input = {key: inputs[key]}
# jacobians[out] |= jacobian_method(
# partial(jacobian_wrapper, output=out)
# )(input)
jacobians[out] |= jacobian_par_fn(input)
return jacobians

return jacobian_fn

def infer_ignore(
self,
weak_keys: set[str],
Expand Down
35 changes: 16 additions & 19 deletions tests/scripts/test_compile_keys_consistencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_dollar_sign_str():
"constant_keys",
"data_keys",
"discard_keys",
"jacobian_keys",
"trainable_keys",
"shapes",
]:
Expand Down Expand Up @@ -73,7 +72,6 @@ def test_connection_not_found():
"constant_keys",
"data_keys",
"discard_keys",
"jacobian_keys",
"trainable_keys",
"shapes",
]:
Expand Down Expand Up @@ -105,7 +103,6 @@ def test_string_not_found():
"constant_keys",
"data_keys",
"discard_keys",
"jacobian_keys",
"trainable_keys",
"shapes",
]:
Expand Down Expand Up @@ -256,22 +253,22 @@ def test_discard_keys_input_and_outputs_only():
)


def test_jacobian_keys_inputs_only():
"""jacobian_keys can not include any keys
other than the inputs of the model.
"""
model = Model()
model += (lin_model := Linear(1, True))(input="input", output="lin_out")
model += Multiply()(output=IOKey(name="output"))

backend = TorchBackend()
with pytest.raises(KeyError) as err_info:
ml_compile(model, backend, jacobian_keys={lin_model.output, "input"})
assert (
str(err_info.value)
== "'Provided jacobian keys must be subset of the input keys. "
"Invalid keys: lin_out.'"
)
# def test_jacobian_keys_inputs_only():
# """jacobian_keys can not include any keys
# other than the inputs of the model.
# """
# model = Model()
# model += (lin_model := Linear(1, True))(input="input", output="lin_out")
# model += Multiply()(output=IOKey(name="output"))

# backend = TorchBackend()
# with pytest.raises(KeyError) as err_info:
# ml_compile(model, backend, jacobian_keys={lin_model.output, "input"})
# assert (
# str(err_info.value)
# == "'Provided jacobian keys must be subset of the input keys. "
# "Invalid keys: lin_out.'"
# )


def test_iterable_type_keys():
Expand Down
8 changes: 0 additions & 8 deletions tests/scripts/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_data_store_1():
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -77,7 +76,6 @@ def test_data_store_1_numpy():
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -139,7 +137,6 @@ def test_data_store_4():
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -426,7 +423,6 @@ def test_data_store_16():
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -459,7 +455,6 @@ def test_data_store_17():
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -490,7 +485,6 @@ def test_data_store_18():
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -521,7 +515,6 @@ def test_data_store_19():
data_keys=set(),
constant_keys={"left": left, "right": right},
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down Expand Up @@ -552,7 +545,6 @@ def test_data_store_20():
data_keys=set(),
constant_keys={"left": left, "right": right},
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=False,
safe_shapes=True,
Expand Down
1 change: 0 additions & 1 deletion tests/scripts/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_discard_keys_inference(case: str) -> None:
data_keys=set(),
constant_keys=dict(),
trainable_keys=set(),
jacobian_keys=set(),
shapes=dict(),
inference=True,
safe_shapes=True,
Expand Down
1 change: 0 additions & 1 deletion tests/scripts/test_pm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_set_random_keys():
data_keys=set(),
constant_keys={},
trainable_keys=set(),
jacobian_keys=set(),
inference=False,
safe_shapes=False,
safe_names=False,
Expand Down

0 comments on commit 941b091

Please sign in to comment.