Skip to content

How to support Dream Creator models in other projects

ProGamerGov edited this page Oct 22, 2020 · 7 revisions

How to support Dream-Creator models in other projects


Loading Models

Models only require the inceptionv1_caffe.py from the utils directory/folder and can be loaded like this:

# Load model file
checkpoint = torch.load(<model_name>, map_location='cpu')

# Load what model base to use
mode = checkpoint['base_model'] # String

# Load list that contains (normalization mean, normalization standard deviation) in RGB format
norm_vals = checkpoint['normalize_params'] # List of floats for preprocessing and deprocessing

# Load number of classes in FC layers
n_classes = checkpoint['num_classses'] # Integer

# Determine whether or not the model has branches
has_branches = checkpoint['has_branches'] # Boolean True or False

# How many epochs the model was trained for
n_epochs = checkpoint['epoch'] # Integer

# Optimizer's State when model was saved
optimizer = checkpoint['optimizer_state_dict'] # state_dict used for continuing training

# Learning Rate Scheduler's State when model was saved
lrscheduler = checkpoint['lrscheduler_state_dict'] # state_dict used for continuing training

# The color correlation matrix needed for color decorrelation
color_matrix = checkpoint['color_correlation_svd_sqrt'] # Tensor used for color correlation.

Then to load the model:

import InceptionV1_Caffe

# Create model definition
cnn = InceptionV1_Caffe(n_classes, mode=mode, load_branches=has_branches)

# Load pretrained model state_dict
cnn.load_state_dict(checkpoint['model_state_dict'])

Preprocessing and Deprocessing

By default, models store normalization values in RGB format, but they require BGR format. So a simple transform is used after/before normalization.

Here's an example of DeepDream preprocessing and deprocessing:

# Basic preprocessing
image_size = (224,224) 
mean_vals = norm_vals[0]

tensor_transforms = transforms.Compose([transforms.Resize(image_size), 
                                                                 transforms.ToTensor(), # Convert PIL Image to tensor
                                                                 transforms.Lambda(lambda x: x*255), # Change val range from 0-1 to 0-255
                                                                 transforms.Normalize(mean=mean_vals, std=[1,1,1]) # Normalize tensor mean
                                                                 transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), # RGB to BGR
                                                                 ])
input_tensor = tensor_transforms(image).unsqueeze(0)

# Basic deprocessing
mean_vals = [n * -1 for n in norm_vals[0]] # Make mean values negative for subtraction
tensor_transforms = transforms.Compose([
                                                                  transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]) # BGR to RGB
                                                                  transforms.Normalize(mean=mean_vals, std=[1,1,1]) # Subtract mean
                                                                  transforms.Lambda(lambda x: x/255) # Change val range from 0-255 to 0-1
                                                                 ])
output_tensor = tensor_transforms (output_tensor.squeeze(0).cpu())

Layer Selection

Layer names are separated into top level names and lower level names.

Top Level Names:

Top level names are:

conv1, conv1_relu, pool1, localresponsenorm1, conv2, conv2_relu, conv3, conv3_relu, localresponsenorm2, pool2, mixed3a, mixed3b, pool3, mixed4a, mixed4b, mixed4c, mixed4d, mixed4e, pool4, mixed5a, mixed5b, avgpool, drop, fc, aux1, aux2 

Lower Level Names:

Names with mixed in them contain their own layer names:

conv_1x1, conv_1x1_relu, conv_3x3_reduce, conv_3x3_reduce_relu, conv_3x3, conv_3x3_relu, conv_5x5_reduce, conv_5x5_reduce_relu, conv_5x5, conv_5x5_relu, pool, pool_proj, pool_proj_relu 

Like for example mixed4a.conv_3x3_reduce.

Names with aux in them contain their own layer names:

avg_pool, loss_conv, loss_conv_relu, loss_fc, loss_fc_relu, loss_dropout, loss_classifier 

Like for example aux1.loss_classifier.

Layer hooks:

For visualization of layers, one option that can be used is registering hooks:

loss_module = # class for calculating/recording loss

# Top level layer example
getattr(cnn, 'mixed5a').register_forward_hook(loss_module)

# Top and lower level name
getattr(getattr(cnn, 'mixed5a'), 'conv_3x3_reduce').register_forward_hook(loss_module)

# Collect loss from loss_module