Skip to content

Commit

Permalink
refactor: dataset paths and transform
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Feb 26, 2024
1 parent 106932b commit e005064
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Dataset files
data/
datasets/

# PDM files
.pdm-python
Expand Down
6 changes: 3 additions & 3 deletions src/torchattack/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class NIPSLoader(DataLoader):

def __init__(
self,
path: str | None,
root: str | None,
image_root: str | None = None,
pairs_path: str | None = None,
batch_size: int = 1,
Expand All @@ -115,8 +115,8 @@ def __init__(

super().__init__(
dataset=NIPSDataset(
image_root=image_root if image_root else f'{path}/images',
pairs_path=pairs_path if pairs_path else f'{path}/images.csv',
image_root=image_root if image_root else f'{root}/images',
pairs_path=pairs_path if pairs_path else f'{root}/images.csv',
transform=transform,
max_samples=max_samples,
),
Expand Down
25 changes: 12 additions & 13 deletions src/torchattack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,26 @@ def run_attack(attack, attack_cfg, model='resnet50', samples=100, batch_size=8)
from rich import print
from rich.progress import track

# Set up model and dataloader
# Set up model, transform, and normalize
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = tv.models.get_model(name=model, weights='DEFAULT').to(device).eval()
transform = tv.transforms.Compose(
[
tv.transforms.Resize([224]),
tv.transforms.ToTensor(),
]
)
normalize = tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# Set up dataloader
dataloader = NIPSLoader(
path='data/nips2017',
root='datasets/nips2017',
batch_size=batch_size,
transform=tv.transforms.Compose(
[
tv.transforms.Resize([232]),
tv.transforms.CenterCrop([224]),
tv.transforms.ToTensor(),
]
),
transform=transform,
max_samples=samples,
)

# Set up attack and trackers
normalize = tv.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
total, acc_clean, acc_adv = len(dataloader.dataset), 0, 0 # type: ignore
attacker = attack(model=model, normalize=normalize, device=device, **attack_cfg)
print(attacker)
Expand Down

0 comments on commit e005064

Please sign in to comment.