Skip to content
This repository has been archived by the owner on Aug 17, 2024. It is now read-only.

Commit

Permalink
feat(rotnet): 添加checkpointer和logger,以及测试程序
Browse files Browse the repository at this point in the history
  • Loading branch information
zjykzj committed Aug 22, 2020
1 parent d844c4c commit b50a960
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 12 deletions.
14 changes: 13 additions & 1 deletion rotnet/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .mnist import FMNIST


def build_transform():
def build_train_transform():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
Expand All @@ -30,6 +30,18 @@ def build_transform():
return transform, target_transform


def build_test_transform():
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)),
])

return transform


def build_dataset(data_dir, transform=None, target_transform=None):
train_dataset = FMNIST(data_dir, download=True, train=True, transform=transform, target_transform=target_transform)
test_dataset = FMNIST(data_dir, download=True, train=False, transform=transform, target_transform=target_transform)
Expand Down
19 changes: 11 additions & 8 deletions rotnet/engine/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@
import torch

from rotnet.util.metrics import topk_accuracy
from rotnet.util.logger import setup_logger


def train_model(model_name, model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes,
def train_model(model_name, model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes, checkpointer,
epoches=100, device=None):
since = time.time()

logger = setup_logger("RotNet")
logger.info("Start training ...")

best_model_weights = copy.deepcopy(model.state_dict())
best_acc = 0.0

loss_dict = {'train': [], 'test': []}
acc_dict = {'train': [], 'test': []}
for epoch in range(epoches):
print('{} - Epoch {}/{}'.format(model_name, epoch, epoches - 1))
print('-' * 10)
logger.info('{} - Epoch {}/{}'.format(model_name, epoch, epoches - 1))
logger.info('-' * 10)

# Each epoch has a training and test phase
for phase in ['train', 'test']:
Expand Down Expand Up @@ -74,7 +78,7 @@ def train_model(model_name, model, criterion, optimizer, lr_scheduler, data_load
loss_dict[phase].append(epoch_loss)
acc_dict[phase].append(epoch_acc)

print('{} Loss: {:.4f} Top-1 Acc: {:.4f}'.format(
logger.info('{} Loss: {:.4f} Top-1 Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))

# deep copy the model
Expand All @@ -83,12 +87,11 @@ def train_model(model_name, model, criterion, optimizer, lr_scheduler, data_load
best_model_weights = copy.deepcopy(model.state_dict())

# 每训练一轮就保存
# util.save_model(model.cpu(), '../data/models/%s_%d.pth' % (model_name, epoch))
# model = model.to(device)
checkpointer.save("model_{:06d}".format(epoch))

time_elapsed = time.time() - since
print('Training {} complete in {:.0f}m {:.0f}s'.format(model_name, time_elapsed // 60, time_elapsed % 60))
print('Best test Top-1 Acc: {:4f}'.format(best_acc))
logger.info('Training {} complete in {:.0f}m {:.0f}s'.format(model_name, time_elapsed // 60, time_elapsed % 60))
logger.info('Best test Top-1 Acc: {:4f}'.format(best_acc))

# load best model weights
model.load_state_dict(best_model_weights)
Expand Down
99 changes: 99 additions & 0 deletions rotnet/util/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
import os

import torch
from torch.nn.parallel import DistributedDataParallel


class CheckPointer:
_last_checkpoint_name = 'last_checkpoint.txt'

def __init__(self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_dir = save_dir
self.save_to_disk = save_to_disk
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger

def save(self, name, **kwargs):
if not self.save_dir:
return

if not self.save_to_disk:
return

data = {}
if isinstance(self.model, DistributedDataParallel):
data['model'] = self.model.module.state_dict()
else:
data['model'] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)

save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)

self.tag_last_checkpoint(save_file)

def load(self, f=None, use_latest=True):
if self.has_checkpoint() and use_latest:
# override argument with existing checkpoint
f = self.get_checkpoint_file()
if not f:
# no checkpoint could be found
self.logger.info("No checkpoint found.")
return {}

self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
model = self.model
if isinstance(model, DistributedDataParallel):
model = self.model.module

model.load_state_dict(checkpoint.pop("model"))
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler:
self.logger.info("Loading scheduler from {}".format(f))
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))

# return any further checkpoint data
return checkpoint

def get_checkpoint_file(self):
save_file = os.path.join(self.save_dir, self._last_checkpoint_name)
try:
with open(save_file, "r") as f:
last_saved = f.read()
last_saved = last_saved.strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
last_saved = ""
return last_saved

def has_checkpoint(self):
save_file = os.path.join(self.save_dir, self._last_checkpoint_name)
return os.path.exists(save_file)

def tag_last_checkpoint(self, last_filename):
save_file = os.path.join(self.save_dir, self._last_checkpoint_name)
with open(save_file, "w") as f:
f.write(last_filename)

def _load_file(self, f):
return torch.load(f, map_location=torch.device("cpu"))
20 changes: 20 additions & 0 deletions rotnet/util/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
import os
import sys


def setup_logger(name, save_dir=None):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)

stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if save_dir:
fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
34 changes: 34 additions & 0 deletions tools/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/22 下午4:20
@file: predict.py
@author: zj
@description:
"""

import cv2
import torch
from rotnet.data.build import build_test_transform
from rotnet.model.build import build_model
from rotnet.util.checkpoint import CheckPointer

if __name__ == '__main__':
epoches = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = build_model(num_classes=360).to(device)
output_dir = './outputs'
checkpointer = CheckPointer(model, save_dir=output_dir)
checkpointer.load()

transform = build_test_transform()
img_path = 'imgs/RotNet.png'
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
print(img.shape)
res_img = transform(img).unsqueeze(0)
print(res_img.shape)

outputs = model(res_img.to(device))
_, preds = torch.max(outputs, 1)
print(preds)
14 changes: 11 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
@description:
"""

import os
import torch

from rotnet.data.build import build_dataset, build_transform, build_dataloader
from rotnet.data.build import build_dataset, build_train_transform, build_dataloader
from rotnet.model.build import build_model, build_criterion
from rotnet.optim.build import build_optimizer, build_lr_scheduler
from rotnet.engine.build import train_model
from rotnet.util.checkpoint import CheckPointer

if __name__ == '__main__':
epoches = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

transform, target_transform = build_transform()
transform, target_transform = build_train_transform()
data_dir = './data/'
data_sets, data_sizes = build_dataset(data_dir, transform, target_transform)
data_loaders = build_dataloader(data_sets)
Expand All @@ -28,5 +30,11 @@
optimizer = build_optimizer(model)
lr_scheduler = build_lr_scheduler(optimizer)

train_model('MobileNet_v2', model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes,
output_dir = './outputs'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
checkpointer = CheckPointer(model, optimizer=optimizer, scheduler=lr_scheduler, save_dir=output_dir,
save_to_disk=True, logger=None)

train_model('MobileNet_v2', model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes, checkpointer,
epoches=epoches, device=device)

0 comments on commit b50a960

Please sign in to comment.