Skip to content

Commit

Permalink
add score_variants function
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Jan 15, 2025
1 parent a85ee57 commit 5921ef7
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,91 @@ def reverse_models(
Path(checkpoint + ".reverse").touch()


@cli.command()
@click.option("--batch-size", type=int, default=1_00_000)
@click.argument("model-config-file", type=click.Path(exists=True))
@click.argument("data-config-file", type=click.Path(exists=True))
@click.argument("annotation-file", type=click.Path(exists=True))
@click.argument("output-file", type=click.Path())
@click.argument("checkpoint-files", type=click.Path(exists=True), nargs=-1)
def score_variants(
batch_size: int,
model_config_file: str,
data_config_file: str,
annotation_file: str,
output_file,
checkpoint_files: Tuple[str],
):
# """
# Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations.

# :param model_config_file: Path to the model configuration file.
# :type model_config_file: str
# :param data_config_file: Path to the data configuration file.
# :type data_config_file: str
# :param checkpoint_files: Paths to checkpoint files.
# :type checkpoint_files: Tuple[str]
# :return: checkpoint.reverse file is created if the model should reverse the burden score output.
# """
with open(model_config_file) as f:
model_config = yaml.safe_load(f)

with open(data_config_file) as f:
data_config = yaml.safe_load(f)

if torch.cuda.is_available():
logger.info("Using GPU")
device = torch.device("cuda")
else:
logger.info("Using CPU")
device = torch.device("cpu")

anno_cols = data_config["annotations"]
annotation_df = pd.read_parquet(annotation_file)

input_tensor = annotation_df[anno_cols].to_numpy()

n_variants = annotation_df.shape[0]
n_batches = math.ceil(n_variants / batch_size)
score_list = []

models = [
load_one_model(model_config, checkpoint, device=device)
for checkpoint in checkpoint_files
if not Path(checkpoint + ".dropped").is_file()
]

with torch.no_grad():
for start_idx in trange(0, n_variants, batch_size, total=n_batches):
i = start_idx // batch_size
for j, agg_model in enumerate(models):
this_input = input_tensor[start_idx : start_idx + batch_size]
this_scores = (
agg_model(
torch.tensor(
this_input,
dtype=torch.float,
device=device,
).reshape((this_input.shape[0], 1, len(anno_cols), 1))
)
.reshape(this_input.shape[0])
.cpu()
.numpy()
)

if j == 0:
score_list.append(this_scores)
else:
score_list[i] += this_scores

score_list[i] /= len(models)

scores = np.concatenate(score_list)
score_df = annotation_df[["id", "chrom", "pos", "ref", "alt", "gene_id"]]
score_df["DeepRVAT_score"] = scores
score_df.to_parquet(output_file)


def load_models(
config: Dict,
checkpoint_files: Tuple[str],
Expand Down

0 comments on commit 5921ef7

Please sign in to comment.