-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
131 lines (103 loc) · 3.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import ImageFilter
from torch.utils.data import DataLoader
from .backdoor import BadNets, Blend
from .cifar import CIFAR10
from .prefetch import PrefetchLoader
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR.
Borrowed from https://github.com/facebookresearch/moco/blob/master/moco/loader.py.
"""
def __init__(self, sigma=[0.1, 2.0]):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
def query_transform(name, kwargs):
if name == "random_crop":
return transforms.RandomCrop(**kwargs)
elif name == "random_resize_crop":
return transforms.RandomResizedCrop(**kwargs)
elif name == "resize":
return transforms.Resize(**kwargs)
elif name == "center_crop":
return transforms.CenterCrop(**kwargs)
elif name == "random_horizontal_flip":
return transforms.RandomHorizontalFlip(**kwargs)
elif name == "random_color_jitter":
# In-place!
p = kwargs.pop("p")
return transforms.RandomApply([transforms.ColorJitter(**kwargs)], p=p)
elif name == "random_grayscale":
return transforms.RandomGrayscale(**kwargs)
elif name == "gaussian_blur":
# In-place!
p = kwargs.pop("p")
return transforms.RandomApply([GaussianBlur(**kwargs)], p=p)
elif name == "to_tensor":
if kwargs:
return transforms.ToTensor()
elif name == "normalize":
return transforms.Normalize(**kwargs)
else:
raise ValueError("Transformation {} is not supported!".format(name))
def get_transform(transform_config):
transform = []
if transform_config is not None:
for (k, v) in transform_config.items():
if v is not None:
transform.append(query_transform(k, v))
transform = transforms.Compose(transform)
return transform
def get_dataset(dataset_dir, transform, train=True, prefetch=False):
if "cifar" in dataset_dir:
dataset = CIFAR10(
dataset_dir, transform=transform, train=train, prefetch=prefetch
)
else:
raise ValueError("Dataset in {} is not supported.".format(dataset_dir))
return dataset
def get_loader(dataset, loader_config=None, **kwargs):
if loader_config is None:
loader = DataLoader(dataset, **kwargs)
else:
loader = DataLoader(dataset, **loader_config, **kwargs)
if dataset.prefetch:
loader = PrefetchLoader(loader, dataset.mean, dataset.std)
return loader
def gen_poison_idx(dataset, target_label, poison_ratio=None):
poison_idx = np.zeros(len(dataset))
train = dataset.train
for (i, t) in enumerate(dataset.targets):
if train and poison_ratio is not None:
if random.random() < poison_ratio and t != target_label:
poison_idx[i] = 1
else:
if t != target_label:
poison_idx[i] = 1
return poison_idx
def get_bd_transform(bd_config):
if "badnets" in bd_config:
bd_transform = BadNets(bd_config["badnets"]["trigger_path"])
elif "blend" in bd_config:
bd_transform = Blend(**bd_config["blend"])
else:
raise ValueError("Backdoor {} is not supported.".format(bd_config))
return bd_transform
def get_semi_idx(record_list, ratio, logger):
"""Get labeled and unlabeled index.
"""
keys = [r.name for r in record_list]
loss = record_list[keys.index("loss")].data.numpy()
poison = record_list[keys.index("poison")].data.numpy()
semi_idx = np.zeros(len(loss))
# Sort loss and fetch `ratio` of the smallest indices.
indice = loss.argsort()[: int(len(loss) * ratio)]
logger.info(
"{}/{} poisoned samples in semi_idx".format(poison[indice].sum(), len(indice))
)
semi_idx[indice] = 1
return semi_idx