diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index d77cda4c..77f52fe3 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -10,6 +10,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union import click +import h5py import numpy as np import pandas as pd import pyranges as pr @@ -47,6 +48,9 @@ AGG_FCT = {"mean": np.mean, "max": np.max} +IntLike = Union[int, np.int16, np.int32, np.int64] +PathLike = Union[str, Path] + def get_burden( batch: Dict, @@ -290,7 +294,7 @@ def compute_xy( zarr.save_array(y_file, y) -def make_regenie_skat_input_( +def make_skat_input_( debug: bool, data_config_file: Union[str, Path], model_config_file: Union[str, Path], @@ -308,278 +312,166 @@ def make_regenie_skat_input_( with open(model_config_file) as f: model_config = yaml.safe_load(f) + ## Make group file + + logger.info("Reading variant metadata") annotation_file = data_config["association_testing_data"]["dataset_config"][ "annotation_file" ] thresholds = data_config["association_testing_data"]["dataset_config"][ "rare_embedding" - ][config]["thresholds"] - maf = pd.read_parquet( - annotation_file, columns=["id", "chrom", "pos", "ref", "alt", "MAF", "gene_id"] - ) - for col, op in thresholds.items(): - maf = maf.query(f"{col} {op}") - - # TODO: pass gene_id as well! - rare_variant_ids = maf["id"] - - ## Make REGENIE annotation file - - regenie_annotations = maf.copy() - regenie_annotations["chrom_pos_ref_alt"] = maf["chrom"].str.cat( - maf[["pos", "ref", "alt"]], sep=":" + ]["config"]["thresholds"] + annotation_cols = data_config["association_testing_data"]["dataset_config"][ + "rare_embedding" + ]["config"]["annotations"] + variant_file = data_config["association_testing_data"]["variant_file"] + annotations = pd.read_parquet( + annotation_file, + columns=list( + set(["id", "MAF", "gene_id"]) + | set(thresholds.keys()) + | set(annotation_cols) + ), ) - regenie_annotations.to_csv( - out_dir / "regenie_annotations.tsv", sep="\t", index=False, header=False + variants = pd.merge( + annotations, + pd.read_parquet(variant_file, columns=["id", "chrom", "pos", "ref", "alt"]), ) - - ## Make REGENIE set list file - - regenie_sets = ( - regenie_annotations[["gene_id", "chrom", "pos", "chrom_pos_ref_alt"]] - .group_by("gene_id") - .agg( - chrom=pd.NamedAgg(column="chrom", aggfunc="first"), - pos=pd.NamedAgg(column="pos", aggfunc="min"), - variants=pd.NamedAgg( - column="chrom_pos_ref_alt", aggfunc=lambda x: ",".join(x) - ), + for op in thresholds.values(): + variants = variants.query(op) + + if variants["gene_id"].dtype == object: + variants = ( + variants.loc[variants["gene_id"].apply(len) > 0] + .explode("gene_id") + .astype({"gene_id": int}) ) - ) - regenie_sets.to_csv( - out_dir / "regenie_set_list.tsv", sep="\t", index=False, header=False - ) - ## Make AAF file - ## - Score each variant, use 1 - (DeepRVAT score) + variants = variants.set_index(["gene_id", "id"]) logger.info("Scoring all rare variants") scores = score_variants( - rare_variant_ids, model_config_file, data_config_file, checkpoint_files + variants[annotation_cols], model_config_file, data_config_file, checkpoint_files ) - # TODO: merge scores into maf dataframe; need to return id, gene_id from score_variants - - ## Make BGEN file - ## - Make zarr file (in memory?) to transpose genotype file (rare variants only) - ## - Make BGEN file - - ## Make BGEN - - # Load data - logger.info("Loading computed burdens, covariates, phenotypes and metadata") - - phenotype_names = [p[0] for p in phenotype] - dataset_files = [p[1] for p in phenotype] - xy_dirs = [p[2] for p in phenotype] - - # load only first sample_ids zarr here - sample_ids = zarr.load(xy_dirs[0] / "sample_ids.zarr") - covariates = zarr.load(xy_dirs[0] / "x.zarr") - ys = [zarr.load(b / "y.zarr") for b in xy_dirs] - - if debug: - sample_ids = sample_ids[:1000] - covariates = covariates[:1000] - ys = [y[:1000] for y in ys] - - n_samples = sample_ids.shape[0] - assert covariates.shape[0] == n_samples - # assert that ALL y.zarrs are the same lengths as the single sample_ids zarr loaded above - assert all([y.shape[0] == n_samples for y in ys]) - - # Sanity check: sample_ids and covariates should be consistent for all phenotypes - if not debug: - for i in range(1, len(phenotype)): - assert np.array_equal(sample_ids, zarr.load(xy_dirs[i] / "sample_ids.zarr")) - this_cov = zarr.load(xy_dirs[i] / "x.zarr") - assert this_cov.shape == covariates.shape - unequal_rows = np.array( - [ - i - for i in range(covariates.shape[0]) - if not np.array_equal(covariates[i], this_cov[i]) - ] - ) - for i in unequal_rows: - assert np.all( - (np.abs(covariates[i] - this_cov[i]) < 1e-6) - | (np.isnan(covariates[i]) & np.isnan(this_cov[i])) - ) - - sample_df = pd.DataFrame({"FID": sample_ids, "IID": sample_ids}) - - if not skip_covariates: - ## Make covariate file - logger.info(f"Creating covariate file {covariate_file}") - with open(dataset_files[0], "rb") as f: - dataset = pickle.load(f) - - covariate_names = dataset.x_phenotypes - cov_df = pd.DataFrame(covariates, columns=covariate_names) - cov_df = pd.concat([sample_df, cov_df], axis=1) - cov_df.to_csv(covariate_file, sep=" ", index=False, na_rep="NA") - - if not skip_phenotypes: - ## Make phenotype file - logger.info(f"Creating phenotype file {phenotype_file}") - pheno_df_list = [] - for p, y in zip(phenotype_names, ys): - pheno_df_list.append(pd.DataFrame({p: y.squeeze()})) + variants["chrom_pos_ref_alt"] = variants["chrom"].str.cat( + variants[["pos", "ref", "alt"]].astype(str), sep=":" + ) - pheno_df = pd.concat([sample_df] + pheno_df_list, axis=1) - pheno_df.to_csv(phenotype_file, sep=" ", index=False, na_rep="NA") + logger.info("Writing group file") + with open(Path(out_dir) / "groups.txt", "w") as f: + scored_variants = pd.merge( + variants[["chrom_pos_ref_alt"]], scores, left_index=True, right_index=True + ) + for k, v in tqdm(scored_variants.reset_index().groupby("gene_id")): + f.write(f"{k} var " + " ".join(v["chrom_pos_ref_alt"]) + "\n") - if not skip_burdens: - burden_file, gene_file, b_sample_file = burdens_genes_samples + f.write("f{k} anno") + for _ in range(len(v)): + f.write(" deeprvat") + f.write("\n") - genes = np.load(gene_file) - n_genes = genes.shape[0] + # TODO: Could also transform scores, e.g., beta(score, 25, 1) + f.write(f"{k} weight " + " ".join(v["DeepRVAT_score"]) + "\n") - sample_ids = zarr.load( - b_sample_file - ) # Might be different from those for the phenotypes - n_samples = sample_ids.shape[0] + ## Make BGEN file - ## Make sample file - logger.info(f"Creating sample file {sample_file}") - sample_df = pd.DataFrame({"FID": sample_ids, "IID": sample_ids}) - samples_out = pd.concat( - [ - pd.DataFrame({"ID_1": 0, "ID_2": 0}, index=[0]), - sample_df.rename( - columns={ - "FID": "ID_1", - "IID": "ID_2", - } - ), - ] + # Transpose genotype file to be variants x samples + logger.info("Transforming genotype file") + samples = list(pd.read_parquet("phenotype_file").index.astype(str)) + max_variant_id = variants["id"].max() + with h5py.File(genotype_file, "r") as g: + gt_matrix = g["gt_matrix"] + variant_matrix = g["variant_matrix"] + n_samples = gt_matrix.shape[0] + max_samples = n_samples / 50 # TODO: Compute from max MAF + sample_zarr = -zarr.ones( + (max_variant_id + 1, max_samples), chunks=(1000, 1000), dtype=np.int8 ) - samples_out.to_csv(sample_file, sep=" ", index=False) - - burdens_zarr = zarr.open(burden_file) - if not debug: - assert burdens_zarr.shape[0] == n_samples - assert burdens_zarr.shape[1] == n_genes - - if average_repeats: - logger.info("Averaging burdens across all repeats") - burdens = np.zeros((n_samples, n_genes)) - for repeat in trange(burdens_zarr.shape[2]): - burdens += burdens_zarr[:n_samples, :, repeat] - burdens = burdens / burdens_zarr.shape[2] - else: - logger.info(f"Using burdens from repeat {repeat}") - assert repeat < burdens_zarr.shape[2] - burdens = burdens_zarr[:n_samples, :, repeat] - - # Read GTF file and get positions for pseudovariants (center of interval [Start, End]) - logger.info( - f"Assigning positions to pseudovariants based on provided GTF file {gtf}" + gt_zarr = -zarr.ones( + (max_variant_id + 1, max_samples), chunks=(1000, 1000), dtype=np.int8 ) - gene_pos = pr.read_gtf(gtf) - gene_pos = gene_pos[ - (gene_pos.Feature == "gene") & (gene_pos.gene_type == "protein_coding") - ][["Chromosome", "Start", "End", "gene_id"]].as_df() - gene_pos = gene_pos.set_index("gene_id") - gene_metadata = pd.read_parquet(gene_metadata_file).set_index("id") - this_gene_pos = gene_pos.loc[gene_metadata.loc[genes, "gene"]] - pseudovar_pos = (this_gene_pos.End - this_gene_pos.Start).to_numpy().astype(int) - ensgids = this_gene_pos.index.to_numpy() + last_filled = np.zeros(max_variant_id) + for i in trange(n_samples): + rows = variant_matrix[i, :] + cols = last_filled[variant_matrix[i, :]] + sample_zarr[(rows, cols)] = samples[i] + gt_zarr[(rows, cols)] = gt_matrix[i, :] + last_filled[rows] = cols + 1 + + logger.info(f"Creating BGEN file") + variants = ( + variants[["id", "chrom", "pos", "ref", "alt"]] + .drop_duplicates(subset="id") + .set_index("id") + ) + with BgenWriter( + Path(out_dir) / "genotypes.bgen", + n_samples, + samples=samples, + metadata=f"Created from {genotype_file}", + ) as f: + for i in trange(max_variant_id + 1): + if sample_zarr[i, 0] == -1: # variant has no associated samples + continue - logger.info(f"Writing pseudovariants to {bgen}") - with BgenWriter( - bgen, - n_samples, - samples=list(sample_ids.astype(str)), - metadata="Pseudovariants containing DeepRVAT gene impairment scores. One pseudovariant per gene.", - ) as f: - for i in trange(n_genes): - varid = f"pseudovariant_gene_{ensgids[i]}" - this_burdens = burdens[:, i] # Rescale scores to be in range (0, 2) - genotypes = np.stack( - (this_burdens, np.zeros(this_burdens.shape), 1 - this_burdens), - axis=1, - ) + het_genotypes = np.zeros(n_samples, dtype=np.int8) + homalt_genotypes = np.zeros(n_samples, dtype=np.int8) + sparse_gt = gt_zarr[i] + sparse_samples = sample_zarr[i] + het_mask = sparse_gt == 1 + het_genotypes[sparse_samples[het_mask]] = 1 + hom_mask = sparse_gt == 2 + homalt_genotypes[sparse_samples[hom_mask]] = 1 + genotypes = np.stack( + ( + homalt_genotypes, + het_genotypes, + 2 - (het_genotypes + homalt_genotypes), + ), + axis=1, + ) - f.add_variant( - varid=varid, - rsid=varid, - chrom=this_gene_pos.iloc[i].Chromosome, - pos=pseudovar_pos[i], - alleles=[ - "A", - "C", - ], # TODO: This is completely arbitrary, however, we might want to match it to a reference FASTA at some point - genotypes=genotypes, - ploidy=2, - bit_depth=16, - ) + varid = variants.loc[i, "chrom_pos_ref_alt"] + f.add_variant( + varid=varid, + rsid=varid, + chrom=variants.loc[i, "chrom"], + pos=variants.loc[i, "chrom"], + alleles=[ + variants.loc[i, "ref"], + variants.loc[i, "alt"], + ], + genotypes=genotypes, + ploidy=2, + bit_depth=1, + ) @cli.command() @click.option("--debug", is_flag=True) -@click.option("--skip-covariates", is_flag=True) -@click.option("--skip-phenotypes", is_flag=True) -@click.option("--skip-burdens", is_flag=True) -@click.option( - "--burdens-genes-samples", - type=( - click.Path(path_type=Path, exists=True), - click.Path(path_type=Path, exists=True), - click.Path(path_type=Path, exists=True), - ), +@click.argument("data-config-file", type=click.Path(exists=True, path_type=Path)) +@click.argument("model-config-file", type=click.Path(exists=True, path_type=Path)) +@click.argument( + "checkpoint-files", type=click.Path(exists=True, path_type=Path), nargs=-1 ) -@click.option("--repeat", type=int, default=-1) -@click.option("--average-repeats", is_flag=True) -@click.option( - "--phenotype", - type=( - str, - click.Path(exists=True, path_type=Path), - click.Path(exists=True, path_type=Path), - ), - multiple=True, -) # phenotype_name, dataset_file, burden_dir -@click.option("--sample-file", type=click.Path(path_type=Path)) -@click.option("--bgen", type=click.Path(path_type=Path)) -@click.option("--covariate-file", type=click.Path(path_type=Path)) -@click.option("--phenotype-file", type=click.Path(path_type=Path)) -# @click.argument("dataset-file", type=click.Path(exists=True, path_type=Path)) -# @click.argument("burden-dir", type=click.Path(exists=True, path_type=Path)) -@click.argument("gene-metadata-file", type=click.Path(exists=True, path_type=Path)) -@click.argument("gtf", type=click.Path(exists=True, path_type=Path)) -def make_regenie_skat_input( +@click.argument("genotype-file", type=click.Path(exists=True, path_type=Path)) +@click.argument("out-dir", type=click.Path(exists=True, path_type=Path)) +def make_skat_input( debug: bool, - skip_covariates: bool, - skip_phenotypes: bool, - skip_burdens: bool, - burdens_genes_samples: Optional[Tuple[Path, Path, Path]], - repeat: int, - average_repeats: bool, - phenotype: Tuple[Tuple[str, Path, Path]], - sample_file: Optional[Path], - covariate_file: Optional[Path], - phenotype_file: Optional[Path], - bgen: Optional[Path], - gene_metadata_file: Path, - gtf: Path, + data_config_file: Path, + model_config_file: Path, + checkpoint_files: Tuple[Path], + genotype_file: Path, + out_dir: Path, ): - make_regenie_skat_input_( + make_skat_input_( debug=debug, - skip_covariates=skip_covariates, - skip_phenotypes=skip_phenotypes, - skip_burdens=skip_burdens, - burdens_genes_samples=burdens_genes_samples, - repeat=repeat, - average_repeats=average_repeats, - phenotype=phenotype, - sample_file=sample_file, - covariate_file=covariate_file, - phenotype_file=phenotype_file, - bgen=bgen, - gene_metadata_file=gene_metadata_file, - gtf=gtf, + data_config_file=data_config_file, + model_config_file=model_config_file, + checkpoint_files=checkpoint_files, + genotype_file=genotype_file, + out_dir=out_dir, ) @@ -883,7 +775,7 @@ def convert_regenie_output( def load_one_model( config: Dict, - checkpoint: str, + checkpoint: PathLike, device: torch.device = torch.device("cpu"), ): """ @@ -910,16 +802,17 @@ def load_one_model( def score_variants( - variant_ids: Iterable[Union[int, np.int16, np.int32, np.int64]], - model_config_file: str, - data_config_file: str, - checkpoint_files: Tuple[str], -): + variants: pd.MultiIndex, + model_config_file: PathLike, + data_config_file: PathLike, + checkpoint_files: Tuple[PathLike], + batch_size: int = 2_000_000, +) -> pd.DataFrame: """ Score individual variants using a collection of DeepRVAT checkpoints. - :param variant_ids: Iterable containing variant IDs to score. - :type variant_ids: Iterable[Union[int, np.int16, np.int32, np.int64]] + :param variants: Pandas dataframe containing annotated variants + :type variants: pandas.DataFrame :param model_config_file: Path to the model configuration file. :type model_config_file: str :param data_config_file: Path to the data configuration file. @@ -931,12 +824,12 @@ def score_variants( with open(model_config_file) as f: model_config = yaml.safe_load(f) - with open(data_config_file) as f: - data_config = yaml.safe_load(f) + # with open(data_config_file) as f: + # data_config = yaml.safe_load(f) - annotation_file = data_config["association_testing_data"]["dataset_config"][ - "annotation_file" - ] + # annotation_file = data_config["association_testing_data"]["dataset_config"][ + # "annotation_file" + # ] if torch.cuda.is_available(): logger.info("Using GPU to score variants") @@ -945,36 +838,60 @@ def score_variants( logger.info("Using CPU to score variants") device = torch.device("cpu") - variant_df = ( - pd.read_parquet( - annotation_file, - columns=data_config["association_testing_data"]["dataset_config"][ - "rare_embedding" - ]["config"]["annotations"], - ) - .set_index("id") - .loc[variant_ids] - .to_numpy() - ) + # logger.info("Loading variant annotations") + # variant_df = pd.read_parquet( + # annotation_file, + # columns=data_config["association_testing_data"]["dataset_config"][ + # "rare_embedding" + # ]["config"]["annotations"] + # + ["gene_id", "id"], + # ) + # if variant_df["gene_id"].dtype == object: + # variant_df = ( + # variant_df.loc[variant_df["gene_id"].apply(len) > 0] + # .explode("gene_id") + # .astype({"gene_id": int}) + # ) - n_variants = variant_df.shape[0] - scores = np.zeros(n_variants) - for checkpoint in checkpoint_files: - if Path(checkpoint + ".dropped").is_file(): - # Ignore checkpoints that were chosen to be dropped - continue + # logger.info("Selecting requested variants") + # variant_df = ( + # variant_df.set_index(["gene_id", "id"]).loc[variant_gene_ids].to_numpy() + # ) - agg_model = load_one_model(model_config, checkpoint, device=device) - if Path(checkpoint + ".reverse").is_file(): - agg_model.set_reverse() + variants_np = variants.to_numpy() - scores += agg_model( - torch.tensor(variant_df, dtype=torch.float, device=device).reshape( - (n_variants, 1, -1, 1) - ) - ).reshape(n_variants) / len(checkpoint_files) + logger.info("Scoring variants using all ensemble models") + n_variants = variants_np.shape[0] + scores = np.zeros(n_variants) + with torch.no_grad(): + for start_idx in trange(0, n_variants, batch_size, desc="Batches"): + end_idx = min(start_idx + batch_size, n_variants) + this_batch_size = end_idx - start_idx + for checkpoint in tqdm(checkpoint_files, desc="Models"): + if Path(str(checkpoint) + ".dropped").is_file(): + # Ignore checkpoints that were chosen to be dropped + continue + + agg_model = load_one_model(model_config, checkpoint, device=device) + if Path(str(checkpoint) + ".reverse").is_file(): + agg_model.set_reverse() + + scores[start_idx:end_idx] += ( + ( + agg_model( + torch.tensor( + variants_np[start_idx:end_idx], + dtype=torch.float, + device=device, + ).reshape((this_batch_size, 1, -1, 1)) + ).reshape(this_batch_size) + / len(checkpoint_files) + ) + .cpu() + .numpy() + ) - return scores + return pd.DataFrame(scores, index=variants.index, columns=["DeepRVAT_score"]) @cli.command()