-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
63 lines (55 loc) · 2.33 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
"""
Created on Thu May 7 17:13:18 2020
@author: Kartik
"""
import torch
import torch.nn as nn
import torchvision.models as models
def load_alexnet():
model = models.alexnet(pretrained=True)
return model
def get_accuracy(prototypes, embeddings, targets):
"""Compute the accuracy of the prototypical network on the test/query points.
Parameters
----------
prototypes : `torch.FloatTensor` instance
A tensor containing the prototypes for each class. This tensor has shape
`(meta_batch_size, num_classes, embedding_size)`.
embeddings : `torch.FloatTensor` instance
A tensor containing the embeddings of the query points. This tensor has
shape `(meta_batch_size, num_examples, embedding_size)`.
targets : `torch.LongTensor` instance
A tensor containing the targets of the query points. This tensor has
shape `(meta_batch_size, num_examples)`.
Returns
-------
accuracy : `torch.FloatTensor` instance
Mean accuracy on the query points.
"""
sq_distances = torch.sum((prototypes.unsqueeze(1)
- embeddings.unsqueeze(2)) ** 2, dim=-1)
_, predictions = torch.min(sq_distances, dim=-1)
return torch.mean(predictions.eq(targets).float())
def conv3x3(in_channels, out_channels, **kwargs):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
class PrototypicalNetwork(nn.Module):
def __init__(self, in_channels, out_channels, hidden_size=64):
super(PrototypicalNetwork, self).__init__()
self.in_channels = in_channels #default = 1
self.out_channels = out_channels #Dimension of the embedding/latent space (default: 64).
self.hidden_size = hidden_size #Number of channels for each convolutional layer (default: 64).
self.encoder = nn.Sequential(
conv3x3(in_channels, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, out_channels)
)
def forward(self, inputs):
embeddings = self.encoder(inputs.view(-1, *inputs.shape[2:]))
return embeddings.view(*inputs.shape[:2], -1)