Skip to content

Commit

Permalink
fix predict() in somvq
Browse files Browse the repository at this point in the history
  • Loading branch information
SandroMartens committed Feb 22, 2024
1 parent 57b0d40 commit 6d79011
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions dbgsom/SomVQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import numpy.typing as npt
from sklearn.base import ClusterMixin, TransformerMixin, check_array
from sklearn.base import ClusterMixin, TransformerMixin, check_array, check_is_fitted

from .BaseSom import BaseSom

Expand Down Expand Up @@ -117,10 +117,26 @@ def _prepare_inputs(self, X: npt.ArrayLike, y=None) -> tuple[npt.NDArray, None]:
X = check_array(array=X, ensure_min_samples=4, dtype=[np.float64, np.float32])
return X, y

def _predict(self, X: npt.ArrayLike) -> npt.NDArray:
labels = self._get_winning_neurons(X, n_bmu=1)
return labels

def _label_prototypes(self, X: npt.ArrayLike, y=None) -> None:
for i, neuron in enumerate(self.som_):
self.som_.nodes[neuron]["label"] = i

def predict(self, X: npt.ArrayLike) -> np.ndarray:
"""Predict the closest neuron each sample in X belongs to.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
New data to predict.
Returns
-------
labels : ndarray of shape (n_samples,)
If fitted unsupervised: Index of best matching prototype.
"""
check_is_fitted(self)
X = check_array(X)
labels = self._get_winning_neurons(X, n_bmu=1)

return labels
Binary file removed profile
Binary file not shown.

0 comments on commit 6d79011

Please sign in to comment.