Skip to content

Commit

Permalink
predict_proba() now sums to 1
Browse files Browse the repository at this point in the history
  • Loading branch information
SandroMartens committed Feb 22, 2024
1 parent 5e4aee8 commit fd0f26c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
15 changes: 5 additions & 10 deletions dbgsom/BaseSom.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def fit(self, X: npt.ArrayLike, y: None | npt.ArrayLike = None):
"""
# Horizontal growing phase

X, y = self._prepare_inputs(X, y)
X, y = self._check_input_data(X, y)
# self._fit(X, y)
if y is not None:
classes, y = np.unique(y, return_inverse=True)
self.classes_ = np.array(classes)
Expand All @@ -116,16 +117,11 @@ def fit(self, X: npt.ArrayLike, y: None | npt.ArrayLike = None):
self.labels_ = self.predict(X)
self.n_iter_ = self._current_epoch

self._fit(X, y)

return self

def _prepare_inputs(self, X, y):
def _check_input_data(self, X, y):
raise NotImplementedError

def _fit(self, X, y):
pass

def predict(self, X):
raise NotImplementedError

Expand Down Expand Up @@ -921,9 +917,8 @@ def linear_decay(
learning_rate=None,
) -> float:
"""Linear decay between sigma_start and sigma_end over t training iterations."""
sigma = sigma_start * (1 - current_iter / max_iter) + sigma_end * (
current_iter / max_iter
)
ratio = current_iter / max_iter
sigma = sigma_start * (1 - ratio) + sigma_end * ratio

return sigma

Expand Down
12 changes: 8 additions & 4 deletions dbgsom/SomClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
check_is_fitted,
check_X_y,
)

from dbgsom.BaseSom import BaseSom


Expand Down Expand Up @@ -121,7 +120,7 @@ class SomClassifier(BaseSom, TransformerMixin, ClassifierMixin):
Average distance from all training samples to their nearest prototypes.
"""

def _prepare_inputs(
def _check_input_data(
self, X: npt.ArrayLike, y=npt.ArrayLike
) -> tuple[npt.NDArray, npt.ArrayLike]:
X, y = check_X_y(X=X, y=y, ensure_min_samples=4, dtype=[np.float64, np.float32])
Expand Down Expand Up @@ -151,7 +150,8 @@ def _label_prototypes(self, X, y) -> None:
)

def _fit(self, X: npt.ArrayLike, y: None | npt.ArrayLike = None):
self.classes_, y = np.unique(y, return_inverse=True)
classes, y = np.unique(y, return_inverse=True)
self.classes_ = classes

def predict(self, X: npt.ArrayLike) -> np.ndarray:
"""Predict class labels for samples in X.
Expand Down Expand Up @@ -208,6 +208,10 @@ def predict_proba(self, X: npt.ArrayLike) -> np.ndarray:
else:
X_transformed = self.transform(X)
node_probabilities = self._extract_values_from_graph("probabilities")
sample_probabilities = X_transformed @ node_probabilities
# Sample Probabilities do not sum to 1
sample_probabilities_unnormalized = X_transformed @ node_probabilities
sample_probabilities = sample_probabilities_unnormalized / (
sample_probabilities_unnormalized.sum(axis=1)[np.newaxis].T
)

return sample_probabilities
2 changes: 1 addition & 1 deletion dbgsom/SomVQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class SomVQ(BaseSom, ClusterMixin, TransformerMixin):
Average distance from all training samples to their nearest prototype.
"""

def _prepare_inputs(self, X: npt.ArrayLike, y=None) -> tuple[npt.NDArray, None]:
def _check_input_data(self, X: npt.ArrayLike, y=None) -> tuple[npt.NDArray, None]:
X = check_array(array=X, ensure_min_samples=4, dtype=[np.float64, np.float32])
return X, y

Expand Down

0 comments on commit fd0f26c

Please sign in to comment.