-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagenette.py
111 lines (86 loc) · 3.65 KB
/
imagenette.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import os
import importlib
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
class LiMImagenette(pl.LightningModule):
def __init__(self, data_dir='~/data/imagenette2',
arch='resnet18',
learning_rate=0.1,
momentum=0.9,
weight_decay=5e-4,
num_workers=4, batch_size=256):
super().__init__()
# Set our init args as class attributes
self.data_dir = os.path.expanduser(data_dir)
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.momentum = momentum
self.num_workers = num_workers
self.batch_size = batch_size
# Hardcode some dataset specific attributes
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
self.test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
module = importlib.import_module("models.imagenette")
self.model = getattr(module, arch)()
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def on_epoch_start(self):
self.logger.experiment.add_scalar('lr', self.trainer.optimizers[0].param_groups[0]['lr'], self.current_epoch)
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
# Tensorboard logs
self.log('train_loss', loss, prog_bar=False)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
preds = torch.argmax(outputs, dim=1)
acc = accuracy(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=90)
return [optimizer], [scheduler]
def prepare_data(self):
pass
def setup(self, stage=None):
self.trainset = torchvision.datasets.ImageFolder(
root=os.path.join(self.data_dir, 'train'),
transform=self.train_transform)
self.valset = torchvision.datasets.ImageFolder(
root=os.path.join(self.data_dir, 'val'),
transform=self.test_transform)
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.valset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)