Skip to content

Commit

Permalink
added catch22regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
harshithasudhakar authored Jul 13, 2024
1 parent 276c739 commit e917433
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
30 changes: 30 additions & 0 deletions aeon/testing/expected_results/expected_regressor_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,33 @@
0.4744,
]
)

covid_3month_preds["Catch22Regressor"] = np.array(
[
0.0302,
0.0354,
0.0352,
0.0345,
0.0259,
0.0484,
0.0369,
0.0827,
0.0737,
0.0526,
]
)

cardano_sentiment_preds["Catch22Regressor"] = np.array(
[
0.2174,
0.1394,
0.3623,
0.1496,
0.3502,
0.2719,
0.1378,
0.076,
0.0587,
0.3773,
]
)
10 changes: 7 additions & 3 deletions aeon/testing/expected_results/regressor_results_reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.utils._testing import set_random_state

from aeon.datasets import load_cardano_sentiment, load_covid_3month
from aeon.regression.feature_based import FreshPRINCERegressor
from aeon.regression.feature_based import FreshPRINCERegressor, Catch22Regressor


def _reproduce_regression_covid_3month(estimator):
Expand Down Expand Up @@ -42,9 +42,13 @@ def _print_results_for_regressor(regressor_name, dataset_name):
regressor = FreshPRINCERegressor.create_test_instance(
parameter_set="results_comparison"
)
elif regressor_name == "Catch22Regressor":
regressor = Catch22Regressor.create_test_instance(
parameter_set="results_comparison"
)
else:
raise ValueError(f"Unknown regressor: {regressor_name}")

if dataset_name == "Covid3Month":
data_function = _reproduce_regression_covid_3month
elif dataset_name == "CardanoSentiment":
Expand All @@ -62,4 +66,4 @@ def _print_results_for_regressor(regressor_name, dataset_name):

if __name__ == "__main__":
# change as required when adding new classifiers, datasets or updating results
_print_results_for_regressor("FreshPRINCERegressor", "Covid3Month")
_print_results_for_regressor("Catch22Regressor", "CardanoSentiment")

0 comments on commit e917433

Please sign in to comment.