Skip to content

Commit

Permalink
added more test cases and updated fit_predict method to set fitted to…
Browse files Browse the repository at this point in the history
… true
  • Loading branch information
Ramana-Raja committed Mar 3, 2025
1 parent b19a3be commit 75a9dce
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
8 changes: 5 additions & 3 deletions aeon/clustering/feature_based/_r_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,15 +448,17 @@ def _fit_predict(self, X, y=None) -> np.ndarray:
pca = PCA().fit(X_std)
optimal_dimensions = np.argmax(pca.explained_variance_ratio_ < 0.01)

pca = PCA(n_components=optimal_dimensions, random_state=self.random_state)
transformed_data_pca = pca.fit_transform(X_std)
self.pca = PCA(n_components=optimal_dimensions, random_state=self.random_state)
self.pca.fit(X_std)
transformed_data_pca = self.pca.transform(X_std)
self.estimator = KMeans(
n_clusters=self.n_clusters,
random_state=self.random_state,
n_init=self.n_init,
)
Y = self.estimator.fit_predict(transformed_data_pca)
self.labels_ = self.estimator.labels_
self.is_fitted = True
return Y

@classmethod
Expand All @@ -479,4 +481,4 @@ def _get_test_params(cls, parameter_set="default") -> dict:
return {
"n_clusters": 2,
"random_state": 1,
}
}
46 changes: 38 additions & 8 deletions aeon/clustering/feature_based/tests/test_r_cluster.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Test For RCluster."""

import numpy as np
import pytest
from sklearn import metrics

from aeon.datasets import load_gunpoint
from aeon.clustering.feature_based._r_cluster import RClusterer
from aeon.utils.validation._dependencies import _check_estimator_deps

X_ = [
[
Expand Down Expand Up @@ -132,15 +130,47 @@
Y = ["22", "28", "21", "15", "2", "18", "21", "36", "11", "21"]


@pytest.mark.skipif(
not _check_estimator_deps(RClusterer, severity="none"),
reason="skip test if required soft dependencies not available",
)
def test_r_cluster():
def test_r_cluster_custom_dataset():
"""Test implementation of RCluster."""
X_train = np.array(X_)
X = np.expand_dims(X_train, axis=1)
Rcluster = RClusterer(n_clusters=8, n_init=10, random_state=1)
labels_pred1 = Rcluster.fit_predict(X)
score = metrics.adjusted_rand_score(labels_true=Y, labels_pred=labels_pred1)
assert score > 0.36

def test_r_cluster_dataset():
"""Test implementation of RCluster using aeon dataset."""

X_train, y_train = load_gunpoint(split="train")
X_test, y_test = load_gunpoint(split="test")
num_points = 20

X_train = X_train[:num_points]
y_train = y_train[:num_points]
X_test = X_test[:num_points]
y_test = y_test[:num_points]

rcluster = RClusterer(
random_state=1,
n_init=2,
n_clusters=2,
)
train_result = rcluster.fit_predict(X_train)
train_score = metrics.rand_score(y_train, train_result)
test_result = rcluster.predict(X_test)
test_score = metrics.rand_score(y_test, test_result)
assert np.array_equal(
test_result,
[1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
)
assert np.array_equal(
train_result,
[1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
)
assert test_score == 0.5210526315789473
assert train_score == 0.5210526315789473
assert rcluster.estimator.n_iter_ == 3
assert np.array_equal(
rcluster.labels_, [1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
)

0 comments on commit 75a9dce

Please sign in to comment.