-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create TARGETS for captum/_utils (#1250)
Summary: Pull Request resolved: #1250 Create separate TARGETS files for different part of Captum project. Start with a relatively simple one: captum/_utils. Mostly TARGETS file change, with more exception to correct import, split test helper function to separate file etc. Reviewed By: cyrjano Differential Revision: D55091069 fbshipit-source-id: 83cbd8a632c8ba71d60d14859bbc549f7ae7b511
- Loading branch information
1 parent
fabac35
commit 5c9d0bf
Showing
5 changed files
with
61 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,5 @@ | ||
from captum._utils.models.linear_model import ( | ||
LinearModel, | ||
SGDLasso, | ||
SGDLinearModel, | ||
SGDLinearRegression, | ||
SGDRidge, | ||
SkLearnLasso, | ||
SkLearnLinearModel, | ||
SkLearnLinearRegression, | ||
SkLearnRidge, | ||
) | ||
from captum._utils.models.model import Model | ||
|
||
__all__ = [ | ||
"Model", | ||
"LinearModel", | ||
"SGDLinearModel", | ||
"SGDLasso", | ||
"SGDRidge", | ||
"SGDLinearRegression", | ||
"SkLearnLinearModel", | ||
"SkLearnLasso", | ||
"SkLearnRidge", | ||
"SkLearnLinearRegression", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
from typing import cast, Dict | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def evaluate(test_data, classifier) -> Dict[str, Tensor]: | ||
classifier.eval() | ||
|
||
l1_loss = 0.0 | ||
l2_loss = 0.0 | ||
n = 0 | ||
l2_losses = [] | ||
with torch.no_grad(): | ||
for data in test_data: | ||
if len(data) == 2: | ||
x, y = data | ||
w = None | ||
else: | ||
x, y, w = data | ||
|
||
out = classifier(x) | ||
|
||
y = y.view(x.shape[0], -1) | ||
assert y.shape == out.shape | ||
|
||
if w is None: | ||
l1_loss += (out - y).abs().sum(0).to(dtype=torch.float64) | ||
l2_loss += ((out - y) ** 2).sum(0).to(dtype=torch.float64) | ||
l2_losses.append(((out - y) ** 2).to(dtype=torch.float64)) | ||
else: | ||
l1_loss += ( | ||
(w.view(-1, 1) * (out - y)).abs().sum(0).to(dtype=torch.float64) | ||
) | ||
l2_loss += ( | ||
(w.view(-1, 1) * ((out - y) ** 2)).sum(0).to(dtype=torch.float64) | ||
) | ||
l2_losses.append( | ||
(w.view(-1, 1) * ((out - y) ** 2)).to(dtype=torch.float64) | ||
) | ||
|
||
n += x.shape[0] | ||
|
||
l2_losses = torch.cat(l2_losses, dim=0) | ||
assert n > 0 | ||
|
||
# just to double check | ||
assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all() | ||
|
||
classifier.train() | ||
return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters