Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in Snakemake Log-Files #147

Merged
merged 8 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deeprvat/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/data/rare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
23 changes: 12 additions & 11 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from tqdm import tqdm, trange
import zarr
import re

import deeprvat.deeprvat.models as deeprvat_models
from deeprvat.data import DenseGTDataset

logging.root.handlers.clear() # Remove all handlers associated with the root logger object
logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level=logging.INFO,
Expand All @@ -48,6 +48,11 @@
AGG_FCT = {"mean": np.mean, "max": np.max}


@click.group()
def cli():
pass


def get_burden(
batch: Dict,
agg_models: Dict[str, List[nn.Module]],
Expand Down Expand Up @@ -99,11 +104,6 @@ def separate_parallel_results(results: List) -> Tuple[List, ...]:
return tuple(map(list, zip(*results)))


@click.group()
def cli():
pass


def make_dataset_(
config: Dict,
debug: bool = False,
Expand Down Expand Up @@ -306,7 +306,6 @@ def make_regenie_input_(
gene_metadata_file: Path,
gtf: Path,
):
logger.setLevel(logging.INFO)

## Check options
if not skip_burdens and burdens_genes_samples is None:
Expand Down Expand Up @@ -420,7 +419,7 @@ def make_regenie_input_(
if average_repeats:
logger.info("Averaging burdens across all repeats")
burdens = np.zeros((n_samples, n_genes))
for repeat in trange(burdens_zarr.shape[2]):
for repeat in trange(burdens_zarr.shape[2], file=sys.stdout):
burdens += burdens_zarr[:n_samples, :, repeat]
burdens = burdens / burdens_zarr.shape[2]
else:
Expand Down Expand Up @@ -449,7 +448,7 @@ def make_regenie_input_(
samples=list(sample_ids.astype(str)),
metadata="Pseudovariants containing DeepRVAT gene impairment scores. One pseudovariant per gene.",
) as f:
for i in trange(n_genes):
for i in trange(n_genes, file=sys.stdout):
varid = f"pseudovariant_gene_{ensgids[i]}"
this_burdens = burdens[:, i] # Rescale scores to be in range (0, 2)
genotypes = np.stack(
Expand Down Expand Up @@ -747,7 +746,7 @@ def load_models(
}

if len(checkpoint_files[first_repeat]) > 1:
logging.info(
logger.info(
f" Averaging results from {len(checkpoint_files[first_repeat])} models for each repeat"
)

Expand Down Expand Up @@ -1065,7 +1064,9 @@ def combine_burden_chunks_(
end_id = 0

for i, chunk in tqdm(
enumerate(range(0, n_chunks)), desc=f"Merging {n_chunks} chunks"
enumerate(range(0, n_chunks)),
desc=f"Merging {n_chunks} chunks",
file=sys.stdout,
):
chunk_dir = burdens_chunks_dir / f"chunk_{chunk}"

Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/common_variant_condition_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
29 changes: 21 additions & 8 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pprint import pformat, pprint
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union

import re
import click
import math
import numpy as np
Expand Down Expand Up @@ -37,10 +37,9 @@
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm


logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -872,20 +871,20 @@ def run_bagging(
trainer.fit(model, dm)
except RuntimeError as e:
# if batch_size is choosen to big, it will be reduced until it fits the GPU
logging.error(f"Caught RuntimeError: {e}")
logger.error(f"Caught RuntimeError: {e}")
if str(e).find("CUDA out of memory") != -1:
if dm.hparams.batch_size > 4:
logging.error(
logger.error(
"Retrying training with half the original batch size"
)
gc.collect()
torch.cuda.empty_cache()
dm.hparams.batch_size = dm.hparams.batch_size // 2
else:
logging.error("Batch size is already <= 4, giving up")
logger.error("Batch size is already <= 4, giving up")
raise RuntimeError("Could not find small enough batch size")
else:
logging.error(f"Caught unknown error: {e}")
logger.error(f"Caught unknown error: {e}")
raise e
else:
break
Expand Down Expand Up @@ -1167,7 +1166,21 @@ def best_training_run(
config = yaml.safe_load(f)

with open(config_file_out, "w") as f:
yaml.dump({"model": config["model"]}, f)
yaml.dump(
{
"model": config["model"],
"rare_variant_annotations": config["training_data"]["dataset_config"][
"rare_embedding"
]["config"]["annotations"],
"training_data_thresholds": {
k: str(re.sub(f"^{k} ", "", v))
for k, v in config["training_data"]["dataset_config"][
"rare_embedding"
]["config"]["thresholds"].items()
},
},
f,
)

n_bags = config["training"]["n_bags"] if not debug else 3
for k in range(n_bags):
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/seed_gene_discovery/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/seed_gene_discovery/seed_gene_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
12 changes: 10 additions & 2 deletions pipelines/association_testing/association_dataset.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ rule association_dataset:
resources:
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1),
priority: 30
log:
stdout="logs/association_dataset/{phenotype}.stdout",
stderr="logs/association_dataset/{phenotype}.stderr"
shell:
'deeprvat_associate make-dataset '
+ debug +
"--skip-genotypes "
'{input.data_config} '
'{output}'
'{output} '
+ logging_redirct


rule association_dataset_burdens:
Expand All @@ -33,8 +37,12 @@ rule association_dataset_burdens:
resources:
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1)
priority: 30
log:
stdout=f"logs/association_dataset_burdens/{phenotypes[0]}.stdout",
stderr=f"logs/association_dataset_burdens/{phenotypes[0]}.stderr"
shell:
'deeprvat_associate make-dataset '
+ debug +
'{input.data_config} '
'{output}'
'{output} '
+ logging_redirct
27 changes: 22 additions & 5 deletions pipelines/association_testing/burdens.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ rule combine_burdens:
threads: 1
resources:
mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098,
log:
stdout="logs/combine_burdens/combine_burdens.stdout",
stderr="logs/combine_burdens/combine_burdens.stderr"
shell:
' '.join([
'deeprvat_associate combine-burden-chunks',
'{params.prefix}/burdens/chunks/',
' --n-chunks ' + str(n_burden_chunks),
'{params.prefix}/burdens',
'{params.prefix}/burdens ',
logging_redirct
])

rule all_xy:
Expand All @@ -42,14 +46,18 @@ rule compute_xy:
threads: 8
resources:
mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098,
log:
stdout="logs/compute_xy/{phenotype}.stdout",
stderr="logs/compute_xy/{phenotype}.stderr"
shell:
' && '.join([
('deeprvat_associate compute-xy '
'--dataset-file {input.dataset} '
'{input.data_config} '
"{output.samples} "
"{output.x} "
"{output.y}")
"{output.y} "
+ logging_redirct)
])


Expand All @@ -73,6 +81,9 @@ rule compute_burdens:
resources:
mem_mb = 32000,
gpus = 1
log:
stdout="logs/compute_burdens/compute_burdens_{chunk}.stdout",
stderr="logs/compute_burdens/compute_burdens_{chunk}.stderr"
shell:
' '.join([
'deeprvat_associate compute-burdens '
Expand All @@ -83,7 +94,8 @@ rule compute_burdens:
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/burdens'],
'{params.prefix}/burdens '
+ logging_redirct ],
)


Expand All @@ -98,11 +110,16 @@ rule reverse_models:
threads: 4
resources:
mem_mb = 20480,
log:
stdout="logs/reverse_models/reverse_models.stdout",
stderr="logs/reverse_models/reverse_models.stderr"
shell:
" && ".join([
("deeprvat_associate reverse-models "
"{input.model_config} "
"{input.data_config} "
"{input.checkpoints}"),
"touch {output}"
"{input.checkpoints} "
+ logging_redirct),
"touch {output} "

])
Loading