Skip to content

Commit

Permalink
Fix summarizer pyre fix me issues (#1479)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1479

Fixing unresolved pyre fixme issues in corresponding file

Reviewed By: cyrjano

Differential Revision: D67707848

fbshipit-source-id: cddf89f0611c7acce367f89a6417c9165f85602b
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 31, 2024
1 parent 1e7d5ff commit 7c7a477
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions captum/attr/_utils/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Summarizer:
>>>print(summ.summary['mean'])
"""

_stats: List[Stat]
_summary_stats_indicies: List[int]

@log_usage()
def __init__(self, stats: List[Stat]) -> None:
r"""
Expand All @@ -37,11 +40,9 @@ def __init__(self, stats: List[Stat]) -> None:
"""
self._summarizers: List[SummarizerSingleTensor] = []
self._is_inputs_tuple: Optional[bool] = None
# pyre-fixme[4]: Attribute must be annotated.
self._stats, self._summary_stats_indicies = _reorder_stats(stats)

# pyre-fixme[3]: Return type must be annotated.
def _copy_stats(self):
def _copy_stats(self) -> List[Stat]:
import copy

return copy.deepcopy(self._stats)
Expand Down Expand Up @@ -125,48 +126,37 @@ def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]:
dep_order = [StdDev, Var, MSE, Mean, Count]

# remove dupe stats
# pyre-fixme[9]: stats has type `List[Stat]`; used as `Set[Stat]`.
stats = set(stats)
stats_set = set(stats)
summary_stats = set(stats)

from collections import defaultdict

# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type[<base type>]` to avoid runtime subscripting errors.
stats_by_module: Dict[Type, List[Stat]] = defaultdict(list)
for stat in stats:
stats_by_module: Dict[Type[Stat], List[Stat]] = defaultdict(list)
for stat in stats_set:
stats_by_module[stat.__class__].append(stat)

# StdDev is an odd case since it is parameterized, thus
# for each StdDev(order) we must ensure there is an associated Var(order)
for std_dev in stats_by_module[StdDev]:
stat_to_add = Var(order=std_dev.order) # type: ignore
# pyre-fixme[16]: `List` has no attribute `add`.
stats.add(stat_to_add)
stats_set.add(stat_to_add)
stats_by_module[stat_to_add.__class__].append(stat_to_add)

# For the other modules (deps[1:n-1]): if i exists =>
# we want to ensure i...n-1 exists
for i, dep in enumerate(dep_order[1:]):
if dep in stats_by_module:
# pyre-fixme[16]: `List` has no attribute `update`.
stats.update([mod() for mod in dep_order[i + 1 :]])
stats_set.update([mod() for mod in dep_order[i + 1 :]])
break

# Step 2: get the correct order
# NOTE: we are sorting via a given topological order
sort_order = {mod: i for i, mod in enumerate(dep_order)}
# pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
# Var]]` but got `Type[Min]`.
sort_order: Dict[Type[Stat], int] = {mod: i for i, mod in enumerate(dep_order)}
sort_order[Min] = -1
# pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
# Var]]` but got `Type[Max]`.
sort_order[Max] = -1
# pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
# Var]]` but got `Type[Sum]`.
sort_order[Sum] = -1

stats = list(stats)
stats = list(stats_set)
stats.sort(key=lambda x: sort_order[x.__class__], reverse=True)

# get the summary stat indices
Expand All @@ -185,6 +175,10 @@ class SummarizerSingleTensor:
If possible use `Summarizer` instead.
"""

_stats: List[Stat]
_stat_to_stat: Dict[Stat, Stat]
_summary_stats: List[Stat]

def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None:
r"""
Args:
Expand All @@ -196,9 +190,7 @@ def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None:
does not require any specific order.
"""
self._stats = stats
# pyre-fixme[4]: Attribute must be annotated.
self._stat_to_stat = {stat: stat for stat in self._stats}
# pyre-fixme[4]: Attribute must be annotated.
self._summary_stats = [stats[i] for i in summary_stats_indices]

for stat in stats:
Expand Down

0 comments on commit 7c7a477

Please sign in to comment.