diff --git a/tests/helpers/evaluate_linear_model.py b/tests/helpers/evaluate_linear_model.py index 19f1207466..15dc24c8c1 100644 --- a/tests/helpers/evaluate_linear_model.py +++ b/tests/helpers/evaluate_linear_model.py @@ -6,11 +6,13 @@ from typing import cast, Dict import torch + +from captum._utils.models.linear_model.model import LinearModel from torch import Tensor +from torch.utils.data import DataLoader -# pyre-fixme[2]: Parameter must be annotated. -def evaluate(test_data, classifier) -> Dict[str, Tensor]: +def evaluate(test_data: DataLoader, classifier: LinearModel) -> Dict[str, Tensor]: classifier.eval() l1_loss = 0.0 diff --git a/tests/metrics/test_sensitivity.py b/tests/metrics/test_sensitivity.py index 14c04fba54..9dbb4e4749 100644 --- a/tests/metrics/test_sensitivity.py +++ b/tests/metrics/test_sensitivity.py @@ -3,7 +3,7 @@ # pyre-strict import typing -from typing import Callable, cast, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -28,19 +28,15 @@ @typing.overload -# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible -# arguments of overload defined on line `32`. def _perturb_func(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: ... @typing.overload -# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible -# arguments of overload defined on line `28`. def _perturb_func(inputs: Tensor) -> Tensor: ... def _perturb_func( - inputs: TensorOrTupleOfTensorsGeneric, + inputs: Union[Tensor, Tuple[Tensor, ...]], ) -> Union[Tensor, Tuple[Tensor, ...]]: def perturb_ratio(input: Tensor) -> Tensor: return ( @@ -55,7 +51,7 @@ def perturb_ratio(input: Tensor) -> Tensor: input1 = inputs[0] input2 = inputs[1] else: - input1 = cast(Tensor, inputs) + input1 = inputs perturbed_input1 = input1 + perturb_ratio(input1) @@ -283,12 +279,13 @@ def test_classification_sensitivity_tpl_target_w_baseline(self) -> None: def sensitivity_max_assert( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - expl_func: Callable, + expl_func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: TensorOrTupleOfTensorsGeneric, expected_sensitivity: Tensor, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perturb_func: Callable = _perturb_func, + perturb_func: Union[ + Callable[[Tensor], Tensor], + Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], + ] = _perturb_func, n_perturb_samples: int = 5, max_examples_per_batch: Optional[int] = None, baselines: Optional[BaselineType] = None,