-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_transforms.py
31 lines (28 loc) · 1.31 KB
/
custom_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
# Weak augmentation
weak_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0), interpolation=InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip(0.5),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
transforms.ToTensor(),
# Uncomment the normalization if needed
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Strong augmentation
strong_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0), interpolation=InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip(0.5),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
transforms.RandAugment(),
transforms.ToTensor(),
# Uncomment the normalization if needed
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Test-time transformations (no augmentations, just resizing)
test_transform = transforms.Compose([
transforms.Resize([224, 224], interpolation=InterpolationMode.BILINEAR),
transforms.ToTensor(),
# Uncomment the normalization if needed
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])