Skip to content

Commit

Permalink
Merge pull request #5 from SandroMartens/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
SandroMartens authored Mar 22, 2024
2 parents b65d424 + 8600b0b commit acfc6d7
Show file tree
Hide file tree
Showing 12 changed files with 2,173 additions and 360 deletions.
351 changes: 216 additions & 135 deletions dbgsom/BaseSom.py

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions dbgsom/SomClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
check_is_fitted,
check_X_y,
)

from .BaseSom import BaseSom


Expand Down Expand Up @@ -126,8 +127,9 @@ def _check_input_data(
X, y = check_X_y(X=X, y=y, ensure_min_samples=4, dtype=[np.float64, np.float32])
return X, y

def _label_prototypes(self, X, y) -> None:
winners = self._get_winning_neurons(X, n_bmu=1)
def _label_prototypes(self, X: npt.ArrayLike, y=npt.ArrayLike) -> None:
"""This method assigns labels to the prototypes based on the input data."""
_, winners = self._get_winning_neurons(X, n_bmu=1)
for winner_index, neuron in enumerate(self.neurons_):
labels = y[winners == winner_index]
# dead neuron
Expand All @@ -150,8 +152,9 @@ def _label_prototypes(self, X, y) -> None:
)

def _fit(self, X: npt.ArrayLike, y: None | npt.ArrayLike = None):
classes, y = np.unique(y, return_inverse=True)
self.classes_ = classes
pass
# classes, y = np.unique(y, return_inverse=True)
# self.classes_ = classes

def predict(self, X: npt.ArrayLike) -> np.ndarray:
"""Predict class labels for samples in X.
Expand Down Expand Up @@ -190,7 +193,7 @@ def predict_proba(self, X: npt.ArrayLike) -> np.ndarray:
check_is_fitted(self)
X = check_array(X)
if self.vertical_growth:
winners = self._get_winning_neurons(X, n_bmu=1)
_, winners = self._get_winning_neurons(X, n_bmu=1)
probabilities_rows = []
for sample, winner in zip(X, winners):
node = self.neurons_[winner]
Expand All @@ -206,7 +209,6 @@ def predict_proba(self, X: npt.ArrayLike) -> np.ndarray:
sample_probabilities = np.array(probabilities_rows)

else:
# pass
X_transformed = self.transform(X)
node_probabilities = self._extract_values_from_graph("probabilities")
# Sample Probabilities do not sum to 1
Expand Down
16 changes: 13 additions & 3 deletions dbgsom/SomVQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import numpy as np
import numpy.typing as npt
from sklearn.base import ClusterMixin, TransformerMixin, check_array, check_is_fitted
from sklearn.base import (
ClusterMixin,
TransformerMixin,
check_array,
check_is_fitted,
)

from .BaseSom import BaseSom

Expand Down Expand Up @@ -115,7 +120,8 @@ class SomVQ(BaseSom, ClusterMixin, TransformerMixin):

def _check_input_data(self, X: npt.ArrayLike, y=None) -> tuple[npt.NDArray, None]:
X = check_array(array=X, ensure_min_samples=4, dtype=[np.float64, np.float32])
return X, y
# throw away any y
return X, None

def _label_prototypes(self, X: npt.ArrayLike, y=None) -> None:
for i, neuron in enumerate(self.som_):
Expand All @@ -137,6 +143,10 @@ def predict(self, X: npt.ArrayLike) -> np.ndarray:
"""
check_is_fitted(self)
X = check_array(X)
labels = self._get_winning_neurons(X, n_bmu=1)
_, labels = self._get_winning_neurons(X, n_bmu=1)

return labels

def _fit(self, X: npt.NDArray):

self.labels_ = self.predict(X)
Empty file removed dbgsom/dbgsom.test
Empty file.
463 changes: 439 additions & 24 deletions examples/2d_example.ipynb

Large diffs are not rendered by default.

Binary file modified examples/2d_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
466 changes: 438 additions & 28 deletions examples/chain_link.ipynb

Large diffs are not rendered by default.

582 changes: 514 additions & 68 deletions examples/darknet.ipynb

Large diffs are not rendered by default.

43 changes: 28 additions & 15 deletions examples/digits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,34 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'dbgsom'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[1], line 9\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpipeline\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Pipeline\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpreprocessing\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StandardScaler\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdbgsom\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdbgsom_\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DBGSOM\n",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'dbgsom'"
]
}
],
"source": [
"import seaborn.objects as so\n",
"import seaborn as sns\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from dbgsom.dbgsom_ import DBGSOM\n",
"from sklearn.preprocessing import StandardScaler\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import seaborn.objects as so\n",
"from sklearn.datasets import load_digits\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.datasets import load_digits"
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"from dbgsom.dbgsom_ import DBGSOM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -70,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -114,7 +127,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -134,7 +147,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -144,7 +157,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -174,7 +187,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -206,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -253,7 +266,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.12.1"
},
"orig_nbformat": 4
},
Expand Down
Loading

0 comments on commit acfc6d7

Please sign in to comment.