Skip to content

Commit

Permalink
Make mypy happy
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed May 8, 2024
1 parent 4aeea11 commit f31ce2f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions src/torchattack/admix.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
grad = torch.autograd.grad(loss, x_admixs)[0]

# Split gradients and compute mean
grads = torch.tensor_split(grad, 5, dim=0)
grads = [g * s for g, s in zip(grads, scales, strict=True)]
split_grads = torch.tensor_split(grad, 5, dim=0)
grads = [g * s for g, s in zip(split_grads, scales, strict=True)]
grad = torch.mean(torch.stack(grads), dim=0)

# Gather gradients
grads = torch.tensor_split(grad, self.size)
grad = torch.sum(torch.stack(grads), dim=0)
split_grads = torch.tensor_split(grad, self.size)
grad = torch.sum(torch.stack(split_grads), dim=0)

# Apply momentum term
g = self.decay * g + grad / torch.mean(
Expand Down
16 changes: 8 additions & 8 deletions src/torchattack/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
self,
image_root: str,
pairs_path: str,
transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
transform: Callable[[torch.Tensor | Image.Image], torch.Tensor] | None = None,
max_samples: int | None = None,
) -> None:
"""Initialize the NIPS 2017 Adversarial Learning Challenge dataset.
Expand Down Expand Up @@ -73,10 +73,10 @@ def __getitem__(self, index: int) -> tuple[Any, int, str]:
name = self.names[index]
label = int(self.labels[index]) - 1

image = Image.open(f'{self.image_root}/{name}.png').convert('RGB')
# image = np.array(image, dtype=np.uint8)
# image = torch.from_numpy(image).permute((2, 0, 1)).contiguous().float().div(255)
image = self.transform(image) if self.transform else image
pil_image = Image.open(f'{self.image_root}/{name}.png').convert('RGB')
# np_image = np.array(pil_image, dtype=np.uint8)
# image = torch.from_numpy(np_image).permute((2, 0, 1)).contiguous().float().div(255)
image = self.transform(pil_image) if self.transform else pil_image
return image, label, name


Expand All @@ -89,9 +89,9 @@ class NIPSLoader(DataLoader):
>>> from torchvision.transforms import transforms
>>> from torchattack.dataset import NIPSLoader
>>> transform = transforms.Resize([224])
>>> transform = transforms.Compose([transforms.Resize([224]), transforms.ToTensor()])
>>> dataloader = NIPSLoader(
>>> path="data/nips2017", batch_size=16, transform=transform
>>> path="data/nips2017", batch_size=16, transform=transform, max_samples=100
>>> )
You can specify a custom image root directory and CSV file location by
Expand All @@ -107,7 +107,7 @@ def __init__(
batch_size: int = 1,
shuffle: bool = False,
num_workers: int = 4,
transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
transform: Callable[[torch.Tensor | Image.Image], torch.Tensor] | None = None,
max_samples: int | None = None,
):
# Specifing a custom image root directory is useful when evaluating
Expand Down
10 changes: 5 additions & 5 deletions src/torchattack/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import suppress
from typing import Any, Optional
from typing import Any

import torch
import torchvision as tv
Expand All @@ -11,9 +11,9 @@ class FoolingRateMetric:
"""Fooling rate metric tracker."""

def __init__(self) -> None:
self.total_count = 0
self.clean_count = 0
self.adv_count = 0
self.total_count = torch.tensor(0)
self.clean_count = torch.tensor(0)
self.adv_count = torch.tensor(0)

def update(
self, labels: torch.Tensor, clean_logits: torch.Tensor, adv_logits: torch.Tensor
Expand Down Expand Up @@ -42,7 +42,7 @@ def compute_clean_accuracy(self) -> torch.Tensor:

def run_attack(
attack: Any,
attack_cfg: Optional[dict] = None,
attack_cfg: dict | None = None,
model_name: str = 'resnet50',
max_samples: int = 100,
batch_size: int = 8,
Expand Down

0 comments on commit f31ce2f

Please sign in to comment.