Skip to content

Commit

Permalink
r1 global imp recheck
Browse files Browse the repository at this point in the history
  • Loading branch information
josejimenezluna committed Jan 14, 2021
1 parent f3406d7 commit d0ab457
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions molgrad/global_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
if __name__ == "__main__":
for data in TASK_GUIDE.keys():
print(f"Now computing oof global importances for dataset {data}...")
global_importances_oof = []

with open(
os.path.join(DATA_PATH, f"{data}", f"data_{data}.pt"), "rb"
Expand All @@ -32,26 +31,61 @@
else:
output_f = None

# Using production model
print("Production model running...")
w_path = os.path.join(MODELS_PATH, f"{data}_noHs.pt")

model = MPNNPredictor(
node_in_feats=49,
edge_in_feats=10,
global_feats=4,
n_tasks=1,
output_f=output_f,
).to(DEVICE)

model.load_state_dict(torch.load(w_path, map_location=DEVICE))

gis = [
molecule_importance(MolFromInchi(inchi), model)[4] for inchi in tqdm(inchis)
]
global_importances = np.vstack(gis)
np.save(
os.path.join(DATA_PATH, f"importances{data}.npy"), arr=global_importances
)

# Using oof models
global_importances_oof = []

kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

for idx_split, (_, idx_test) in enumerate(kf.split(inchis)):
print("Split {}/{} running...".format(idx_split + 1, N_FOLDS))
inchis_test, values_test = inchis[idx_test].tolist(), values[idx_test, :].squeeze().tolist()
inchis_test, values_test = (
inchis[idx_test].tolist(),
values[idx_test, :].squeeze().tolist(),
)

w_path = os.path.join(MODELS_PATH, f"{data}_noHs_fold{idx_split}.pt")

model = MPNNPredictor(node_in_feats=49,
edge_in_feats=10,
global_feats=4,
n_tasks=1,
output_f=output_f).to(DEVICE)
model = MPNNPredictor(
node_in_feats=49,
edge_in_feats=10,
global_feats=4,
n_tasks=1,
output_f=output_f,
).to(DEVICE)

model.load_state_dict(torch.load(w_path,
map_location=DEVICE))
model.load_state_dict(torch.load(w_path, map_location=DEVICE))

gis = [molecule_importance(MolFromInchi(inchi), model)[4] for inchi in tqdm(inchis_test)]
gis = [
molecule_importance(MolFromInchi(inchi), model)[4]
for inchi in tqdm(inchis_test)
]
global_importances_oof.extend(gis)

global_importances_oof = np.vstack(global_importances_oof)

np.save(os.path.join(DATA_PATH, f"importances_oof_{data}.npy"), arr=global_importances_oof)
np.save(
os.path.join(DATA_PATH, f"importances_oof_{data}.npy"),
arr=global_importances_oof,
)

0 comments on commit d0ab457

Please sign in to comment.