-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcalculate_dfim.py
191 lines (169 loc) · 5.73 KB
/
calculate_dfim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Calculate Deep Feature Interaction Maps (DFIM) for a bunch of examples.
"""
import argparse
import glob
import logging
import os
import numpy as np
import tqdm
import utils
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"
logging.getLogger("tensorflow").setLevel(logging.FATAL)
import shap
import tensorflow as tf
import clipnet
from calculate_deepshap import (
create_explainers,
load_seqs,
profile_contrib,
quantity_contrib,
)
# This will fix an error message for running tf.__version__==2.5
shap.explainers._deep.deep_tf.op_handlers["AddV2"] = (
shap.explainers._deep.deep_tf.passthrough
)
tf.compat.v1.disable_v2_behavior()
def calculate_dfim(explainers, rec, start, stop, check_additivity=True, silence=False):
major_seq = rec.seq
major_twohot = np.expand_dims(utils.TwoHotDNA(major_seq).twohot, axis=0)
dfim_range = list(range(start, stop))
major_shap = np.array(
[
explainer.shap_values(major_twohot, check_additivity=check_additivity)[0]
for explainer in explainers
]
).mean(axis=0)
mutations_per_pos = [utils.get_mut_bases(major_seq[i]) for i in dfim_range]
dfim = []
for i in tqdm.trange(
len(mutations_per_pos),
desc=f"Calculating DFIM for {rec.name} (pos {start}-{stop})",
disable=silence,
):
muts = mutations_per_pos[i]
mut_seqs = [
major_seq[: dfim_range[i]] + mut + major_seq[dfim_range[i] + 1 :]
for mut in muts
]
mut_twohot = np.array([utils.TwoHotDNA(mut_seq).twohot for mut_seq in mut_seqs])
mut_shap = np.array(
[
explainer.shap_values(mut_twohot, check_additivity=check_additivity)[0]
for explainer in explainers
]
).mean(axis=0)
fis = (
np.abs(major_shap * major_twohot / 2 - mut_shap * mut_twohot / 2)
.sum(axis=2)
.max(axis=0)[start:stop]
)
dfim.append(fis)
return np.array(dfim)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("fasta_fp", type=str, help="Fasta file path.")
parser.add_argument("score_fp", type=str, help="Where to write DFIM scores.")
parser.add_argument(
"--model_fp",
type=str,
default=None,
help="Model file path. If None, will use all models in model_dir. Overwrites model_dir.",
)
parser.add_argument(
"--model_dir",
type=str,
default="ensemble_models/",
help="Directory to load models from",
)
parser.add_argument(
"--mode",
type=str,
default="quantity",
help="Calculate contrib scores for quantity or profile.",
)
parser.add_argument(
"--start", type=int, default=400, help="Start position for calculating DFIM."
)
parser.add_argument(
"--stop", type=int, default=600, help="Stop position for calculating DFIM."
)
parser.add_argument(
"--background_fp",
type=str,
default=None,
help="Background sequences (if None, will select from main seqs).",
)
parser.add_argument(
"--n_subset",
type=int,
default=20,
help="Maximum number of sequences to use as background. \
Default is 20 to ensure reasonably fast compute on large datasets.",
)
parser.add_argument(
"--gpu",
type=int,
default=None,
help="Index of GPU to use (starting from 0). If not invoked, uses CPU.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for selecting background sequences.",
)
parser.add_argument(
"--silence",
action="store_true",
help="Disables progress bars and other non-essential print statements.",
)
parser.add_argument(
"--skip_check_additivity",
action="store_true",
help="Disables check for additivity of shap results.",
)
args = parser.parse_args()
# Check arguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if args.model_fp is None and args.model_dir is None:
raise ValueError("Must specify either --model_fp or --model_dir.")
if args.mode == "quantity":
contrib = quantity_contrib
elif args.mode == "profile":
contrib = profile_contrib
else:
raise ValueError(f"Invalid mode: {args.mode}. Must be 'quantity' or 'profile'.")
# Load sequences ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
seqs_to_explain, twohot_background = load_seqs(
args.fasta_fp, False, args.background_fp, args.n_subset, args.seed
)
# Create explainers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
nn = clipnet.CLIPNET(n_gpus=1, use_specific_gpu=args.gpu)
if args.model_fp is None:
model_fps = list(glob.glob(os.path.join(args.model_dir, "*.h5")))
else:
model_fps = [args.model_fp]
explainers = create_explainers(model_fps, twohot_background, contrib, args.silence)
# Calculate DFIM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
print(
f"Calculating DFIM scores from pos {args.start} to {args.stop} in sequences of length {len(seqs_to_explain[0])}."
)
dfims = {
rec.name: calculate_dfim(
explainers,
rec,
args.start,
args.stop,
check_additivity=not args.skip_check_additivity,
silence=True,
)
for rec in tqdm.tqdm(
seqs_to_explain,
total=len(seqs_to_explain),
desc="Calculating DFIM",
disable=args.silence,
)
}
np.savez(args.score_fp, **dfims)
if __name__ == "__main__":
main()