Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzb56 authored Mar 9, 2025
2 parents e766b8e + 2f2c803 commit 4c063ad
Show file tree
Hide file tree
Showing 15 changed files with 1,928 additions and 234 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
args: [ "--create", "--python-folders", "aeon" ]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.7
rev: v0.9.9
hooks:
- id: ruff
args: [ "--fix"]
Expand All @@ -41,7 +41,7 @@ repos:
args: [ "--py39-plus" ]

- repo: https://github.com/pycqa/isort
rev: 6.0.0
rev: 6.0.1
hooks:
- id: isort
name: isort
Expand Down
12 changes: 12 additions & 0 deletions aeon/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,18 @@ def __sklearn_is_fitted__(self):
"""Check fitted status and return a Boolean value."""
return self.is_fitted

def __sklearn_tags__(self):
"""Return sklearn style tags for the estimator."""
aeon_tags = self.get_tags()
sklearn_tags = super().__sklearn_tags__()
sklearn_tags.non_deterministic = aeon_tags.get("non_deterministic", False)
sklearn_tags.target_tags.one_d_labels = True
sklearn_tags.input_tags.three_d_array = True
sklearn_tags.input_tags.allow_nan = aeon_tags.get(
"capability:missing_values", False
)
return sklearn_tags

def _validate_data(self, **kwargs):
"""Sklearn data validation."""
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions aeon/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class name: BaseClassifier

import numpy as np
import pandas as pd
from sklearn.base import ClassifierMixin
from sklearn.metrics import get_scorer, get_scorer_names
from sklearn.model_selection import cross_val_predict

Expand All @@ -35,7 +36,7 @@ class name: BaseClassifier
from aeon.utils.validation.labels import check_classification_y


class BaseClassifier(BaseCollectionEstimator):
class BaseClassifier(ClassifierMixin, BaseCollectionEstimator):
"""
Abstract base class for time series classifiers.
Expand Down Expand Up @@ -66,7 +67,6 @@ def __init__(self):
self.classes_ = [] # classes seen in y, unique labels
self.n_classes_ = -1 # number of unique classes in y
self._class_dictionary = {}
self._estimator_type = "classifier"

super().__init__()

Expand Down
25 changes: 2 additions & 23 deletions aeon/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import final

import numpy as np
from sklearn.base import ClusterMixin

from aeon.base import BaseCollectionEstimator


class BaseClusterer(BaseCollectionEstimator):
class BaseClusterer(ClusterMixin, BaseCollectionEstimator):
"""Abstract base class for time series clusterers.
Parameters
Expand All @@ -26,10 +27,6 @@ class BaseClusterer(BaseCollectionEstimator):

@abstractmethod
def __init__(self):
# required for compatibility with some sklearn interfaces e.g.
# CalibratedClassifierCV
self._estimator_type = "clusterer"

super().__init__()

@final
Expand Down Expand Up @@ -132,24 +129,6 @@ def fit_predict(self, X, y=None) -> np.ndarray:
to return.
y: ignored, exists for API consistency reasons.
Returns
-------
np.ndarray (1d array of shape (n_cases,))
Index of the cluster each time series in X belongs to.
"""
return self._fit_predict(X, y)

def _fit_predict(self, X, y=None) -> np.ndarray:
"""Fit predict using base methods.
Parameters
----------
X : np.ndarray (2d or 3d array of shape (n_cases, n_timepoints) or shape
(n_cases, n_channels, n_timepoints)).
Time series instances to train clusterer and then have indexes each belong
to return.
y: ignored, exists for API consistency reasons.
Returns
-------
np.ndarray (1d array of shape (n_cases,))
Expand Down
6 changes: 2 additions & 4 deletions aeon/regression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class name: BaseRegressor

import numpy as np
import pandas as pd
from sklearn.base import RegressorMixin
from sklearn.metrics import get_scorer, get_scorer_names
from sklearn.model_selection import cross_val_predict
from sklearn.utils.multiclass import type_of_target
Expand All @@ -33,7 +34,7 @@ class name: BaseRegressor
from aeon.base._base import _clone_estimator


class BaseRegressor(BaseCollectionEstimator):
class BaseRegressor(RegressorMixin, BaseCollectionEstimator):
"""Abstract base class for time series regressors.
The base regressor specifies the methods and method signatures that all
Expand All @@ -54,9 +55,6 @@ class BaseRegressor(BaseCollectionEstimator):

@abstractmethod
def __init__(self):
# required for compatibility with some sklearn interfaces
self._estimator_type = "regressor"

super().__init__()

@final
Expand Down
4 changes: 2 additions & 2 deletions aeon/regression/feature_based/_catch22.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class Catch22Regressor(BaseRegressor):
>>> reg.fit(X, y)
Catch22Regressor(...)
>>> reg.predict(X)
array([0.63821896, 1.0906666 , 0.58323551, 1.57550709, 0.48413489,
0.70976176, 1.33206165, 1.09927538, 1.51673405, 0.31683308])
array([0.63821896, 1.0906666 , 0.64351536, 1.57550709, 0.46036267,
0.79297397, 1.32882497, 1.12603087, 1.51673405, 0.31683308])
"""

