Skip to content

Commit

Permalink
optimize mem usage in latent2gene
Browse files Browse the repository at this point in the history
  • Loading branch information
Ganten-Hornby committed Dec 3, 2024
1 parent b65ff7c commit f98a45d
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/gsMap/latent_to_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy.stats import rankdata
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from tqdm import tqdm, trange

from gsMap.config import LatentToGeneConfig

Expand Down Expand Up @@ -153,10 +153,6 @@ def run_latent_to_gene(config: LatentToGeneConfig):
adata.var_names = homologs.loc[adata.var_names, 'HUMAN_GENE_SYM'].values
adata = adata[:, ~adata.var_names.duplicated()]

# Create mappings
n_cells = adata.n_obs
n_genes = adata.n_vars

if config.annotation is not None:
cell_annotations = adata.obs[config.annotation].values
logger.info(f'Using cell annotations for {len(cell_annotations)} cells.')
Expand Down Expand Up @@ -204,14 +200,18 @@ def run_latent_to_gene(config: LatentToGeneConfig):
else:
adata_X = adata.X.tocsr()

ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
# Create mappings
n_cells = adata.n_obs
n_genes = adata.n_vars

ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float16)
for i in tqdm(range(n_cells), desc="Computing ranks per cell"):
data = adata_X[i, :].toarray().flatten()
ranks[i, :] = rankdata(data, method='average')

if gM is None:
gM = gmean(ranks, axis=0)
gM = gM.astype(np.float16)

adata_X_bool = adata_X.astype(bool)
if frac_whole is None:
Expand All @@ -225,15 +225,17 @@ def run_latent_to_gene(config: LatentToGeneConfig):
# Normalize the ranks
ranks /= gM

# Compute marker scores in parallel
logger.info('------Computing marker scores...')
def compute_mk_score_wrapper(cell_pos):
return compute_regional_mkscore(
cell_pos, spatial_net_dict, coor_latent, config, cell_annotations, ranks, frac_whole, adata_X_bool
)

mk_scores = [compute_mk_score_wrapper(cell_pos) for cell_pos in tqdm(range(n_cells), desc="Calculating marker scores")]
mk_score = np.vstack(mk_scores).T
logger.info('------Computing marker scores...')
mk_score = np.zeros((n_cells, n_genes), dtype=np.float16)
for cell_pos in trange(n_cells, desc="Calculating marker scores"):
mk_score[cell_pos, :] = compute_mk_score_wrapper(cell_pos)

mk_score = mk_score.T
logger.info('Marker scores computed.')

# Remove mitochondrial genes
Expand Down

0 comments on commit f98a45d

Please sign in to comment.