-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocessing.py
131 lines (93 loc) · 4.8 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os
from rich.progress import track
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from biopandas.pdb import PandasPdb
from collections import defaultdict
def pdb2pandas(pdb_path):
df = PandasPdb().read_pdb(pdb_path).df["ATOM"]
df["node_id"] = (
df["chain_id"]
+ df["residue_number"].map(str)
+ df["residue_name"]
)
residue_mapping = defaultdict(lambda: 8)
residue_mapping["A"] = 0
residue_mapping["T"] = 1
residue_mapping["U"] = 1
residue_mapping["C"] = 2
residue_mapping["G"] = 3
residue_mapping["DA"] = 4
residue_mapping["DT"] = 5
residue_mapping["DU"] = 5
residue_mapping["DC"] = 6
residue_mapping["DG"] = 7
element_mapping = {'C':0, 'N':1, 'O':2, 'P':3}
df["residue_id"] = df["residue_name"].map(residue_mapping)
df["element_id"] = df["element_symbol"].map(element_mapping)
return df
def augment_pc(points):
theta_x, theta_y, theta_z = tuple(np.random.rand(3) * 2 *np.pi)
rotate_x = np.array([[1,0,0],
[0, np.cos(theta_x), -np.sin(theta_x)],
[0, np.sin(theta_x), np.cos(theta_x)]])
rotate_y = np.array([[np.cos(theta_y), 0, -np.sin(theta_y)],
[0,1,0],
[np.sin(theta_y), 0 , np.cos(theta_y)]])
rotate_z = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
[np.sin(theta_z), np.cos(theta_z), 0],
[0,0, 1]])
return points @ (rotate_z @ rotate_y @ rotate_x).T
def normalize_pc(points):
centroid = np.mean(points, axis=0)
points -= centroid
furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))
points /= furthest_distance
return points
def pad_pc(points, amount):
coords_padded = np.pad(points, ((0, amount),(0, 0)), 'constant', constant_values=(0, 0))
return coords_padded
def create_pyg_datalist(data_path, max_pointcloud_size, withfeatures=True, augment_num=10):
datalist = []
shape_list = []
for filename in track(os.listdir(data_path), description="[cyan]Creating PyG Data from RNA pdb files"):
if filename.endswith("pdb"):
pdb_df = pdb2pandas(os.path.join(data_path, filename))
atom_number = torch.from_numpy(pdb_df[["atom_number"]].to_numpy())
# residue_ids = pdb_df['residue_id'].str.split(' ',expand=True).to_numpy().astype(int)
residue_ids = pdb_df[["residue_id"]].to_numpy()
element_ids = pdb_df[["element_id"]].to_numpy()
node_id = pdb_df[["node_id"]].to_numpy()
raw_coordinates = pdb_df[["x_coord","y_coord","z_coord"]].to_numpy() # should be shape (num_atoms,3)
if raw_coordinates.shape[0] <= max_pointcloud_size and raw_coordinates.shape[0] >=100:
paddingamount = max_pointcloud_size - raw_coordinates.shape[0]
shape_list.append(raw_coordinates.shape[0])
normalized_coordinates = normalize_pc(raw_coordinates)
temp_coordinates = normalized_coordinates.copy()
for i in range(augment_num):
feature_coordinates = np.concatenate((temp_coordinates, residue_ids, element_ids), axis=1) if withfeatures else temp_coordinates
# print(feature_coordinates,i)
padded_coordinates = pad_pc(feature_coordinates, paddingamount)
data = Data(pos=torch.from_numpy(padded_coordinates).type(torch.FloatTensor), atom_number=atom_number, y=node_id, num_nodes=padded_coordinates.shape[0])
datalist.append(data)
temp_coordinates = augment_pc(normalized_coordinates.copy())
return datalist, shape_list
def create_dataloaders(data_list, batch_size=1, with_val= False):
X_train, X_t = train_test_split(data_list, test_size=0.2, random_state=42)
if with_val:
X_val, X_test = train_test_split(X_t, test_size=0.5, random_state=42)
assert len(X_train) + len(X_val) + len(X_test) == len(data_list)
train_data_loader = DataLoader(X_train, batch_size=batch_size)
val_data_loader = DataLoader(X_val, batch_size=batch_size)
test_data_loader = DataLoader(X_test, batch_size=batch_size)
return train_data_loader, test_data_loader, val_data_loader
else:
assert len(X_train) + len(X_t)== len(data_list)
train_data_loader = DataLoader(X_train, batch_size=batch_size)
test_data_loader = DataLoader(X_t, batch_size=batch_size)
return train_data_loader, test_data_loader, None