Skip to content

Commit

Permalink
fix: hooks and layer feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Feb 26, 2024
1 parent cf1dbe2 commit fe3c9eb
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/torchattack/fia.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
if self.alpha is None:
self.alpha = self.eps / self.steps

h = self.feature_layer.register_forward_hook(self.__forward_hook)
h2 = self.feature_layer.register_full_backward_hook(self.__backward_hook)
h = self.feature_layer.register_forward_hook(self.__forward_hook) # type: ignore
h2 = self.feature_layer.register_full_backward_hook(self.__backward_hook) # type: ignore

# Gradient aggregation on ensembles
agg_grad = 0
Expand All @@ -93,7 +93,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
loss += output_random[batch_i][y[batch_i]] # type: ignore
self.model.zero_grad()
loss.backward() # type: ignore
agg_grad += self.mid_grad[0].detach()
agg_grad += self.mid_grad[0].detach() # type: ignore
for batch_i in range(x.shape[0]):
agg_grad[batch_i] /= agg_grad[batch_i].norm(p=2) # type: ignore
h2.remove()
Expand All @@ -104,10 +104,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
_ = self.model(self.normalize(x + delta))

# Hooks are updated during forward pass
loss = (self.mid_output * agg_grad).sum()
outs = (self.mid_output * agg_grad).sum()

self.model.zero_grad()
grad = torch.autograd.grad(loss, delta, retain_graph=False)[0]
grad = torch.autograd.grad(outs, delta, retain_graph=False)[0]

# Update delta
delta.data = delta.data - self.alpha * grad.sign()
Expand Down Expand Up @@ -143,17 +143,17 @@ def find_layer(self, feature_layer_name) -> nn.Module:
The layer to compute feature importance.
"""

parser = feature_layer_name.split(' ')
m = self.model

for layer in parser:
if layer in m._modules:
m = m._modules[layer]
break
else:
raise ValueError(f'Layer {layer} not found in the model.')

return m
# for layer in feature_layer_name.split(' '):
# if layer not in self.model._modules:
# raise ValueError(f'Layer {layer} not found in the model.')
# return self.model._modules[layer]

if feature_layer_name not in self.model._modules:
raise ValueError(f'Layer {feature_layer_name} not found in the model.')
feature_layer = self.model._modules[feature_layer_name]
if not isinstance(feature_layer, nn.Module):
raise ValueError(f'Layer {feature_layer_name} invalid.')
return feature_layer

def __forward_hook(self, m: nn.Module, i: torch.Tensor, o: torch.Tensor):
self.mid_output = o
Expand Down

0 comments on commit fe3c9eb

Please sign in to comment.