Skip to content

Commit

Permalink
mvp
Browse files Browse the repository at this point in the history
  • Loading branch information
maxibor committed Apr 3, 2024
1 parent bd5bed7 commit c45bdac
Show file tree
Hide file tree
Showing 11 changed files with 514 additions and 16 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: CI

on:
push:
branches:
- dev
- master
pull_request:
release:
types: [published]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install poetry
- name: Install and test
run: |
poetry run pytest -vv
31 changes: 29 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@
# floria-strainer

Given the output of the Strain Haplotyping software [floria](https://github.com/bluenote-1577/floria)[^1] , **floria-strainer** computes the allele frequency at each variable position for each haploset identified by floria to cluster them into the different mixtures of strains using a Gaussian Mixture Model.
## Introduction

[^1]: [Floria: Fast and accurate strain haplotyping in metagenomes](https://www.biorxiv.org/content/10.1101/2024.01.28.577669v1.full)
Given the output of the Strain Haplotyping software [floria](https://github.com/bluenote-1577/floria)[^1] , **floria-strainer** computes the allele frequency at each variable position of each haploset identified by floria to cluster them into the different mixtures of strains using a Gaussian Mixture Model.

[^1]: [Floria: Fast and accurate strain haplotyping in metagenomes](https://www.biorxiv.org/content/10.1101/2024.01.28.577669v1.full)

## Install

TBD

## Help

```bash
floria-strainer --help

Usage: floria-strainer [OPTIONS] FLORIA_OUTDIR

Strain the haplotypes in the floria output directory.
Author: Maxime Borry

╭─ Options ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ --version Show the version and exit. │
│ --nb-strains -n INTEGER Number of strains to keep. If 0, the number of strains will be determined by the mean floria average │
│ strain count with HAPQ > 15. │
│ [default: 0] │
│ --hapq-cut -h INTEGER Minimum HAPQ threshold [default: 15] │
│ --sp-cut -s FLOAT Minimum strain clustering probability threshold [default: 0.5] │
│ --help Show this message and exit. │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
```
4 changes: 4 additions & 0 deletions floria_strainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from importlib import metadata

__version__ = metadata.metadata("floria-strainer")["Version"]
__author__ = metadata.metadata("floria-strainer")["Author"]
39 changes: 39 additions & 0 deletions floria_strainer/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import rich_click as click
from floria_strainer.main import strainer
from floria_strainer import __version__, __author__


@click.command()
@click.version_option(__version__)
@click.argument("floria_outdir", type=click.Path(exists=True))
@click.option(
"-n",
"--nb-strains",
help="Number of strains to keep. If 0, the number of strains will be determined by the mean floria average strain count with HAPQ > 15.",
type=int,
default=0,
show_default=True,
)
@click.option(
"-h",
"--hapq-cut",
help="Minimum HAPQ threshold",
type=int,
default=15,
show_default=True,
)
@click.option(
"-s",
"--sp-cut",
help="Minimum strain clustering probability threshold",
type=float,
default=0.5,
show_default=True,
)
def cli(floria_outdir: str, nb_strains: int, hapq_cut: int, sp_cut: float):
"""
Strain the haplotypes in the floria output directory.
Author: Maxime Borry
"""
strainer(floria_outdir, nb_strains, hapq_cut, sp_cut)
172 changes: 159 additions & 13 deletions floria_strainer/main.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,50 @@
import pandas as pd
import numpy as np
from typing import List
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score
from collections import ChainMap
import logging
import pysam
import os

from floria_strainer.parser import parse_vartigs, parse_vartig_info
from floria_strainer.parser import (
parse_vartigs,
parse_vartig_info,
parse_floria_contig_ploidy,
)


def parse_files(vartigs: dict, vartig_info: dict) -> pd.DataFrame:
def parse_files(floria_outdir: str) -> tuple[pd.DataFrame, int]:

vartigs = parse_vartigs(vartigs)
vartig_info = parse_vartig_info(vartig_info)
cp = parse_floria_contig_ploidy(
os.path.join(floria_outdir, "contig_ploidy_info.tsv")
)

contigs = cp["contig"].unique().tolist()
nb_strains = round(cp["average_straincount_min15hapq"].mean())

vartigs_files = [
os.path.join(floria_outdir, contig, f"{contig}.vartigs") for contig in contigs
]
vartig_info_files = [
os.path.join(floria_outdir, contig, "vartig_info.txt") for contig in contigs
]

parsed_vartigs = [parse_vartigs(vartig_file) for vartig_file in vartigs_files]

parsed_vartig_info = [
parse_vartig_info(vartig_info_file) for vartig_info_file in vartig_info_files
]

vartigs_df = pd.DataFrame.from_dict(vartigs)
vartig_info_df = pd.DataFrame.from_dict(vartig_info)
vartigs_df = pd.DataFrame.from_dict(dict(ChainMap(*parsed_vartigs)))
vartig_info_df = pd.DataFrame.from_dict(dict(ChainMap(*parsed_vartig_info)))

df = pd.merge(vartigs_df, vartig_info_df, on=["contig", "haploset"])

return df
return df, nb_strains


def compute_gmm(obs: np.array, n_components: int) -> List(np.array, np.array):
def compute_gmm(obs: np.array, n_components: int) -> tuple[np.array, np.array]:
"""
Compute the Gaussian Mixture Model for the given observations.
Expand All @@ -38,9 +61,9 @@ def compute_gmm(obs: np.array, n_components: int) -> List(np.array, np.array):
List(np.array, np.array)
The labels and the maximum probability for each observation.
"""
if n_components < 2:
if n_components == 0:
logging.warning(
"The number of components you gave less than 2. The optimal number of components will be selected using the Silhouette score."
"The number of components is 0. The optimal number of components will be selected using the Silhouette score."
)

best_comp = 2
Expand All @@ -55,13 +78,42 @@ def compute_gmm(obs: np.array, n_components: int) -> List(np.array, np.array):
old_score = score
n_components = best_comp

elif n_components < 2:
raise ValueError("The number of components must be at least 2.")
gmm = GaussianMixture(n_components=n_components)
gmm.fit(obs)
labels = gmm.predict(obs)
max_proba = np.max(gmm.predict_proba(obs), axis=1)
return labels, max_proba


def process_df(df: pd.DataFrame, hapq_cut: int, nb_strains: int) -> pd.DataFrame:
def process_df(df: pd.DataFrame, hapq_cut: int, sp_cut: float, nb_strains: int) -> dict:
"""
Process the DataFrame to get the strains.
Parameters
----------
df : pd.DataFrame
The DataFrame to process.
hapq_cut : int
The hapq cut-off to use.
sp_cut : float
The clustering probability support cut-off to use.
nb_strains : int
The number of strains to use. Use 0 to automatically select the best number of strains.
Returns
-------
dict: {str: {str: int}}
{
contig(str) : {
haploset(str) : strain(int)
}
}
list: [int]
List of strains
"""

df = df[df["HAPQ"] >= hapq_cut]
support = df.groupby(["contig", "pos", "allele"])["support"].sum().reset_index()
coverage = (
Expand All @@ -73,13 +125,107 @@ def process_df(df: pd.DataFrame, hapq_cut: int, nb_strains: int) -> pd.DataFrame
all_freq = support.merge(
coverage, left_on=["contig", "pos"], right_on=["contig", "pos"]
)
all_freq["freq"] = all_freq["support"] / all_freq["cov"]

all_freq["freq"] = all_freq["support"] / all_freq["coverage"]
all_freq["MAF"] = np.where(all_freq["freq"] > 0.5, 1, 0)
keys = ["contig", "pos", "allele"]
df2 = all_freq.merge(df[["contig", "pos", "allele", "haploset"]], on=keys)

obs = df2["freq"].values.reshape(-1, 1)

labels, max_proba = compute_gmm(obs=obs, n_components=nb_strains)

df2["strain"] = pd.Categorical(labels)
df2["strain_proba"] = max_proba

df2 = df2[df2["strain_proba"] >= sp_cut]
strains = df2["strain"].unique().tolist()
df2["strain"] = df2["strain"].astype(int)

df3 = (
df2.groupby(["contig", "haploset"])["strain"]
.mean()
.reset_index()
.assign(strain=lambda df: df["strain"].apply(round))
)

return (
df3.set_index("haploset")
.groupby("contig")
.apply(lambda x: x["strain"].to_dict())
.to_dict(),
strains,
)


def write_bam_split(
inbam: str, basename: str, haplostrain: dict, strains: list
) -> None:
"""
Write the BAM files for each strain.
Parameters
----------
inbam : str
The input BAM file.
basename : str
The basename to use for the output BAM files.
haplostrain : dict
The haplostrain dictionary.
strains : list
The strains to write the BAM files for.
"""

bam = pysam.AlignmentFile(inbam, "rb")
for strain in strains:
with pysam.AlignmentFile(
f"{basename}.{strain}.bam", "wb", template=bam
) as outbam:
for read in bam:
refname = read.reference_name
if refname in haplostrain:
try:
h = read.get_tag("HP")
if haplostrain[refname][h] == strain:
outbam.write(read)
except KeyError:
pass
bam.close()


def write_bam(inbam: str, outbam: str, haplostrain: dict) -> None:
"""
Write the BAM files for each strain.
Parameters
----------
inbam : str
The input BAM file.
outbam : str
The output BAM file.
haplostrain : dict
The haplostrain dictionary.
"""

with pysam.AlignmentFile(inbam, "rb") as bam:
with pysam.AlignmentFile(outbam, "wb", template=bam) as bamout:
for read in bam:
refname = read.reference_name
if refname in haplostrain:
try:
h = read.get_tag("HP")
read.set_tag("ST", haplostrain[refname][h])
except KeyError:
pass
bamout.write(read)


def strainer(floria_outdir, nb_strains: int, hapq_cut: int, sp_cut: float):
fl_df, fl_nb_strains = parse_files(floria_outdir)
if nb_strains == 0:
nb_strains = fl_nb_strains
fl_df_processed, strains = process_df(
df=fl_df, hapq_cut=hapq_cut, sp_cut=sp_cut, nb_strains=nb_strains
)

logging.info(f"Strains: {strains}")
22 changes: 22 additions & 0 deletions floria_strainer/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pandas as pd


def parse_vartigs(filename: str) -> dict:
"""
Read a vartigs file and return a dictionary with the values.
Expand Down Expand Up @@ -75,3 +78,22 @@ def parse_vartig_info(filename: str) -> dict:
continue

return vartig_all


def parse_floria_contig_ploidy(contig_ploidy: str) -> pd.DataFrame:
"""
Read a contig_ploidy_info.tsv file generated by Floria and returns it as a DataFrame
Parameters
----------
contig_ploidy : str
The name of the file to read.
Returns
-------
pd.DataFrame
A pandas DataFrame with the values of the file.
"""

df = pd.read_csv(contig_ploidy, sep="\t")
return df
Loading

0 comments on commit c45bdac

Please sign in to comment.