Skip to content

Commit

Permalink
Fix test sensitivity pyre fix me issues (#1481)
Browse files Browse the repository at this point in the history
Summary:

Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: banne01

Differential Revision: D67726420
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 31, 2024
1 parent 3f48e79 commit fc11fbf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
6 changes: 4 additions & 2 deletions tests/helpers/evaluate_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import cast, Dict

import torch

from captum._utils.models.linear_model.model import LinearModel
from torch import Tensor
from torch.utils.data import DataLoader


# pyre-fixme[2]: Parameter must be annotated.
def evaluate(test_data, classifier) -> Dict[str, Tensor]:
def evaluate(test_data: DataLoader, classifier: LinearModel) -> Dict[str, Tensor]:
classifier.eval()

l1_loss = 0.0
Expand Down
19 changes: 8 additions & 11 deletions tests/metrics/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyre-strict

import typing
from typing import Callable, cast, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
Expand All @@ -28,19 +28,15 @@


@typing.overload
# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
# arguments of overload defined on line `32`.
def _perturb_func(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
# arguments of overload defined on line `28`.
def _perturb_func(inputs: Tensor) -> Tensor: ...


def _perturb_func(
inputs: TensorOrTupleOfTensorsGeneric,
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Union[Tensor, Tuple[Tensor, ...]]:
def perturb_ratio(input: Tensor) -> Tensor:
return (
Expand All @@ -55,7 +51,7 @@ def perturb_ratio(input: Tensor) -> Tensor:
input1 = inputs[0]
input2 = inputs[1]
else:
input1 = cast(Tensor, inputs)
input1 = inputs

perturbed_input1 = input1 + perturb_ratio(input1)

Expand Down Expand Up @@ -283,12 +279,13 @@ def test_classification_sensitivity_tpl_target_w_baseline(self) -> None:

def sensitivity_max_assert(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
expl_func: Callable,
expl_func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]],
inputs: TensorOrTupleOfTensorsGeneric,
expected_sensitivity: Tensor,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
perturb_func: Callable = _perturb_func,
perturb_func: Union[
Callable[[Tensor], Tensor],
Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]],
] = _perturb_func,
n_perturb_samples: int = 5,
max_examples_per_batch: Optional[int] = None,
baselines: Optional[BaselineType] = None,
Expand Down

0 comments on commit fc11fbf

Please sign in to comment.