Skip to content

Commit

Permalink
gmt length should be checked, #303
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoqing Fang authored and Zhuoqing Fang committed Feb 13, 2025
1 parent 624aec3 commit 060d920
Showing 1 changed file with 150 additions and 41 deletions.
191 changes: 150 additions & 41 deletions gseapy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,168 @@


class GMT:
"""A collection of gene set dictionaries with metadata.
Attributes:
_collections: Dict[str, Dict[str, Any]] - Stores gene set collections
key: collection name
value: {
'genes': Dict[str, List[str]] - Gene set mappings
'description': str - Collection description
'source': str - Source of the gene sets
}
"""

def __init__(
self,
mapping: Optional[Dict[str, str]] = None,
mapping: Optional[Dict[str, List[str]]] = None,
description: Optional[str] = None,
source: Optional[str] = None,
name: Optional[str] = "default",
):
"""Initialize a GMT collection.
Args:
mapping: Initial gene set dictionary
description: Description of the gene sets
source: Source of the gene sets
name: Name of this collection
"""
wrapper of dict. this helps merge multiple dict into one
the original key will changed to new key with suffix '__{description}'
"""
self.description = description
self.source = source
self._mapping = {}
self._collections = {}
if mapping is not None:
self.update(mapping)
self.add(mapping, description, source, name)

def add(
self,
mapping: Dict[str, List[str]],
description: Optional[str] = None,
source: Optional[str] = None,
name: Optional[str] = "default",
):
"""Add a gene set collection with metadata.
def update(self, mapping: Dict[str, str]):
Args:
mapping: Gene set dictionary to add
description: Description of the gene sets
source: Source of the gene sets
name: Name for this collection
"""
update the mapping in place
self._collections[name] = {
"genes": mapping,
"description": description,
"source": source,
}

def get(self, name: str = "default") -> Dict[str, List[str]]:
"""Get gene sets by collection name."""
return self._collections[name]["genes"]

def get_metadata(self, name: str = "default") -> Dict[str, Any]:
"""Get metadata for a collection."""
collection = self._collections[name]
return {
"description": collection["description"],
"source": collection["source"],
}

def write(self, ofname: str):
"""Write GMT file to disk."""
with open(ofname, "w") as out:
for name, collection in self._collections.items():
for key, genes in collection["genes"].items():
desc = collection["description"] or ""
line = [key, desc] + genes
out.write("\t".join(line) + "\n")

def filter(
self,
min_size: Optional[int] = None,
max_size: Optional[int] = None,
gene_list: Optional[List[str]] = None,
collections: Optional[List[str]] = None,
) -> "GMT":
"""Filter gene sets based on size and gene membership.
Args:
min_size: Minimum number of genes in a set
max_size: Maximum number of genes in a set
gene_list: Only keep genes present in this list
collections: Only keep these named collections
Returns:
A new filtered GMT object
"""
for key, value in mapping.items():
k = key + "__" + self.description if self.description else key
self._mapping[k] = value
filtered = GMT()

def apply(self, func):
"""apply function in place"""
for key, value in self._mapping.items():
self._mapping[key] = func(value)
# Filter collections if specified
colls = collections or self._collections.keys()

def is_empty(self):
return len(self._mapping) == 0
for name in colls:
if name not in self._collections:
continue

def write(self, ofname: str):
"""
write gmt file to disk
"""
with open(ofname, "w") as out:
for key, value in self._mapping.items():
collections = key.split("__")
collections += list(value)
out.write("\t".join(collections) + "\n")
collection = self._collections[name]
filtered_mapping = {}

for term, genes in collection["genes"].items():
# Filter genes if gene_list provided
if gene_list is not None:
genes = [g for g in genes if g in gene_list]

# Apply size filters
if min_size is not None and len(genes) < min_size:
continue
if max_size is not None and len(genes) > max_size:
continue

# Only add if genes remain after filtering
if genes:
filtered_mapping[term] = genes

# Only add collection if it has gene sets after filtering
if filtered_mapping:
filtered.add(
filtered_mapping,
description=collection["description"],
source=collection["source"],
name=name,
)

return filtered

@classmethod
def read(cls, paths, source=None):
paths = paths.strip().split(",")
# mapping
mapping = {}
for path in paths:
with open(path, "r") as inp:
for line in inp:
def read(cls, paths: str, source: Optional[str] = None) -> "GMT":
"""Read GMT files into a collection.
Args:
paths: Comma-separated list of GMT file paths
source: Source annotation for the files
"""
gmt = cls()
for path in paths.strip().split(","):
name = os.path.basename(path)
mapping = {}
with open(path) as f:
for line in f:
items = line.strip().split("\t")
key = items[0]
if items[1] != "":
key += "__" + items[1]
mapping[key] = items[2:]
return cls(mapping, source=source)
desc = items[1]
genes = items[2:]
mapping[key] = genes
gmt.add(mapping, desc, source, name)
return gmt

def __getitem__(self, name: str) -> Dict[str, List[str]]:
"""Dictionary-like access to gene sets."""
return self.get(name)

def __iter__(self):
"""Iterate over collection names."""
return iter(self._collections)

def items(self):
"""Iterate over (name, gene_sets) pairs."""
return ((name, coll["genes"]) for name, coll in self._collections.items())


class GSEAbase(object):
Expand Down Expand Up @@ -325,12 +431,15 @@ def load_gmt(
"""load gene set dict"""

genesets_dict = self.load_gmt_only(gmt)

if not subsets: # Check if empty
raise ValueError("Empty gene sets dictionary")
subsets = list(genesets_dict.keys())
entry1st = genesets_dict[subsets[0]]
gene_dict = {g: i for i, g in enumerate(gene_list)}
# Check uppercase for up to 20 sets
sample_size = min(20, len(subsets))
ups = []
for s in subsets[:20]:
for s in subsets[:sample_size]:
ups.append(self.check_uppercase(genesets_dict[s]))

if (not self._gene_isupper) and all(ups):
Expand Down

0 comments on commit 060d920

Please sign in to comment.