Skip to content

Commit

Permalink
Fix pyre-fixmes in captum/attr/_utils/input_layer_wrapper.py
Browse files Browse the repository at this point in the history
Summary: Fix Pyre fixmes in the input layer wrapper python file.

Reviewed By: cyrjano

Differential Revision: D66972149
  • Loading branch information
styusuf authored and facebook-github-bot committed Dec 9, 2024
1 parent ad77e79 commit 1c5b7ae
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions captum/attr/_utils/input_layer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# pyre-strict

import inspect
from typing import Any
from typing import Dict, List

import torch.nn as nn
from torch import Tensor


class InputIdentity(nn.Module):
Expand All @@ -21,9 +22,7 @@ def __init__(self, input_name: str) -> None:
super().__init__()
self.input_name = input_name

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return x


Expand Down Expand Up @@ -64,18 +63,17 @@ def __init__(self, module_to_wrap: nn.Module) -> None:
self.module = module_to_wrap

# ignore self
# pyre-fixme[4]: Attribute must be annotated.
self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:]
self.arg_name_list: List[str] = inspect.getfullargspec(
module_to_wrap.forward
).args[1:]
self.input_maps = nn.ModuleDict(
{arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list}
)

# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, *args, **kwargs) -> Any:
args = list(args)
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)):
args[idx] = self.input_maps[arg_name](arg)
def forward(self, *args: str, **kwargs: Dict[str, str]) -> object:
args_list = list(args)
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args_list)):
args_list[idx] = self.input_maps[arg_name](arg)

for arg_name in kwargs.keys():
kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])
Expand Down

0 comments on commit 1c5b7ae

Please sign in to comment.