_tags = {
Expand Down
30 changes: 15 additions & 15 deletions aeon/regression/sklearn/tests/test_rotation_forest_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@ def test_rotf_output():
rotf.fit(X_train, y_train)

expected = [
0.02694297,
0.02694297,
0.01997832,
0.04276962,
0.09027588,
0.02706564,
0.02553648,
0.04075808,
0.02900289,
0.04248546,
0.02694297,
0.03667328,
0.0235855,
0.03444119,
0.0235855,
0.026,
0.0245,
0.0224,
0.0453,
0.0892,
0.0314,
0.026,
0.0451,
0.0287,
0.04,
0.026,
0.0378,
0.0265,
0.0356,
0.0281,
]

np.testing.assert_array_almost_equal(expected, rotf.predict(X_test[:15]), decimal=4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
def _yield_classification_checks(estimator_class, estimator_instances, datatypes):
"""Yield all classification checks for an aeon classifier."""
# only class required
if sys.platform != "darwin": # We cannot guarantee same results on ARM macOS
if sys.platform == "linux": # We cannot guarantee same results on ARM macOS
# Compare against results for both UnitTest and BasicMotions if available
yield partial(
check_classifier_against_expected_results,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def _yield_regression_checks(estimator_class, estimator_instances, datatypes):
"""Yield all regression checks for an aeon regressor."""
# only class required
if sys.platform != "darwin": # We cannot guarantee same results on ARM macOS
if sys.platform == "linux": # We cannot guarantee same results on ARM macOS
# Compare against results for both Covid3Month and CardanoSentiment if available
yield partial(
check_regressor_against_expected_results,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
def _yield_transformation_checks(estimator_class, estimator_instances, datatypes):
"""Yield all transformation checks for an aeon transformer."""
# only class required
if sys.platform != "darwin":
if sys.platform == "linux": # We cannot guarantee same results on ARM macOS
# Compare against results for both UnitTest and BasicMotions if available
yield partial(
check_transformer_against_expected_results,
estimator_class=estimator_class,
Expand Down
60 changes: 30 additions & 30 deletions aeon/testing/expected_results/expected_classifier_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@
)
unit_test_proba["TemporalDictionaryEnsemble"] = np.array(
[
[0.2778, 0.7222],
[0.7222, 0.2778],
[0.3307, 0.6693],
[0.6693, 0.3307],
[0.0, 1.0],
[0.6251, 0.3749],
[0.3749, 0.6251],
[0.5538, 0.4462],
[0.6693, 0.3307],
[1.0, 0.0],
[0.3749, 0.6251],
[0.4462, 0.5538],
[0.0, 1.0],
[0.4653, 0.5347],
[0.3749, 0.6251],
[0.5538, 0.4462],
[0.4462, 0.5538],
]
)
unit_test_proba["WEASEL"] = np.array(
Expand Down Expand Up @@ -263,16 +263,16 @@
)
unit_test_proba["HIVECOTEV2"] = np.array(
[
[0.0613, 0.9387],
[0.5531, 0.4479],
[0.0431, 0.9569],
[0.2239, 0.7761],
[0.6732, 0.3268],
[0.1211, 0.8789],
[1.0, 0.0],
[0.9751, 0.0249],
[0.9818, 0.0182],
[1.0, 0.0],
[0.7398, 0.2602],
[0.0365, 0.9635],
[0.7829, 0.2171],
[0.9236, 0.0764],
[0.7201, 0.2799],
[0.2058, 0.7942],
[0.8412, 0.1588],
[0.9441, 0.0559],
]
)
unit_test_proba["CanonicalIntervalForestClassifier"] = np.array(
Expand All @@ -293,12 +293,12 @@
[
[0.1, 0.9],
[0.8, 0.2],
[0.0, 1.0],
[0.1, 0.9],
[1.0, 0.0],
[0.7, 0.3],
[0.9, 0.1],
[0.8, 0.2],
[0.4, 0.6],
[0.5, 0.5],
[0.9, 0.1],
[1.0, 0.0],
]
Expand Down Expand Up @@ -379,11 +379,11 @@
[0.3505, 0.6495],
[0.1753, 0.8247],
[0.8247, 0.1753],
[0.3505, 0.6495],
[0.6495, 0.3505],
[0.701, 0.299],
[0.6495, 0.3505],
[0.1753, 0.8247],
[0.5258, 0.4742],
[0.8247, 0.1753],
[1.0, 0.0],
]
)
Expand Down Expand Up @@ -656,12 +656,12 @@
)
basic_motions_proba["FreshPRINCEClassifier"] = np.array(
[
[0.0, 0.0, 0.1, 0.9],
[0.0, 0.0, 0.2, 0.8],
[0.9, 0.1, 0.0, 0.0],
[0.0, 0.0, 0.8, 0.2],
[0.1, 0.9, 0.0, 0.0],
[0.1, 0.0, 0.0, 0.9],
[0.0, 0.0, 0.1, 0.9],
[0.1, 0.0, 0.1, 0.8],
[0.0, 0.0, 0.2, 0.8],
[0.7, 0.3, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
Expand Down Expand Up @@ -782,15 +782,15 @@
)
basic_motions_proba["DrCIFClassifier"] = np.array(
[
[0.1, 0.0, 0.2, 0.7],
[0.5, 0.4, 0.0, 0.1],
[0.0, 0.0, 0.8, 0.2],
[0.1, 0.9, 0.0, 0.0],
[0.1, 0.0, 0.3, 0.6],
[0.0, 0.0, 0.2, 0.8],
[0.4, 0.5, 0.1, 0.0],
[0.0, 0.0, 0.7, 0.3],
[0.2, 0.8, 0.0, 0.0],
[0.0, 0.0, 0.3, 0.7],
[0.0, 0.0, 0.3, 0.7],
[0.7, 0.2, 0.1, 0.0],
[0.0, 0.0, 0.7, 0.3],
[0.1, 0.7, 0.1, 0.1],
[0.5, 0.3, 0.0, 0.2],
[0.0, 0.0, 0.8, 0.2],
[0.2, 0.7, 0.0, 0.1],
[0.0, 0.9, 0.0, 0.1],
]
)
Expand Down
Loading

0 comments on commit 4c063ad

Please sign in to comment.