diff --git a/dbgsom/SomVQ.py b/dbgsom/SomVQ.py index a495141..cb38133 100644 --- a/dbgsom/SomVQ.py +++ b/dbgsom/SomVQ.py @@ -3,7 +3,7 @@ import numpy as np import numpy.typing as npt -from sklearn.base import ClusterMixin, TransformerMixin, check_array +from sklearn.base import ClusterMixin, TransformerMixin, check_array, check_is_fitted from .BaseSom import BaseSom @@ -117,10 +117,26 @@ def _prepare_inputs(self, X: npt.ArrayLike, y=None) -> tuple[npt.NDArray, None]: X = check_array(array=X, ensure_min_samples=4, dtype=[np.float64, np.float32]) return X, y - def _predict(self, X: npt.ArrayLike) -> npt.NDArray: - labels = self._get_winning_neurons(X, n_bmu=1) - return labels - def _label_prototypes(self, X: npt.ArrayLike, y=None) -> None: for i, neuron in enumerate(self.som_): self.som_.nodes[neuron]["label"] = i + + def predict(self, X: npt.ArrayLike) -> np.ndarray: + """Predict the closest neuron each sample in X belongs to. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to predict. + + Returns + ------- + labels : ndarray of shape (n_samples,) + If fitted unsupervised: Index of best matching prototype. + + """ + check_is_fitted(self) + X = check_array(X) + labels = self._get_winning_neurons(X, n_bmu=1) + + return labels diff --git a/profile b/profile deleted file mode 100644 index 64bf7ec..0000000 Binary files a/profile and /dev/null differ