-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel.py
91 lines (72 loc) · 2.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torchvision.models import resnet50
import torch.nn as nn
def get_model(sources, device, checkpoint=None):
""" Returns a model suitable for the given sources """
if sources == "S2":
return get_S2_no2_model(device, checkpoint)
elif sources == "S2S5P":
return get_S2S5P_no2_model(device, checkpoint)
def get_S2_no2_model(device, checkpoint=None):
""" Returns a ResNet for Sentinel-2 data with a regression head """
backbone = get_resnet_model(device, checkpoint)
backbone.fc = nn.Identity()
head = nn.Sequential(nn.Linear(2048, 512), nn.ReLU(), nn.Linear(512, 1))
regression_model = ResnetRegressionHead(backbone, head)
return regression_model
def get_S2S5P_no2_model(device, checkpoint=None):
""" Returns a model with two input streams
(one for S2, one for S5P) followed by a dense
regression head """
backbone_S2 = get_resnet_model(device, checkpoint)
backbone_S2.fc = nn.Identity()
backbone_S5P = nn.Sequential(nn.Conv2d(1, 10, 3),
nn.ReLU(),
nn.MaxPool2d(3),
nn.Conv2d(10, 15, 5),
nn.ReLU(),
nn.MaxPool2d(3),
nn.Flatten(),
nn.Linear(1815, 128),
)
head = nn.Sequential(nn.Linear(2048+128, 544), nn.ReLU(), nn.Linear(544, 1))
regression_model = MultiBackboneRegressionHead(backbone_S2, backbone_S5P, head)
return regression_model
def get_resnet_model(device, checkpoint=None):
"""
create a resnet50 model, optionally load pretrained checkpoint
and pass it to the device
"""
model = resnet50(pretrained=False, num_classes=19)
model.conv1 = torch.nn.Conv2d(12, 64, kernel_size=(3,3), stride=(2,2), padding=(3,3), bias=False)
model.to(device)
if checkpoint is not None:
model.load_state_dict(torch.load(checkpoint, map_location=device))
return model
class ResnetRegressionHead(nn.Module):
""" Wrapper class to put a regression head on
a resnet model """
def __init__(self, backbone, head):
super(ResnetRegressionHead, self).__init__()
self.backbone = backbone
self.head = head
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x
class MultiBackboneRegressionHead(nn.Module):
""" Wrapper class that combines features extracted
from two inputs (S2 and S5P) with a regression head """
def __init__(self, backbone_S2, backbone_S5P, head):
super(MultiBackboneRegressionHead, self).__init__()
self.backbone_S2 = backbone_S2
self.backbone_S5P = backbone_S5P
self.head = head
def forward(self, x):
s5p = x.get("s5p")
x = x.get("img")
x = self.backbone_S2(x)
s5p = self.backbone_S5P(s5p)
x = torch.cat((x, s5p), dim=1)
x = self.head(x)
return x