Skip to content

Commit

Permalink
Merge pull request #162 from JonathanShor/typing
Browse files Browse the repository at this point in the history
Add python typing
  • Loading branch information
JonathanShor authored Feb 6, 2025
2 parents a4156e5 + 4a10465 commit c4a4f42
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
42 changes: 22 additions & 20 deletions doubletdetection/doubletdetection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Doublet detection in single-cell RNA-seq data."""

import collections
from collections.abc import Callable
import io
import warnings
from contextlib import redirect_stdout

import anndata
import numpy as np
from numpy.typing import NDArray
import phenograph
import scanpy as sc
import scipy.sparse as sp_sparse
Expand Down Expand Up @@ -83,20 +85,20 @@ class BoostClassifier:

def __init__(
self,
boost_rate=0.25,
n_components=30,
n_top_var_genes=10000,
replace=False,
clustering_algorithm="phenograph",
clustering_kwargs=None,
n_iters=10,
normalizer=None,
pseudocount=0.1,
random_state=0,
verbose=False,
standard_scaling=False,
n_jobs=1,
):
boost_rate: float = 0.25,
n_components: int = 30,
n_top_var_genes: int = 10000,
replace: bool = False,
clustering_algorithm: str = "phenograph",
clustering_kwargs: dict | None = None,
n_iters: int = 10,
normalizer: Callable | None = None,
pseudocount: float = 0.1,
random_state: int = 0,
verbose: bool = False,
standard_scaling: bool = False,
n_jobs: int = 1,
) -> None:
self.boost_rate = boost_rate
self.replace = replace
self.clustering_algorithm = clustering_algorithm
Expand Down Expand Up @@ -145,7 +147,7 @@ def __init__(
n_components, n_top_var_genes
)

def fit(self, raw_counts):
def fit(self, raw_counts: NDArray | sp_sparse.csr_matrix) -> "BoostClassifier":
"""Fits the classifier on raw_counts.
Args:
Expand Down Expand Up @@ -226,7 +228,7 @@ def fit(self, raw_counts):

return self

def predict(self, p_thresh=1e-7, voter_thresh=0.9):
def predict(self, p_thresh: float = 1e-7, voter_thresh: float = 0.9) -> NDArray:
"""Produce doublet calls from fitted classifier
Args:
Expand Down Expand Up @@ -266,7 +268,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):

return self.labels_

def doublet_score(self):
def doublet_score(self) -> NDArray:
"""Produce doublet scores
The doublet score is the average negative log p-value of doublet enrichment
Expand All @@ -284,7 +286,7 @@ def doublet_score(self):

return -avg_log_p

def _one_fit(self):
def _one_fit(self) -> tuple[NDArray, NDArray]:
if self.verbose:
print("\nCreating synthetic doublets...")
self._createDoublets()
Expand Down Expand Up @@ -395,7 +397,7 @@ def _one_fit(self):

return scores, log_p_values

def _createDoublets(self):
def _createDoublets(self) -> None:
"""Create synthetic doublets.
Sets .parents_
Expand All @@ -414,7 +416,7 @@ def _createDoublets(self):
self._raw_synthetics = synthetic
self.parents_ = parents

def _set_clustering_kwargs(self):
def _set_clustering_kwargs(self) -> None:
"""Sets .clustering_kwargs"""
if self.clustering_algorithm == "phenograph":
if "prune" not in self.clustering_kwargs:
Expand Down
32 changes: 20 additions & 12 deletions doubletdetection/plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import warnings
from typing import Any

import matplotlib
import numpy as np
from numpy.typing import NDArray
from matplotlib.figure import Figure

try:
os.environ["DISPLAY"]
Expand All @@ -11,7 +14,7 @@
import matplotlib.pyplot as plt


def normalize_counts(raw_counts, pseudocount=0.1):
def normalize_counts(raw_counts: NDArray, pseudocount: float = 0.1) -> NDArray:
"""Normalize count array. Default normalizer used by BoostClassifier.
Args:
Expand All @@ -22,7 +25,6 @@ def normalize_counts(raw_counts, pseudocount=0.1):
ndarray: Normalized data.
"""
# Sum across cells

cell_sums = np.sum(raw_counts, axis=1)

# Mutiply by median and divide each cell by cell sum
Expand All @@ -34,7 +36,13 @@ def normalize_counts(raw_counts, pseudocount=0.1):
return normed


def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9):
def convergence(
clf: Any,
show: bool = False,
save: str | None = None,
p_thresh: float = 1e-7,
voter_thresh: float = 0.9,
) -> Figure:
"""Produce a plot showing number of cells called doublet per iter
Args:
Expand Down Expand Up @@ -81,15 +89,15 @@ def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9):


def threshold(
clf,
show=False,
save=None,
log10=True,
log_p_grid=None,
voter_grid=None,
v_step=2,
p_step=5,
):
clf: Any,
show: bool = False,
save: str | None = None,
log10: bool = True,
log_p_grid: NDArray | None = None,
voter_grid: NDArray | None = None,
v_step: int = 2,
p_step: int = 5,
) -> Figure:
"""Produce a plot showing number of cells called doublet across
various thresholds
Expand Down

0 comments on commit c4a4f42

Please sign in to comment.