Skip to content

Commit

Permalink
Prepare for "Fix type-safety of torch.nn.Module instances": wave 2
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#3099

See D52890934

I might absorb these into earlier diffs in the stack, these are the more recent ones
that I got autofixes on from the buildall

Differential Revision: D66245100
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 21, 2024
1 parent 478942d commit 889ee51
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions captum/_utils/transformers_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,5 @@ def supports_caching(model: nn.Module) -> bool:
# Cache is mandatory
return True
# Fallback on _supports_cache_class attribute
# pyre-fixme[7]: Expected `bool` but got `Union[Module, Tensor]`.
return getattr(model, "_supports_cache_class", False)
4 changes: 4 additions & 0 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def _forward_pre_hook(
set necessary hooks on inputs there.
"""
inputs = _format_tensor_into_tuples(inputs)
# pyre-fixme[16]: `Module` has no attribute `input`.
module.input = inputs[0].clone().detach()

def _forward_hook(
Expand All @@ -420,6 +421,7 @@ def _forward_hook(
outputs of a neuron
"""
outputs = _format_tensor_into_tuples(outputs)
# pyre-fixme[16]: `Module` has no attribute `output`.
module.output = outputs[0].clone().detach()

def _backward_hook(
Expand Down Expand Up @@ -536,6 +538,8 @@ def forward_hook(
):
return [
self.model.module.register_forward_pre_hook(pre_hook), # type: ignore
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
# attribute `register_forward_hook`.
self.model.module.register_forward_hook(forward_hook),
] # type: ignore
else:
Expand Down
3 changes: 3 additions & 0 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def _get_target_tokens(
gen_args = DEFAULT_GEN_ARGS

model_inp = self._format_model_input(inp.to_model_input())
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
output_tokens = self.model.generate(model_inp, **gen_args)
target_tokens = output_tokens[0][model_inp.size(1) :]
else:
Expand Down Expand Up @@ -558,9 +559,11 @@ def _forward_func(
outputs.past_key_values = DynamicCache.from_legacy_cache(
outputs.past_key_values
)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
model_kwargs = self.model._update_model_kwargs_for_generation(
outputs, model_kwargs
)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
model_inputs = self.model.prepare_inputs_for_generation(
model_inp, **model_kwargs
)
Expand Down
2 changes: 2 additions & 0 deletions captum/insights/attr_vis/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def _calculate_vis_output(
else self.models
)
results = []
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
# `Union[List[Any], Module]`.
for model_index, model in enumerate(models_used):
# Get list of model visualizations for each input
actual_label_output = None
Expand Down
4 changes: 4 additions & 0 deletions tests/attr/layer/test_layer_gradient_x_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def _layer_activation_test_assert(
self, attributions, expected_activation, delta=0.01
)
else:
# pyre-fixme[6]: For 1st argument expected
# `pyre_extensions.PyreReadOnly[Sized]` but got `ModuleOrModuleList`.
for i in range(len(target_layer)):
assertTensorTuplesAlmostEqual(
self, attributions[i], expected_activation[i], delta=0.01
Expand All @@ -196,6 +198,8 @@ def _layer_activation_test_assert(
delta=0.01,
)
else:
# pyre-fixme[6]: For 1st argument expected
# `pyre_extensions.PyreReadOnly[Sized]` but got `ModuleOrModuleList`.
for i in range(len(target_layer)):
assertTensorTuplesAlmostEqual(
self,
Expand Down

0 comments on commit 889ee51

Please sign in to comment.