diff --git a/datasail/cluster/clustering.py b/datasail/cluster/clustering.py index 065e12c..ec5043a 100644 --- a/datasail/cluster/clustering.py +++ b/datasail/cluster/clustering.py @@ -153,18 +153,16 @@ def finish_clustering(dataset: DataSet) -> None: """ # compute the weights and the stratification for the clusters dataset.cluster_weights = {} - if dataset.stratification is not None: - dataset.cluster_stratification = {} + dataset.cluster_stratification = {} for key, value in dataset.cluster_map.items(): if value not in dataset.cluster_weights: dataset.cluster_weights[value] = 0 dataset.cluster_weights[value] += dataset.weights[key] - if dataset.stratification is not None: - if value not in dataset.cluster_stratification: - dataset.cluster_stratification[value] = np.zeros(len(dataset.classes)) - dataset.cluster_stratification[value] += dataset.strat2oh(name=key) + if value not in dataset.cluster_stratification: + dataset.cluster_stratification[value] = np.zeros(len(dataset.classes)) + dataset.cluster_stratification[value] += dataset.strat2oh(name=key) def additional_clustering( diff --git a/datasail/cluster/diamond.py b/datasail/cluster/diamond.py index 9aacf09..13b7c23 100644 --- a/datasail/cluster/diamond.py +++ b/datasail/cluster/diamond.py @@ -69,7 +69,5 @@ def run_diamond(dataset: DataSet, threads: int = 1, log_dir: Optional[Path] = No shutil.rmtree(result_folder, ignore_errors=True) dataset.cluster_names = table.index.tolist() - print(dataset.cluster_names) dataset.cluster_map = {n: n for n in dataset.cluster_names} dataset.cluster_similarity = table.to_numpy() - dataset.cluster_weights = {n: 1 for n in dataset.cluster_names} diff --git a/datasail/cluster/vectors.py b/datasail/cluster/vectors.py index 2049cdd..e00d9c2 100644 --- a/datasail/cluster/vectors.py +++ b/datasail/cluster/vectors.py @@ -149,8 +149,8 @@ def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None: else: raise ValueError(f"Unknown method {method}") fps = [dataset.data[name] for name in dataset.names] - run(dataset, fps, method) + run(dataset, fps, method) dataset.cluster_names = copy.deepcopy(dataset.names) dataset.cluster_map = {n: n for n in dataset.names} diff --git a/datasail/reader/utils.py b/datasail/reader/utils.py index 9845b7e..9d32055 100644 --- a/datasail/reader/utils.py +++ b/datasail/reader/utils.py @@ -269,7 +269,7 @@ def read_data( Returns: A dataset storing all information on that datatype """ - # parse the protein weights + # parse the weights if isinstance(weights, Path) and weights.is_file(): if weights.suffix[1:].lower() == "csv": dataset.weights = dict((n, float(w)) for n, w in read_csv(weights, ",")) @@ -286,13 +286,31 @@ def read_data( elif inter is not None: dataset.weights = dict(count_inter(inter, index)) else: - dataset.weights = dict((p, 1) for p in list(dataset.data.keys())) + dataset.weights = {k: 1 for k in dataset.data.keys()} - dataset.classes, dataset.stratification = read_stratification(strats) + # parse the stratification + if isinstance(strats, Path) and strats.is_file(): + if strats.suffix[1:].lower() == "csv": + dataset.stratification = dict(read_csv(strats, ",")) + elif strats.suffix[1:].lower() == "tsv": + dataset.stratification = dict(read_csv(strats, "\t")) + else: + raise ValueError() + elif isinstance(strats, dict): + dataset.stratification = strats + elif isinstance(strats, Callable): + dataset.stratification = strats() + elif isinstance(strats, Generator): + dataset.stratification = dict(strats) + else: + dataset.stratification = {k: 0 for k in dataset.data.keys()} + + # .classes maps the individual classes to their index in one-hot encoding, important for non-numeric classes + dataset.classes = {s: i for i, s in enumerate(set(dataset.stratification.values()))} dataset.class_oh = np.eye(len(dataset.classes)) dataset.num_clusters = num_clusters - # parse the protein similarity measure + # parse the similarity or distance measure if sim is None and dist is None: dataset.similarity, dataset.distance = get_default(dataset.type, dataset.format) dataset.names = list(dataset.data.keys()) @@ -312,37 +330,6 @@ def read_data( return dataset -def read_stratification(strats: DATA_INPUT) -> Tuple[Dict[Any, int], Optional[Dict[str, np.ndarray]]]: - """ - Read in the stratification for the data. - - Args: - strats: Stratification input - - Returns: - Set of all classes and a dictionary mapping the entity names to their class - """ - # parse the stratification - if isinstance(strats, Path) and strats.is_file(): - if strats.suffix[1:].lower() == "csv": - stratification = dict(read_csv(strats, ",")) - elif strats.suffix[1:].lower() == "tsv": - stratification = dict(read_csv(strats, "\t")) - else: - raise ValueError() - elif isinstance(strats, dict): - stratification = strats - elif isinstance(strats, Callable): - stratification = strats() - elif isinstance(strats, Generator): - stratification = dict(strats) - else: - return {0: 0}, None - - classes = {s: i for i, s in enumerate(set(stratification.values()))} - return classes, stratification - - def read_folder(folder_path: Path, file_extension: Optional[str] = None) -> Generator[Tuple[str, str], None, None]: """ Read in all PDB file from a folder and ignore non-PDB files. diff --git a/datasail/version.py b/datasail/version.py index 976498a..92192ee 100644 --- a/datasail/version.py +++ b/datasail/version.py @@ -1 +1 @@ -__version__ = "1.0.3" +__version__ = "1.0.4" diff --git a/recipe/meta.yaml b/recipe/meta.yaml index bc5571d..e9dc1c2 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -1,6 +1,6 @@ package: name: "datasail" - version: '1.0.3' + version: '1.0.4' source: path: .. diff --git a/recipe_lite/meta.yaml b/recipe_lite/meta.yaml index 5a716ab..7677e4c 100644 --- a/recipe_lite/meta.yaml +++ b/recipe_lite/meta.yaml @@ -1,6 +1,6 @@ package: name: "datasail-lite" - version: '1.0.3' + version: '1.0.4' source: path: .. diff --git a/tests/test_caching.py b/tests/test_caching.py index 54ae77c..a528b37 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -1,5 +1,4 @@ import copy -import os from pathlib import Path import pytest