From 9c265ad11b43cfe761602c8d6c25dcbfb0f099d0 Mon Sep 17 00:00:00 2001 From: NinV Date: Sun, 29 Aug 2021 16:45:19 +0900 Subject: [PATCH] improve edge weight --- .../networks/graph_connectivity_model.py | 21 +++++++++++++------ scripts/train_300w.py | 2 +- scripts/train_wflw.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/libs/models/networks/graph_connectivity_model.py b/libs/models/networks/graph_connectivity_model.py index 9ad756b..9b92dcf 100644 --- a/libs/models/networks/graph_connectivity_model.py +++ b/libs/models/networks/graph_connectivity_model.py @@ -56,15 +56,24 @@ def forward(self, x): class EdgeWeights(nn.Module): - def __init__(self, embedding_size=1, hidden=4): + def __init__(self, class_embedding_size, hidden_sizes): super(EdgeWeights, self).__init__() - self.linear = Linear(embedding_size + 2, hidden, activation="relu") - self.out = Linear(hidden, 1, activation="relu") + layers = [] + # self.linear = Linear(embedding_size + 2, hidden, activation="relu") + current_dim = class_embedding_size + 2 # +2 for pair of node confidences + for h in hidden_sizes: + layers.append(Linear(current_dim, h, activation="relu")) + current_dim = h + self.hidden_layers = nn.ModuleList(layers) + self.out = Linear(current_dim, 1, activation="relu") def forward(self, x): - x = self.linear(x) - x = self.out(x) - return x + # x = self.linear(x) + # x = self.out(x) + for layer in self.hidden_layers: + x = layer(x) + out = self.out(x) + return out class VisualFeatureEmbedding(nn.Module): diff --git a/scripts/train_300w.py b/scripts/train_300w.py index d0d2282..280f9b6 100644 --- a/scripts/train_300w.py +++ b/scripts/train_300w.py @@ -31,7 +31,7 @@ graph_model_config = {"num_classes": 68, "embedding_hidden_sizes": [32], "class_embedding_size": 1, - "edge_hidden_size": 4, + "edge_hidden_size": [128, 64, 64], # "visual_feature_dim": 1920, # Stacked Hourglass "visual_feature_dim": 270, # HRNet18 "visual_hidden_sizes": [512, 128, 32], diff --git a/scripts/train_wflw.py b/scripts/train_wflw.py index e8b6352..1a9ba33 100644 --- a/scripts/train_wflw.py +++ b/scripts/train_wflw.py @@ -28,7 +28,7 @@ graph_model_config = {"num_classes": 98, "embedding_hidden_sizes": [32], "class_embedding_size": 1, - "edge_hidden_size": 4, + "edge_hidden_size": [128, 64, 64], # "visual_feature_dim": 1920, # Stacked Hourglass "visual_feature_dim": 270, # HRNet18 "visual_hidden_sizes": [512, 128, 32],