-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
138 lines (113 loc) · 5.83 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import torch.nn.functional as F
class NETWORK(nn.Module):
def __init__(
self,
input_dim,
latent_dim,
hidden_dim,
nhead=1,
num_encoder_layers=1,
):
super().__init__()
self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=nhead, dim_feedforward=hidden_dim)
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)
self.derivatives_dim = input_dim
self.probabilities_dim = 4*input_dim//2
self.derivative_decoder = nn.Sequential(
nn.Linear(input_dim, self.derivatives_dim), #3 output kinetic parameters for each gene
)
self.probabilities_decoder = nn.Linear(input_dim, self.probabilities_dim) #4 output probabilities for each gene
self.v_u = None
self.v_s = None
self.v_u_pos = None
self.v_s_pos = None
self.pp = None
self.nn = None
self.pn = None
self.np = None
def forward(self, x):
x = x.unsqueeze(1) # Add a sequence dimension
z = self.encoder(x)
z = z.squeeze(1) # Remove the sequence dimension
self.derivatives = self.derivative_decoder(z)
self.v_u_pos, self.v_s_pos = torch.split(self.derivatives, self.derivatives_dim // 2, dim=1)
v_u_neg = -1 * self.v_u_pos
v_s_neg = -1 * self.v_s_pos
p_sign = self.probabilities_decoder(z)
p_sign = p_sign.view(-1, self.probabilities_dim//4, 4)
p_sign = F.softmax(p_sign, dim=-1)
self.pp = p_sign[:,:,0]
self.nn = p_sign[:,:,1]
self.pn = p_sign[:,:,2]
self.np = p_sign[:,:,3]
unspliced, spliced = torch.split(x.squeeze(1), x.size(2) // 2, dim=1)
self.v_u = self.v_u_pos * self.pp + v_u_neg * self.nn + self.v_u_pos * self.pn + v_u_neg * self.np
self.v_s = self.v_s_pos * self.pp + v_s_neg * self.nn + v_s_neg * self.pn + self.v_s_pos * self.np
unspliced_pred = unspliced + self.v_u
spliced_pred = spliced + self.v_s
self.prediction = torch.cat([unspliced_pred, spliced_pred], dim=1)
self.out_dic = {
"x" : x.squeeze(1),
"pred" : self.prediction,
"v_u" : self.v_u,
"v_s" : self.v_s,
"v_u_pos" : self.v_u_pos,
"v_s_pos" : self.v_s_pos,
"pp" : self.pp,
"nn" : self.nn,
"pn" : self.pn,
"np" : self.np
}
return self.out_dic
def heuristic_loss(
self,
adata,
x,
batch_indices,
lambda1,
lambda2,
out_dic,
device,
K):
x = out_dic["x"]
prediction_nn = out_dic["pred"]
reference_data = x #fetch the GE data of the samples in the batch
neighbor_indices = adata.uns["indices"][batch_indices,1:K] #fetch the nearest neighbors
neighbor_data_u = torch.from_numpy(adata.layers["Mu"][neighbor_indices]).to(device)
neighbor_data_s = torch.from_numpy(adata.layers["Ms"][neighbor_indices]).to(device)
neighbor_data = torch.cat([neighbor_data_u, neighbor_data_s], dim=2) #fetch the GE data of the neighbors for each sample in the batch
model_prediction_vector = prediction_nn - reference_data #compute the difference vector of the model prediction vs the input samples
neighbor_prediction_vectors = neighbor_data - reference_data.unsqueeze(1) #compute the difference vector of the neighbor data vs the input samples
# Normalize the vectors cell-wise
model_prediction_vector_normalized = F.normalize(model_prediction_vector, p=2, dim=1)
neighbor_prediction_vectors_normalized = F.normalize(neighbor_prediction_vectors, p=2, dim=2)
# Calculate the norms of the normalized vectors
model_prediction_vector_norms = torch.norm(model_prediction_vector_normalized, p=2, dim=1)
neighbor_prediction_vectors_norms = torch.norm(neighbor_prediction_vectors_normalized, p=2, dim=2)
# Assertions to ensure each vector is a unit vector, considering a small tolerance
tolerance = 1e-4 # Adjust the tolerance if needed
#assert torch.allclose(model_prediction_vector_norms, torch.ones_like(model_prediction_vector_norms), atol=tolerance), "Model prediction vectors are not properly normalized"
#assert torch.allclose(neighbor_prediction_vectors_norms, torch.ones_like(neighbor_prediction_vectors_norms), atol=tolerance), "Neighbor prediction vectors are not properly normalized"
cos_sim = F.cosine_similarity(neighbor_prediction_vectors_normalized, model_prediction_vector_normalized.unsqueeze(1), dim=-1)
aggr, _ = cos_sim.max(dim=1)
cell_loss = 1 - aggr
heuristic_loss = torch.mean(cell_loss) # compute the batch loss
discrepancy_loss = 0
for p in ["pp", "nn", "pn", "np"]:
discrepancy_loss += (torch.tensor(0.25, device=device).expand_as(out_dic[p]) - out_dic[p]) ** 2
discrepancy_loss = (discrepancy_loss / 4).mean()
weighted_heuristic_loss = lambda1 * heuristic_loss
weighted_discrepancy_loss = lambda2 * discrepancy_loss
total_loss = weighted_heuristic_loss + discrepancy_loss
losses_dic = {
"heuristic_loss" : heuristic_loss,
"heuristic_loss_weighted" : weighted_heuristic_loss,
"cell_loss" : cell_loss,
"uniform_p_loss" : discrepancy_loss,
"uniform_p_loss_weighted" : weighted_discrepancy_loss,
"total_loss" : total_loss,
"batch_indices" : batch_indices,
}
return losses_dic