diff --git a/aeon/regression/feature_based/_catch22.py b/aeon/regression/feature_based/_catch22.py index 162d0e653b..579ec874e7 100644 --- a/aeon/regression/feature_based/_catch22.py +++ b/aeon/regression/feature_based/_catch22.py @@ -94,8 +94,8 @@ class Catch22Regressor(BaseRegressor): >>> reg.fit(X, y) Catch22Regressor(...) >>> reg.predict(X) - array([0.66497445, 1.52167747, 0.73353397, 1.57550709, 0.46036267, - 0.6494623 , 1.08156127, 1.09927538, 1.46025772, 0.37711294]) + array([0.63821896, 1.0906666 , 0.58323551, 1.57550709, 0.48413489, + 0.70976176, 1.33206165, 1.09927538, 1.51673405, 0.31683308]) """ _tags = { diff --git a/aeon/testing/expected_results/expected_classifier_outputs.py b/aeon/testing/expected_results/expected_classifier_outputs.py index d27f2d18cf..3bff249dd2 100644 --- a/aeon/testing/expected_results/expected_classifier_outputs.py +++ b/aeon/testing/expected_results/expected_classifier_outputs.py @@ -180,15 +180,15 @@ unit_test_proba["Catch22Classifier"] = np.array( [ [0.2, 0.8], - [1.0, 0.0], - [0.2, 0.8], - [0.6, 0.4], [0.9, 0.1], - [0.6, 0.4], - [0.6, 0.4], [0.0, 1.0], - [0.8, 0.2], - [0.6, 0.4], + [0.7, 0.3], + [0.7, 0.3], + [0.9, 0.1], + [0.7, 0.3], + [0.1, 0.9], + [0.7, 0.3], + [0.9, 0.1], ] ) unit_test_proba["FreshPRINCEClassifier"] = np.array( @@ -291,40 +291,40 @@ ) unit_test_proba["HIVECOTEV2"] = np.array( [ - [0.2469, 0.7531], - [0.6344, 0.3656], - [0.0959, 0.9041], + [0.0613, 0.9387], + [0.5531, 0.4479], + [0.0431, 0.9569], [1.0, 0.0], - [0.9796, 0.0204], + [0.9751, 0.0249], [1.0, 0.0], - [0.802, 0.198], - [0.2265, 0.7735], - [0.8224, 0.1776], - [0.9374, 0.0626], + [0.7398, 0.2602], + [0.0365, 0.9635], + [0.7829, 0.2171], + [0.9236, 0.0764], ] ) unit_test_proba["CanonicalIntervalForestClassifier"] = np.array( [ [0.3, 0.7], - [0.75, 0.25], - [0.3, 0.7], - [0.85, 0.15], - [0.7, 0.3], [0.9, 0.1], + [0.2, 0.8], + [0.8, 0.2], [0.7, 0.3], - [0.0, 1.0], - [0.75, 0.25], - [0.75, 0.25], + [0.9, 0.1], + [0.4, 0.6], + [0.3, 0.7], + [0.6, 0.4], + [0.8, 0.2], ] ) unit_test_proba["DrCIFClassifier"] = np.array( [ - [0.3, 0.7], - [0.8, 0.2], [0.2, 0.8], + [0.9, 0.1], + [0.1, 0.9], [1.0, 0.0], - [0.8, 0.2], [0.9, 0.1], + [1.0, 0.0], [0.8, 0.2], [0.5, 0.5], [1.0, 0.0], @@ -516,16 +516,16 @@ basic_motions_proba["ChannelEnsembleClassifier"] = np.array( [ - [0.0, 0.0825, 0.0, 0.9175], + [0.0, 0.0825, 0.25, 0.6675], [0.0, 0.3325, 0.6675, 0.0], - [0.0, 0.0825, 0.9175, 0.0], + [0.0, 0.0825, 0.6675, 0.25], [0.0, 0.0825, 0.6675, 0.25], [0.0, 0.0825, 0.0, 0.9175], - [0.0, 0.0825, 0.0, 0.9175], - [0.0, 0.3325, 0.6675, 0.0], - [0.0, 0.3325, 0.6675, 0.0], - [0.0, 0.5825, 0.4175, 0.0], + [0.0, 0.0825, 0.25, 0.6675], + [0.0, 0.3325, 0.4175, 0.25], [0.0, 0.3325, 0.4175, 0.25], + [0.0, 0.5825, 0.4175, 0.0], + [0.25, 0.0825, 0.6675, 0.0], ] ) basic_motions_proba["ClassifierPipeline"] = np.array( @@ -670,15 +670,15 @@ ) basic_motions_proba["Catch22Classifier"] = np.array( [ - [0.1, 0.0, 0.1, 0.8], - [0.3, 0.4, 0.2, 0.1], + [0.1, 0.0, 0.2, 0.7], + [0.2, 0.3, 0.2, 0.3], [0.0, 0.2, 0.6, 0.2], - [0.0, 0.8, 0.1, 0.1], - [0.1, 0.0, 0.0, 0.9], - [0.2, 0.0, 0.1, 0.7], - [0.4, 0.2, 0.2, 0.2], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.1, 0.2, 0.7], + [0.1, 0.1, 0.2, 0.6], + [0.2, 0.5, 0.2, 0.1], [0.1, 0.1, 0.6, 0.2], - [0.1, 0.9, 0.0, 0.0], + [0.0, 0.7, 0.2, 0.1], [0.0, 1.0, 0.0, 0.0], ] ) @@ -782,16 +782,16 @@ ) basic_motions_proba["CanonicalIntervalForestClassifier"] = np.array( [ - [0.2, 0.1, 0.2, 0.5], - [0.3, 0.2, 0.3, 0.2], - [0.1, 0.1, 0.7, 0.1], - [0.1, 0.5, 0.2, 0.2], - [0.1, 0.1, 0.0, 0.8], - [0.0, 0.1, 0.2, 0.7], - [0.3, 0.4, 0.1, 0.2], - [0.2, 0.0, 0.7, 0.1], - [0.2, 0.6, 0.1, 0.1], - [0.1, 0.5, 0.3, 0.1], + [0.0, 0.0, 0.1, 0.9], + [0.3, 0.5, 0.0, 0.2], + [0.0, 0.0, 0.8, 0.2], + [0.4, 0.2, 0.2, 0.2], + [0.1, 0.0, 0.0, 0.9], + [0.0, 0.0, 0.2, 0.8], + [0.3, 0.3, 0.2, 0.2], + [0.0, 0.2, 0.7, 0.1], + [0.0, 1.0, 0.0, 0.0], + [0.1, 0.7, 0.0, 0.2], ] ) basic_motions_proba["RandomIntervalSpectralEnsembleClassifier"] = np.array( @@ -811,15 +811,15 @@ basic_motions_proba["DrCIFClassifier"] = np.array( [ [0.0, 0.0, 0.2, 0.8], - [0.8, 0.2, 0.0, 0.0], - [0.0, 0.0, 0.6, 0.4], - [0.1, 0.5, 0.0, 0.4], - [0.0, 0.0, 0.4, 0.6], - [0.0, 0.0, 0.2, 0.8], - [0.5, 0.5, 0.0, 0.0], - [0.0, 0.0, 0.8, 0.2], - [0.4, 0.6, 0.0, 0.0], - [0.3, 0.6, 0.0, 0.1], + [0.4, 0.5, 0.1, 0.0], + [0.0, 0.0, 0.7, 0.3], + [0.2, 0.8, 0.0, 0.0], + [0.0, 0.0, 0.3, 0.7], + [0.0, 0.0, 0.3, 0.7], + [0.7, 0.2, 0.1, 0.0], + [0.0, 0.0, 0.7, 0.3], + [0.1, 0.7, 0.1, 0.1], + [0.0, 0.9, 0.0, 0.1], ] ) basic_motions_proba["IntervalForestClassifier"] = np.array( diff --git a/aeon/testing/expected_results/expected_regressor_outputs.py b/aeon/testing/expected_results/expected_regressor_outputs.py index cef5d95109..ecc4a1ef07 100644 --- a/aeon/testing/expected_results/expected_regressor_outputs.py +++ b/aeon/testing/expected_results/expected_regressor_outputs.py @@ -25,18 +25,7 @@ ) covid_3month_preds["Catch22Regressor"] = np.array( - [ - 0.0302, - 0.0354, - 0.0352, - 0.0345, - 0.0259, - 0.0484, - 0.0369, - 0.0827, - 0.0737, - 0.0526, - ] + [0.0310, 0.0555, 0.0193, 0.0359, 0.0261, 0.0361, 0.0387, 0.0835, 0.0827, 0.0414] ) covid_3month_preds["RandomForestRegressor"] = np.array( @@ -131,33 +120,11 @@ ) covid_3month_preds["CanonicalIntervalForestRegressor"] = np.array( - [ - 0.049, - 0.04, - 0.0299, - 0.0352, - 0.0423, - 0.0315, - 0.0519, - 0.0605, - 0.0647, - 0.037, - ] + [0.0412, 0.0420, 0.0292, 0.0202, 0.0432, 0.0192, 0.0155, 0.0543, 0.0412, 0.0399] ) covid_3month_preds["DrCIFRegressor"] = np.array( - [ - 0.0302, - 0.0778, - 0.0272, - 0.03, - 0.0405, - 0.0388, - 0.0351, - 0.093, - 0.1041, - 0.0263, - ] + [0.0376, 0.0317, 0.0274, 0.0143, 0.0332, 0.0397, 0.0386, 0.0721, 0.0632, 0.0211] ) covid_3month_preds["RandomIntervalRegressor"] = np.array( @@ -252,16 +219,16 @@ cardano_sentiment_preds["Catch22Regressor"] = np.array( [ - 0.2174, - 0.1394, - 0.3623, - 0.1496, - 0.3502, - 0.2719, - 0.1378, - 0.076, - 0.0587, - 0.3773, + 0.2537, + 0.1417, + 0.2980, + 0.1324, + 0.3519, + 0.1919, + 0.1790, + 0.1295, + 0.1644, + 0.3836, ] ) @@ -341,48 +308,15 @@ ) cardano_sentiment_preds["RISTRegressor"] = np.array( - [ - 0.3002, - 0.3174, - 0.718, - 0.089, - 0.4002, - 0.0825, - 0.5342, - 0.0, - 0.3503, - 0.448, - ] + [0.0825, 0.1924, 0.7180, 0.0413, 0.4840, 0.0825, 0.2336, 0.0000, 0.0413, 0.2814] ) cardano_sentiment_preds["CanonicalIntervalForestRegressor"] = np.array( - [ - 0.276, - 0.1466, - 0.282, - 0.205, - 0.125, - 0.0111, - 0.3672, - 0.0677, - 0.1773, - 0.2586, - ] + [0.2546, 0.1796, 0.3423, 0.2016, 0.2369, 0.3178, 0.2051, 0.2286, 0.1956, 0.2452] ) cardano_sentiment_preds["DrCIFRegressor"] = np.array( - [ - 0.2361, - 0.2222, - 0.2046, - 0.1709, - 0.2462, - 0.2369, - 0.1916, - 0.1995, - 0.0407, - 0.1428, - ] + [0.2569, 0.2196, 0.2948, 0.2677, 0.0829, 0.0677, 0.1584, 0.1840, 0.0498, 0.1383] ) cardano_sentiment_preds["IntervalForestRegressor"] = np.array( diff --git a/aeon/transformations/collection/_collection_wrapper.py b/aeon/transformations/collection/_collection_wrapper.py index 11bc859750..2b05d27116 100644 --- a/aeon/transformations/collection/_collection_wrapper.py +++ b/aeon/transformations/collection/_collection_wrapper.py @@ -37,8 +37,8 @@ class CollectionToSeriesWrapper(BaseTransformer): >>> y = load_airline() >>> wrap = CollectionToSeriesWrapper(Catch22()) >>> wrap.fit_transform(y) - 0 1 2 3 ... 18 19 20 21 - 0 155.800003 181.700012 49.0 0.541667 ... 0.282051 0.769231 0.166667 11.0 + 0 1 2 3 ... 18 19 20 21 + 0 155.8 181.7 27.498346 8.0 ... 0.769231 0.282051 0.024544 48.355452 [1 rows x 22 columns] """ diff --git a/aeon/transformations/collection/feature_based/_catch22.py b/aeon/transformations/collection/feature_based/_catch22.py index a874e189df..2134f45f9f 100644 --- a/aeon/transformations/collection/feature_based/_catch22.py +++ b/aeon/transformations/collection/feature_based/_catch22.py @@ -20,26 +20,51 @@ feature_names = [ "DN_HistogramMode_5", "DN_HistogramMode_10", - "SB_BinaryStats_diff_longstretch0", - "DN_OutlierInclude_p_001_mdrmd", - "DN_OutlierInclude_n_001_mdrmd", "CO_f1ecac", "CO_FirstMin_ac", - "SP_Summaries_welch_rect_area_5_1", - "SP_Summaries_welch_rect_centroid", - "FC_LocalSimple_mean3_stderr", - "CO_trev_1_num", "CO_HistogramAMI_even_2_5", - "IN_AutoMutualInfoStats_40_gaussian_fmmi", + "CO_trev_1_num", "MD_hrv_classic_pnn40", "SB_BinaryStats_mean_longstretch1", - "SB_MotifThree_quantile_hh", - "FC_LocalSimple_mean1_tauresrat", - "CO_Embed2_Dist_tau_d_expfit_meandiff", - "SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1", - "SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1", "SB_TransitionMatrix_3ac_sumdiagcov", "PD_PeriodicityWang_th0_01", + "CO_Embed2_Dist_tau_d_expfit_meandiff", + "IN_AutoMutualInfoStats_40_gaussian_fmmi", + "FC_LocalSimple_mean1_tauresrat", + "DN_OutlierInclude_p_001_mdrmd", + "DN_OutlierInclude_n_001_mdrmd", + "SP_Summaries_welch_rect_area_5_1", + "SB_BinaryStats_diff_longstretch0", + "SB_MotifThree_quantile_hh", + "SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1", + "SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1", + "SP_Summaries_welch_rect_centroid", + "FC_LocalSimple_mean3_stderr", +] + +feature_names_short = [ + "mode_5", + "mode_10", + "acf_timescale", + "acf_first_min", + "ami2", + "trev", + "high_fluctuation", + "stretch_high", + "transition_matrix", + "periodicity", + "embedding_dist", + "ami_timescale", + "whiten_timescale", + "outlier_timing_pos", + "outlier_timing_neg", + "centroid_freq", + "stretch_decreasing", + "entropy_pairs", + "rs_range", + "dfa", + "low_freq_power", + "forecast_error", ] @@ -57,17 +82,28 @@ class Catch22(BaseCollectionTransformer): list of names or indices for multiple features. If "all", all features are extracted. Valid features are as follows: - ["DN_HistogramMode_5", "DN_HistogramMode_10", - "SB_BinaryStats_diff_longstretch0", "DN_OutlierInclude_p_001_mdrmd", - "DN_OutlierInclude_n_001_mdrmd", "CO_f1ecac", "CO_FirstMin_ac", - "SP_Summaries_welch_rect_area_5_1", "SP_Summaries_welch_rect_centroid", - "FC_LocalSimple_mean3_stderr", "CO_trev_1_num", "CO_HistogramAMI_even_2_5", - "IN_AutoMutualInfoStats_40_gaussian_fmmi", "MD_hrv_classic_pnn40", - "SB_BinaryStats_mean_longstretch1", "SB_MotifThree_quantile_hh", - "FC_LocalSimple_mean1_tauresrat", "CO_Embed2_Dist_tau_d_expfit_meandiff", + ["DN_HistogramMode_5", "DN_HistogramMode_10", "CO_f1ecac","CO_FirstMin_ac", + "CO_HistogramAMI_even_2_5", "CO_trev_1_num", "MD_hrv_classic_pnn40", + "SB_BinaryStats_mean_longstretch1", "SB_TransitionMatrix_3ac_sumdiagcov", + "PD_PeriodicityWang_th0_01", "CO_Embed2_Dist_tau_d_expfit_meandiff", + "IN_AutoMutualInfoStats_40_gaussian_fmmi", "FC_LocalSimple_mean1_tauresrat", + "DN_OutlierInclude_p_001_mdrmd", "DN_OutlierInclude_n_001_mdrmd", + "SP_Summaries_welch_rect_area_5_1", "SB_BinaryStats_diff_longstretch0", + "SB_MotifThree_quantile_hh", "SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1", "SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1", - "SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1", - "SB_TransitionMatrix_3ac_sumdiagcov", "PD_PeriodicityWang_th0_01"] + "SP_Summaries_welch_rect_centroid", "FC_LocalSimple_mean3_stderr"] + Shortened: + ["mode_5", "mode_10", "acf_timescale", "acf_first_min", + "ami2", "trev", "high_fluctuation", + "stretch_high", "transition_matrix", + "periodicity", "embedding_dist", + "ami_timescale", "whiten_timescale", + "outlier_timing_pos", "outlier_timing_neg", + "centroid_freq", "stretch_decreasing", + "entropy_pairs", "rs_range", + "dfa", + "low_freq_power", "forecast_error"] + catch24 : bool, default=False Extract the mean and standard deviation as well as the 22 Catch22 features if true. If a List of specific features to extract is provided, "Mean" and/or @@ -122,12 +158,12 @@ class Catch22(BaseCollectionTransformer): >>> tnf.fit(X) Catch22(...) >>> print(tnf.transform(X)[0]) - [1.15639532e+00 1.31700575e+00 3.00000000e+00 2.00000000e-01 - 0.00000000e+00 1.00000000e+00 2.00000000e+00 1.10933565e-32 - 1.96349541e+00 5.10744398e-01 2.33853577e-01 3.89048349e-01 - 2.00000000e+00 1.00000000e+00 4.00000000e+00 1.88915916e+00 - 1.00000000e+00 1.70859420e-01 0.00000000e+00 0.00000000e+00 - 2.46913580e-02 0.00000000e+00] + [1.15639531e+00 1.31700577e+00 5.66227710e-01 2.00000000e+00 + 3.89048349e-01 2.33853577e-01 1.00000000e+00 3.00000000e+00 + 8.23045267e-03 0.00000000e+00 1.70859420e-01 2.00000000e+00 + 1.00000000e+00 2.00000000e-01 0.00000000e+00 1.10933565e-32 + 4.00000000e+00 2.04319187e+00 0.00000000e+00 0.00000000e+00 + 1.96349541e+00 5.51667002e-01] """ _tags = { @@ -174,7 +210,7 @@ def _transform(self, X, y=None): Returns ------- - Xt : array-like, shape = [n_cases, n_features*n_channels] + Xt : array-like, shape = [n_cases, num_features*n_channels] The catch22 features for each dimension. """ n_cases = len(X) @@ -189,51 +225,51 @@ def _transform(self, X, y=None): features = [ pycatch22.DN_HistogramMode_5, pycatch22.DN_HistogramMode_10, - pycatch22.SB_BinaryStats_diff_longstretch0, - pycatch22.DN_OutlierInclude_p_001_mdrmd, - pycatch22.DN_OutlierInclude_n_001_mdrmd, pycatch22.CO_f1ecac, pycatch22.CO_FirstMin_ac, - pycatch22.SP_Summaries_welch_rect_area_5_1, - pycatch22.SP_Summaries_welch_rect_centroid, - pycatch22.FC_LocalSimple_mean3_stderr, - pycatch22.CO_trev_1_num, pycatch22.CO_HistogramAMI_even_2_5, - pycatch22.IN_AutoMutualInfoStats_40_gaussian_fmmi, + pycatch22.CO_trev_1_num, pycatch22.MD_hrv_classic_pnn40, pycatch22.SB_BinaryStats_mean_longstretch1, - pycatch22.SB_MotifThree_quantile_hh, - pycatch22.FC_LocalSimple_mean1_tauresrat, - pycatch22.CO_Embed2_Dist_tau_d_expfit_meandiff, - pycatch22.SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1, - pycatch22.SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1, pycatch22.SB_TransitionMatrix_3ac_sumdiagcov, pycatch22.PD_PeriodicityWang_th0_01, + pycatch22.CO_Embed2_Dist_tau_d_expfit_meandiff, + pycatch22.IN_AutoMutualInfoStats_40_gaussian_fmmi, + pycatch22.FC_LocalSimple_mean1_tauresrat, + pycatch22.DN_OutlierInclude_p_001_mdrmd, + pycatch22.DN_OutlierInclude_n_001_mdrmd, + pycatch22.SP_Summaries_welch_rect_area_5_1, + pycatch22.SB_BinaryStats_diff_longstretch0, + pycatch22.SB_MotifThree_quantile_hh, + pycatch22.SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1, + pycatch22.SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1, + pycatch22.SP_Summaries_welch_rect_centroid, + pycatch22.FC_LocalSimple_mean3_stderr, ] else: features = [ Catch22._DN_HistogramMode_5, Catch22._DN_HistogramMode_10, - Catch22._SB_BinaryStats_diff_longstretch0, - Catch22._DN_OutlierInclude_p_001_mdrmd, - Catch22._DN_OutlierInclude_n_001_mdrmd, Catch22._CO_f1ecac, Catch22._CO_FirstMin_ac, - Catch22._SP_Summaries_welch_rect_area_5_1, - Catch22._SP_Summaries_welch_rect_centroid, - Catch22._FC_LocalSimple_mean3_stderr, - Catch22._CO_trev_1_num, Catch22._CO_HistogramAMI_even_2_5, - Catch22._IN_AutoMutualInfoStats_40_gaussian_fmmi, + Catch22._CO_trev_1_num, Catch22._MD_hrv_classic_pnn40, Catch22._SB_BinaryStats_mean_longstretch1, - Catch22._SB_MotifThree_quantile_hh, - Catch22._FC_LocalSimple_mean1_tauresrat, - Catch22._CO_Embed2_Dist_tau_d_expfit_meandiff, - Catch22._SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1, - Catch22._SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1, Catch22._SB_TransitionMatrix_3ac_sumdiagcov, Catch22._PD_PeriodicityWang_th0_01, + Catch22._CO_Embed2_Dist_tau_d_expfit_meandiff, + Catch22._IN_AutoMutualInfoStats_40_gaussian_fmmi, + Catch22._FC_LocalSimple_mean1_tauresrat, + Catch22._DN_OutlierInclude_p_001_mdrmd, + Catch22._DN_OutlierInclude_n_001_mdrmd, + Catch22._SP_Summaries_welch_rect_area_5_1, + Catch22._SB_BinaryStats_diff_longstretch0, + Catch22._SB_MotifThree_quantile_hh, + Catch22._SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1, + Catch22._SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1, + Catch22._SP_Summaries_welch_rect_centroid, + Catch22._FC_LocalSimple_mean3_stderr, ] c22_list = Parallel( @@ -284,17 +320,17 @@ def _transform_case(self, X, f_idx, features): args = [series] - if feature == 0 or feature == 1 or feature == 11: + if feature == 0 or feature == 1 or feature == 4: if smin is None: smin = numba_min(series) if smax is None: smax = numba_max(series) args = [series, smin, smax] - elif feature == 2 or feature == 22: + elif feature == 7 or feature == 22: if smean is None: smean = mean(series) args = [series, smean] - elif feature == 3 or feature == 4: + elif feature == 13 or feature == 14: if self.outlier_norm: if smean is None: smean = mean(series) @@ -303,7 +339,7 @@ def _transform_case(self, X, f_idx, features): args = [outlier_series] else: args = [series] - elif feature == 7 or feature == 8: + elif feature == 15 or feature == 20: if smean is None: smean = mean(series) if fft is None: @@ -312,34 +348,16 @@ def _transform_case(self, X, f_idx, features): ) fft = np.fft.fft(series - smean, n=nfft) args = [series, fft] - elif feature == 5 or feature == 6 or feature == 12: - if smean is None: - smean = mean(series) - if fft is None: - nfft = int( - np.power(2, np.ceil(np.log(len(series)) / np.log(2))) - ) - fft = np.fft.fft(series - smean, n=nfft) + elif feature == 2 or feature == 3: if ac is None: - ac = _autocorr(series, fft) - args = [ac] - elif feature == 15: - indices = np.argsort(series) - args = [series, indices] - elif feature == 16 or feature == 17 or feature == 20: - if smean is None: - smean = mean(series) - if fft is None: - nfft = int( - np.power(2, np.ceil(np.log(len(series)) / np.log(2))) - ) - fft = np.fft.fft(series - smean, n=nfft) + ac = _compute_autocorrelations(series) + args = [ac, len(series)] + elif feature == 12 or feature == 10 or feature == 8: if ac is None: - ac = _autocorr(series, fft) + ac = _compute_autocorrelations(series) if acfz is None: acfz = _ac_first_zero(ac) args = [series, acfz] - if feature == 22: c22[dim + n] = smean elif feature == 23: @@ -408,14 +426,14 @@ def _DN_HistogramMode_10(X, smin, smax): @staticmethod @njit(fastmath=True, cache=True) - def _SB_BinaryStats_diff_longstretch0(X, smean): - # Longest period of consecutive values above the mean. - mean_binary = np.zeros(len(X)) - for i in range(len(X)): - if X[i] - smean > 0: - mean_binary[i] = 1 + def _SB_BinaryStats_diff_longstretch0(X): + # Longest period of successive incremental decreases. + diff_binary = np.zeros(len(X) - 1) + for i in range(len(X) - 1): + if X[i + 1] - X[i] >= 0: + diff_binary[i] = 1 - return _long_stretch(mean_binary, 1) + return _long_stretch(diff_binary, 0) @staticmethod def _DN_OutlierInclude_p_001_mdrmd(X): @@ -430,22 +448,30 @@ def _DN_OutlierInclude_n_001_mdrmd(X): @staticmethod @njit(fastmath=True, cache=True) - def _CO_f1ecac(X_ac): + def _CO_f1ecac(X_ac, size): + # Parameter has already been transformed using _autocorr # First 1/e crossing of autocorrelation function. threshold = 0.36787944117144233 # 1 / np.exp(1) - for i in range(1, len(X_ac)): - if (X_ac[i - 1] - threshold) * (X_ac[i] - threshold) < 0: - return i + for i in range(len(X_ac) - 2): + if X_ac[i + 1] < threshold: + m = X_ac[i + 1] - X_ac[i] + if m == 0: + return size + dy = threshold - X_ac[i] + dx = dy / m + out = np.float64(i) + dx + return out + return len(X_ac) @staticmethod @njit(fastmath=True, cache=True) - def _CO_FirstMin_ac(X_ac): + def _CO_FirstMin_ac(X_ac, size): # First minimum of autocorrelation function. for i in range(1, len(X_ac) - 1): if X_ac[i] < X_ac[i - 1] and X_ac[i] < X_ac[i + 1]: return i - return len(X_ac) + return size @staticmethod def _SP_Summaries_welch_rect_area_5_1(X, X_fft): @@ -464,7 +490,7 @@ def _FC_LocalSimple_mean3_stderr(X): if len(X) - 3 < 3: return 0 res = _local_simple_mean(X, 3) - return np.std(res) + return _stddev(res, len(X) - 3) @staticmethod @njit(fastmath=True, cache=True) @@ -511,17 +537,32 @@ def _IN_AutoMutualInfoStats_40_gaussian_fmmi(X_ac): # First minimum of the automutual information function. tau = int(min(40, np.ceil(len(X_ac) / 2))) - diffs = np.zeros(tau - 1) - prev = -0.5 * np.log(1 - np.power(X_ac[1], 2)) - for i in range(len(diffs)): - corr = -0.5 * np.log(1 - np.power(X_ac[i + 2], 2)) - diffs[i] = corr - prev - prev = corr - - for i in range(len(diffs) - 1): - if diffs[i] * diffs[i + 1] < 0 and diffs[i] < 0: - return i + 1 + ami = np.zeros(len(X_ac), dtype=np.float64) + for i in range(tau): + lag_size = len(X_ac) - (i + 1) + y = X_ac[i + 1 :] + nom = 0.0 + denomX = 0.0 + denomY = 0.0 + meanX = 0.0 + for j in range(lag_size): + meanX += X_ac[j] + meanX = meanX / lag_size + meanY = np.mean(y) + for j in range(lag_size): + nom += (X_ac[j] - meanX) * (y[j] - meanY) + denomX += (X_ac[j] - meanX) * (X_ac[j] - meanX) + denomY += (y[j] - meanY) * (y[j] - meanY) + divisor = np.sqrt(denomX * denomY) + if divisor == 0: + return np.nan + ac = nom / np.sqrt(denomX * denomY) + ami[i] = -0.5 * np.log(1 - np.power(ac, 2)) + + for i in range(1, tau - 1): + if ami[i] < ami[i - 1] and ami[i] < ami[i + 1]: + return i return tau @staticmethod @@ -541,67 +582,69 @@ def _MD_hrv_classic_pnn40(X): @staticmethod @njit(fastmath=True, cache=True) - def _SB_BinaryStats_mean_longstretch1(X): - # Longest period of successive incremental decreases. - diff_binary = np.zeros(len(X) - 1) - for i in range(len(diff_binary)): - if X[i + 1] - X[i] >= 0: - diff_binary[i] = 1 + def _SB_BinaryStats_mean_longstretch1(X, smean): + # Longest period of consecutive values above the mean. + mean_binary = np.zeros(len(X) - 1) + for i in range(len(mean_binary)): + if X[i] - smean > 0: + mean_binary[i] = 1 - return _long_stretch(diff_binary, 0) + return _long_stretch(mean_binary, 1) @staticmethod @njit(fastmath=True, cache=True) - def _SB_MotifThree_quantile_hh(X, indices): - # Shannon entropy of two successive letters in equiprobable 3-letter - # symbolization. - bins = np.zeros(len(X)) - q1 = int(len(X) / 3) - q2 = q1 * 2 - l1 = np.zeros(q1, dtype=np.int_) - for i in range(q1): - l1[i] = indices[i] - l2 = np.zeros(q1, dtype=np.int_) - c1 = 0 - for i in range(q1, q2): - bins[indices[i]] = 1 - l2[c1] = indices[i] - c1 += 1 - l3 = np.zeros(len(indices) - q2, dtype=np.int_) - c2 = 0 - for i in range(q2, len(indices)): - bins[indices[i]] = 2 - l3[c2] = indices[i] - c2 += 1 - - found_last = False - nsum = 0 - for i in range(3): - if i == 0: - o = l1 - elif i == 1: - o = l2 - else: - o = l3 - - if not found_last: - for n in range(len(o)): - if o[n] == len(X) - 1: - o = np.delete(o, n) - break - - for n in range(3): - nsum2 = 0 - - for v in o: - if bins[v + 1] == n: - nsum2 += 1 - - if nsum2 > 0: - nsum2 /= len(X) - 1 - nsum += nsum2 * np.log(nsum2) - - return -nsum + def _SB_MotifThree_quantile_hh(X): + alphabet_size = 3 + yt = np.zeros(len(X), dtype=np.int32) + _sb_coarsegrain(X, 3, yt) + r1 = [np.zeros(len(X), np.int32) for i in range(alphabet_size)] + sizes_r1 = np.zeros(alphabet_size, np.int32) + for i in range(alphabet_size): + r_idx = 0 + sizes_r1[i] = 0 + for j in range(len(X)): + if yt[j] == i + 1: + r1[i][r_idx] = j + r_idx += 1 + sizes_r1[i] += 1 + + for i in range(alphabet_size): + if sizes_r1[i] != 0 and r1[i][sizes_r1[i] - 1] == len(X) - 1: + tmp_ar = np.zeros(sizes_r1[i], np.int32) + # isn't this doing the same thing? + for x in range(sizes_r1[i]): + tmp_ar[x] = r1[i][x] + for y in range(sizes_r1[i] - 1): + r1[i][y] = tmp_ar[y] + sizes_r1[i] -= 1 + + r2 = [ + [np.zeros(len(X), np.int32) for j in range(alphabet_size)] + for i in range(alphabet_size) + ] + sizes_r2 = [np.zeros(alphabet_size, np.int32) for i in range(alphabet_size)] + out2 = [np.zeros(alphabet_size, np.float64) for i in range(alphabet_size)] + + for i in range(alphabet_size): + for j in range(alphabet_size): + sizes_r2[i][j] = 0 + dynamic_idx = 0 + for k in range(sizes_r1[i]): + tmp_idx = yt[r1[i][k] + 1] + if tmp_idx == j + 1: + r2[i][j][dynamic_idx] = r1[i][k] + dynamic_idx += 1 + sizes_r2[i][j] += 1 + tmp = np.float64(sizes_r2[i][j]) / (np.float64(len(X)) - 1.0) + out2[i][j] = tmp + hh = 0.0 + for i in range(alphabet_size): + f = 0.0 + for j in range(alphabet_size): + if out2[i][j] > 0: + f += out2[i][j] * np.log(out2[i][j]) + hh += -1 * f + return hh @staticmethod def _FC_LocalSimple_mean1_tauresrat(X, acfz): @@ -609,12 +652,8 @@ def _FC_LocalSimple_mean1_tauresrat(X, acfz): if len(X) < 2: return 0 res = _local_simple_mean(X, 1) - mean = np.mean(res) - - nfft = int(np.power(2, np.ceil(np.log(len(res)) / np.log(2)))) - fft = np.fft.fft(res - mean, n=nfft) - ac = _autocorr(res, fft) + ac = _compute_autocorrelations(res) return _ac_first_zero(ac) / acfz @staticmethod @@ -624,50 +663,57 @@ def _CO_Embed2_Dist_tau_d_expfit_meandiff(X, acfz): tau = acfz if tau > len(X) / 10: tau = int(len(X) / 10) - d = np.zeros(len(X) - tau - 1) d_mean = 0 for i in range(len(d)): n = np.sqrt( - np.power(X[i + 1] - X[i], 2) + np.power(X[i + tau + 1] - X[i + tau], 2) + np.power(X[i + 1] - X[i], 2) + np.power(X[i + tau] - X[i + tau + 1], 2) ) d[i] = n d_mean += n - d_mean /= len(X) - tau - 1 - + d_mean /= len(d) smin = np.min(d) smax = np.max(d) srange = smax - smin std = np.std(d) - - if std == 0: - return np.nan - + if std < 0.001: + return 0 num_bins = int( - np.ceil(srange / (3.5 * np.std(d) / np.power(len(d), 0.3333333333333333))) + np.ceil( + srange + / (3.5 * _stddev(d, len(d)) / np.power(len(d), 0.3333333333333333)) + ) ) - if num_bins == 0: - return np.nan + return 0 bin_width = srange / num_bins - histogram = np.zeros(num_bins) + histogram = np.zeros(num_bins, dtype=np.int32) + binEdges = np.zeros(num_bins + 1, dtype=np.float64) for val in d: idx = int((val - smin) / bin_width) + if idx < 0: + idx = 0 if idx >= num_bins: idx = num_bins - 1 histogram[idx] += 1 - sum = 0 + for i in range(num_bins + 1): + binEdges[i] = i * bin_width + smin + + histogramNormalise = np.zeros(num_bins, dtype=np.float64) + for i in range(len(histogramNormalise)): + histogramNormalise[i] = histogram[i] / len(d) + + d_exp_fit = np.zeros(num_bins, dtype=np.float64) for i in range(num_bins): - center = ((smin + bin_width * i) * 2 + bin_width) / 2 - n = np.exp(-center / d_mean) / d_mean - if n < 0: - n = 0 + expf = np.exp(-(binEdges[i] + binEdges[i + 1]) * 0.5 / d_mean) / d_mean + if expf < 0: + expf = 0 - sum += np.abs(histogram[i] / len(d) - n) + d_exp_fit[i] = np.abs(histogramNormalise[i] - expf) - return sum / num_bins + return np.mean(d_exp_fit) @staticmethod @njit(fastmath=True, cache=True) @@ -696,46 +742,54 @@ def _SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1(X): @staticmethod @njit(fastmath=True, cache=True) def _SB_TransitionMatrix_3ac_sumdiagcov(X, acfz): - # Trace of covariance of transition matrix between symbols in 3-letter - # alphabet. - ds = np.zeros(int((len(X) - 1) / acfz + 1)) + # Trace of covariance of transition matrix between symbols in 3-letter alphabet. + ds = np.zeros(int(((len(X) - 1) / acfz) + 1), dtype=np.float64) for i in range(len(ds)): ds[i] = X[i * acfz] - indicies = np.argsort(ds) - - bins = np.zeros(len(ds), dtype=np.int32) - q1 = int(len(ds) / 3) - q2 = q1 * 2 - for i in range(q1 + 1, q2 + 1): - bins[indicies[i]] = 1 - for i in range(q2 + 1, len(indicies)): - bins[indicies[i]] = 2 - - t = np.zeros((3, 3)) + # swap to alphabet: + yCG = np.zeros(len(ds), dtype=np.int32) + _sb_coarsegrain(ds, 3, yCG) + T = np.zeros((3, 3), dtype=np.float64) for i in range(len(ds) - 1): - t[bins[i + 1]][bins[i]] += 1 - t /= len(ds) - 1 + T[yCG[i] - 1][yCG[i + 1] - 1] += 1 - means = np.zeros(3) for i in range(3): - means[i] = np.mean(t[i]) + for j in range(3): + if (len(ds) - 1) == 0: + T[i][j] = np.nan + else: + T[i][j] /= len(ds) - 1 + column1 = np.zeros(3, dtype=np.float64) + column2 = np.zeros(3, dtype=np.float64) + column3 = np.zeros(3, dtype=np.float64) - cov = np.zeros((3, 3)) for i in range(3): - for n in range(3): - covariance = 0 - for j in range(3): - covariance += (t[i][j] - means[i]) * (t[n][j] - means[n]) - covariance /= 2 - - cov[i][n] = covariance - cov[n][i] = covariance - - ssum = 0 + column1[i] = T[i][0] + column2[i] = T[i][1] + column3[i] = T[i][2] + columns = np.zeros((3, 3), dtype=np.float64) + columns[0] = column1 + columns[1] = column2 + columns[2] = column3 + + # columns = [column1, column2, column3] + cov_array = np.zeros((3, 3), dtype=np.float64) + for i in range(3): + for j in range(3): + covTemp = 0 + meanX = np.mean(columns[i]) + meanY = np.mean(columns[j]) + for k in range(3): + covTemp += (columns[i][k] - meanX) * (columns[j][k] - meanY) + covTemp = covTemp / 2 + cov_array[i][j] = covTemp + cov_array[j][i] = covTemp + + sum_of_diagonal_cov = 0.0 for i in range(3): - ssum += cov[i][i] + sum_of_diagonal_cov += cov_array[i][i] - return ssum + return sum_of_diagonal_cov @staticmethod @njit(fastmath=True, cache=True) @@ -787,19 +841,23 @@ def _PD_PeriodicityWang_th0_01(X): @njit(fastmath=True, cache=True) def _histogram_mode(X, num_bins, smin, smax): - bin_width = (smax - smin) / num_bins + srange = smax - smin + bin_width = srange / num_bins if bin_width == 0: return np.nan - histogram = np.zeros(num_bins) + histogram = np.zeros(num_bins, dtype=np.int32) + edges = np.zeros(num_bins + 1, dtype=np.float64) for val in X: idx = int((val - smin) / bin_width) - idx = num_bins - 1 if idx >= num_bins else idx + if idx < 0: + idx = 0 + if idx >= num_bins: + idx = num_bins - 1 histogram[idx] += 1 - edges = np.zeros(num_bins + 1, dtype=np.float32) - for i in range(len(edges)): + for i in range(num_bins + 1): edges[i] = i * bin_width + smin max_count = 0 @@ -814,12 +872,12 @@ def _histogram_mode(X, num_bins, smin, smax): elif histogram[i] == max_count: num_maxs += 1 max_sum += v - return max_sum / num_maxs @njit(fastmath=True, cache=True) def _long_stretch(X_binary, val): + # look for the longest consecutive given value in an array last_val = 0 max_stretch = 0 for i in range(len(X_binary)): @@ -1268,6 +1326,8 @@ def _verify_features(features, catch24): f_idx = [i for i in range(22)] if catch24: f_idx += [22, 23] + elif features in feature_names_short: + f_idx = [feature_names_short.index(features)] elif features in feature_names: f_idx = [feature_names.index(features)] elif catch24 and features == "Mean": @@ -1290,7 +1350,9 @@ def _verify_features(features, catch24): f_idx = [] for f in features: if isinstance(f, str): - if f in feature_names: + if f in feature_names_short: + f_idx.append(feature_names_short.index(f)) + elif f in feature_names: f_idx.append(feature_names.index(f)) elif catch24 and f == "Mean": f_idx.append(22) @@ -1315,3 +1377,108 @@ def _verify_features(features, catch24): raise ValueError("Invalid feature selection.") return f_idx + + +@njit(fastmath=True, cache=True) +def _compute_autocorrelations(X): + mean = np.mean(X) + nFFT = int(np.log2(len(X))) + if 2**nFFT == len(X): + nFFT = len(X) * 2 + else: + nFFT = (2 ** (nFFT + 1)) * 2 + F = np.zeros(nFFT * 2, dtype=np.complex128) + for i in range(len(X)): + F[i] = complex(X[i] - mean, 0.0) + for i in range(len(X), nFFT): + F[i] = complex(0.0, 0.0) + tw = np.zeros(nFFT * 2, dtype=np.complex128) + # twiddles + PI = np.pi + for i in range(nFFT): + tmp = 0.0 - PI * i / nFFT * 1j + tw[i] = np.exp(tmp) + F = _fft(F, tw) + # dot multiply + F = np.multiply(F, np.conj(F)) + F = _fft(F, tw) + divisor = F[0] + if np.real(divisor) == 0 and np.imag(divisor) == 0: + return np.zeros(nFFT * 2, dtype=np.float64) + F = F / divisor + out = np.real(F) + return out + + +@njit(fastmath=True, cache=True) +def _fft(a, tw): + n = a.shape[0] + log_n = int(np.log2(n)) + out = np.empty_like(a) + + # Bit-reversed addressing permutation + for i in range(n): + j = 0 + for k in range(log_n): + j = (j << 1) | ((i >> k) & 1) + out[j] = a[i] + + # Iterative FFT computation + step = 1 + while step < n: + halfstep = step + step = 2 * step + for i in range(0, n, step): + for j in range(halfstep): + t = tw[j * (n // step)] * out[i + j + halfstep] + u = out[i + j] + out[i + j] = u + t + out[i + j + halfstep] = u - t + + return out + + +@njit(fastmath=True, cache=True) +def _stddev(a, size): + m = np.mean(a[:size]) + sd = np.sqrt(np.sum((a[:size] - m) ** 2) / (size - 1)) + return sd + + +@njit(fastmath=True, cache=True) +def _sb_coarsegrain(y, num_groups, labels): + th = np.zeros((num_groups + 1), dtype=np.float64) + ls = np.zeros((num_groups + 1), dtype=np.float64) + # linspace + step_size = 1 / (num_groups) + start = 0 + for i in range(num_groups + 1): + ls[i] = start + start += step_size + for i in range(num_groups + 1): + th[i] = _quantile(y, ls[i]) + th[0] -= 1 + for i in range(num_groups): + for j in range(len(y)): + if y[j] > th[i] and y[j] <= th[i + 1]: + labels[j] = i + 1 + + +@njit(fastmath=True, cache=True) +def _quantile(X, quant): + tmp = np.sort(X) + q = 0.5 / len(X) + if quant < q: + value = tmp[0] + return value + elif quant > (1 - q): + value = tmp[len(X) - 1] + return value + + quant_idx = len(X) * quant - 0.5 + idx_left = int(np.floor(quant_idx)) + idx_right = int(np.ceil(quant_idx)) + value = tmp[idx_left] + (quant_idx - idx_left) * ( + tmp[idx_right] - tmp[idx_left] + ) / (idx_right - idx_left) + return value diff --git a/aeon/transformations/collection/feature_based/tests/test_catch22.py b/aeon/transformations/collection/feature_based/tests/test_catch22.py index 0f047793dc..5b5bc28925 100644 --- a/aeon/transformations/collection/feature_based/tests/test_catch22.py +++ b/aeon/transformations/collection/feature_based/tests/test_catch22.py @@ -23,7 +23,26 @@ def test_catch22_on_basic_motions(): data = c22.fit_transform(X_train[indices]) testing.assert_array_almost_equal( data, - catch22_basic_motions_data[:, np.sort(np.r_[0:132:22, 5:132:22, 9:132:22])], + catch22_basic_motions_data[:, np.sort(np.r_[0:132:22, 2:132:22, 21:132:22])], + decimal=4, + ) + + +def test_catch22_short_on_basic_motions(): + """Test of Catch22 on basic motions data.""" + X_train, _ = load_basic_motions(split="train") + indices = [28, 39, 4, 15, 26] + # fit Catch22 and assert transformed data is the same + c22 = Catch22(features=feature_names_short, replace_nans=True) + data = c22.fit_transform(X_train[indices]) + testing.assert_array_almost_equal(data, catch22_basic_motions_data, decimal=4) + + # fit Catch22 with select features and assert transformed data is the same + c22 = Catch22(replace_nans=True, features=feature_names) + data = c22.fit_transform(X_train[indices]) + testing.assert_array_almost_equal( + data, + catch22_basic_motions_data[:, np.sort(np.r_[0:132:22, 2:132:22, 21:132:22])], decimal=4, ) @@ -51,7 +70,7 @@ def test_catch22_wrapper_on_basic_motions(): testing.assert_array_almost_equal( data, catch22wrapper_basic_motions_data[ - :, np.sort(np.r_[0:132:22, 5:132:22, 9:132:22]) + :, np.sort(np.r_[0:132:22, 2:132:22, 21:132:22]) ], decimal=4, ) @@ -59,677 +78,703 @@ def test_catch22_wrapper_on_basic_motions(): feature_names = ["DN_HistogramMode_5", "CO_f1ecac", "FC_LocalSimple_mean3_stderr"] +feature_names_short = [ + "mode_5", + "mode_10", + "acf_timescale", + "acf_first_min", + "ami2", + "trev", + "high_fluctuation", + "stretch_high", + "transition_matrix", + "periodicity", + "embedding_dist", + "ami_timescale", + "whiten_timescale", + "outlier_timing_pos", + "outlier_timing_neg", + "centroid_freq", + "stretch_decreasing", + "entropy_pairs", + "rs_range", + "dfa", + "low_freq_power", + "forecast_error", +] + +# Both results should be the same catch22_basic_motions_data = np.array( [ [ 0.8918, 1.2091, - 7.0000, - -0.2000, - 0.0400, - 2.0000, - 3.0000, - 0.4158, - 0.9327, - 1.3956, - 0.9846, + 1.1428, + 3.0, 0.0956, - 1.0000, + 0.9846, 0.9192, - 4.0000, - 2.0822, - 0.5000, + 7.0, + 0.0108, + 13.0, 0.0616, - 0.2000, + 1.0, + 0.5, + -0.2, + 0.04, + 0.4158, + 4.0, + 2.0909, 0.3714, - 0.0108, - 13.0000, + 0.2, + 0.9327, + 1.4028, -1.8851, -3.4695, - 8.0000, - -0.5600, - 0.0200, - 3.0000, - 6.0000, - 10.6256, - 0.4909, - 3.4816, - -0.9328, + 2.2361, + 6.0, 0.2865, - 2.0000, + -0.9328, 0.9192, - 8.0000, - 1.8188, - 0.7500, + 8.0, + 0.0243, + 11.0, 0.0934, - 0.1714, + 2.0, + 0.75, + -0.56, + 0.02, + 10.6256, + 8.0, + 1.8244, 0.2857, - 0.0168, - 11.0000, - 0.0650, + 0.1714, + 0.4909, + 3.4997, + 0.065, 0.2902, - 9.0000, - -0.7800, - -0.3400, - 2.0000, - 5.0000, - 0.3406, - 0.6381, - 0.9140, - 0.1154, + 1.5664, + 5.0, 0.1343, - 2.0000, + 0.1154, 0.8485, - 8.0000, - 1.9982, - 0.3333, + 9.0, + 0.0046, + 10.0, 0.0967, - 0.6857, + 2.0, + 0.3333, + -0.78, + -0.34, + 0.3406, + 8.0, + 1.9998, 0.3143, - 0.0034, - 10.0000, + 0.6857, + 0.6381, + 0.9188, 0.0216, 0.0216, - 10.0000, - -0.3800, - 0.0600, - 2.0000, - 7.0000, - 0.1938, - 0.6381, - 0.7083, - -0.0848, + 1.2329, + 7.0, 0.1159, - 2.0000, - 0.8990, - 5.0000, - 2.0236, - 0.3333, + -0.0848, + 0.899, + 10.0, + 0.0101, + 12.0, 0.1549, - 0.1714, + 2.0, + 0.3333, + -0.38, + 0.06, + 0.1938, + 5.0, + 2.0078, 0.3143, - 0.0052, - 12.0000, + 0.1714, + 0.6381, + 0.712, 0.3028, 0.1962, - 10.0000, - 0.0300, - 0.3500, - 2.0000, - 7.0000, - 0.1221, - 0.4909, - 0.4555, - 0.0199, + 1.815, + 7.0, 0.1564, - 2.0000, + 0.0199, 0.8283, - 8.0000, - 1.9428, - 0.6667, + 10.0, + 0.0046, + 13.0, 0.5965, - 0.8000, + 2.0, + 0.6667, + 0.03, + 0.35, + 0.1221, + 8.0, + 1.9443, 0.2286, - 0.0037, - 13.0000, + 0.8, + 0.4909, + 0.4579, 0.5644, 0.9089, - 8.0000, - -0.1100, - -0.7200, - 3.0000, - 6.0000, - 1.7181, - 0.4909, - 1.4064, - -0.4162, + 2.1736, + 6.0, 0.1914, - 2.0000, + -0.4162, 0.9091, - 8.0000, - 1.8064, - 1.0000, + 8.0, + 0.0174, + 11.0, 0.1072, + 2.0, + 1.0, + -0.11, + -0.72, + 1.7181, + 8.0, + 1.8142, + 0.2, 0.7429, - 0.2000, - 0.0125, - 11.0000, + 0.4909, + 1.4137, ], [ 0.1467, 1.7614, - 9.0000, - -0.1500, - 0.0300, - 2.0000, - 8.0000, - 32.2850, - 0.4909, - 7.0546, - 159.5298, + 1.851, + 8.0, 0.1971, - 4.0000, + 159.5298, 0.9091, - 8.0000, - 1.9656, - 0.3333, + 9.0, + 0.0091, + 3.0, 0.1012, - 0.1714, + 4.0, + 0.3333, + -0.15, + 0.03, + 32.285, + 8.0, + 1.9501, 0.6571, - 0.0091, - 3.0000, - -1.0310, + 0.1714, + 0.4909, + 7.0913, + -1.031, 0.5692, - 13.0000, - 0.2000, - -0.3700, - 2.0000, - 3.0000, - 18.3029, - 0.7363, - 7.4032, - -298.4341, + 1.3705, + 3.0, 0.1008, - 3.0000, + -298.4341, 0.9091, - 7.0000, - 2.0451, - 0.3333, + 13.0, + 0.0037, + 8.0, 0.0992, + 3.0, + 0.3333, + 0.2, + -0.37, + 18.3029, + 7.0, + 2.0298, + 0.2, 0.3143, - 0.2000, - 0.0015, - 8.0000, + 0.7363, + 7.4417, 0.1721, -1.3777, - 12.0000, - -0.2300, - -0.0400, - 2.0000, - 3.0000, - 14.3938, - 0.7854, - 6.4700, + 1.2531, + 3.0, + 0.156, 91.0107, - 0.1560, - 3.0000, 0.9091, - 5.0000, - 2.0199, - 0.3333, + 12.0, + 0.0116, + 13.0, 0.1303, - 0.7429, + 3.0, + 0.3333, + -0.23, + -0.04, + 14.3938, + 5.0, + 2.0059, 0.5429, - 0.0107, - 13.0000, + 0.7429, + 0.7854, + 6.5036, -0.2157, -1.7152, - 13.0000, - 0.1500, - -0.1800, - 2.0000, - 5.0000, - 7.1407, - 0.6872, - 4.2252, - 4.1153, + 1.6441, + 5.0, 0.0992, - 2.0000, - 0.8990, - 6.0000, - 1.9989, - 0.3333, + 4.1153, + 0.899, + 13.0, + 0.0037, + 14.0, 0.1047, - 0.6571, + 2.0, + 0.3333, + 0.15, + -0.18, + 7.1407, + 6.0, + 2.0097, 0.2286, - 0.0024, - 14.0000, + 0.6571, + 0.6872, + 4.2472, 0.6856, -0.1177, - 14.0000, - 0.1800, - 0.3000, - 2.0000, - 4.0000, - 1.6007, - 0.8345, - 2.9957, - 3.2192, + 1.2534, + 4.0, 0.0709, - 1.0000, + 3.2192, 0.8687, - 5.0000, - 2.0294, + 14.0, + 0.0073, + 6.0, + 0.064, + 1.0, 0.3333, - 0.0640, - 0.7429, + 0.18, + 0.3, + 1.6007, + 5.0, + 2.0191, 0.8286, - 0.0061, - 6.0000, + 0.7429, + 0.8345, + 3.0112, -1.5437, 0.0972, - 8.0000, - -0.1400, - 0.1000, - 2.0000, - 5.0000, - 7.3076, - 0.6872, - 5.1671, - 119.5049, + 1.1481, + 5.0, 0.0292, - 2.0000, + 119.5049, 0.8889, - 5.0000, - 1.9850, - 0.3333, + 8.0, + 0.0092, + 8.0, 0.1307, + 2.0, + 0.3333, + -0.14, + 0.1, + 7.3076, + 5.0, + 1.9736, + 0.2, 0.2286, - 0.2000, - 0.0070, - 8.0000, + 0.6872, + 5.194, ], [ -0.2176, - -0.2520, - 7.0000, - -0.1300, - 0.0200, - 1.0000, - 6.0000, - 0.0081, - 0.8836, - 0.1506, - -0.0018, + -0.252, + 0.9342, + 6.0, 0.0569, - 2.0000, - 0.6970, - 5.0000, - 2.1494, - 0.3333, - 1.6321, - 0.6000, - 0.2000, + -0.0018, + 0.697, + 7.0, 0.0024, - 19.0000, - 0.2537, - 0.3414, - 8.0000, - -0.1200, - -0.0200, - 2.0000, - 5.0000, - 0.1288, - 0.5890, - 0.4545, + 19.0, + 1.6321, + 2.0, + 0.3333, + -0.13, + 0.02, + 0.0081, + 5.0, + 2.133, + 0.2, + 0.6, + 0.8836, + 0.1514, + 0.2537, + 0.3414, + 1.8247, + 5.0, + 0.213, 0.0234, - 0.2130, - 2.0000, 0.8586, - 7.0000, - 1.9398, - 1.0000, + 8.0, + 0.0116, + 10.0, 0.5715, + 2.0, + 1.0, + -0.12, + -0.02, + 0.1288, + 7.0, + 1.9505, + 0.2, 0.8286, - 0.2000, - 0.0107, - 10.0000, - -0.0450, + 0.589, + 0.4569, + -0.045, -0.0049, - 9.0000, - -0.9100, - -0.2200, - 2.0000, - 3.0000, - 0.0037, - 0.9327, - 0.1481, - -0.0013, + 1.1595, + 3.0, 0.0428, - 1.0000, + -0.0013, 0.6162, - 6.0000, - 2.0848, - 1.0000, + 9.0, + 0.0004, + 0.0, 1.2419, - 0.6000, + 1.0, + 1.0, + -0.91, + -0.22, + 0.0037, + 6.0, + 2.0869, 0.1714, - 0.0004, - 0.0000, + 0.6, + 0.9327, + 0.1488, -0.0368, -0.0176, - 11.0000, - -0.3900, - -0.0800, - 2.0000, - 5.0000, - 0.0041, - 0.5890, - 0.0770, - -0.0000, + 1.8915, + 5.0, 0.1782, - 2.0000, + -0.0, 0.3939, - 7.0000, - 1.9786, - 0.2500, - 3.9760, - 0.7429, + 11.0, + 0.0122, + 0.0, + 3.976, + 2.0, + 0.25, + -0.39, + -0.08, + 0.0041, + 7.0, + 1.9728, 0.2571, - 0.0295, - 0.0000, - 0.0200, + 0.7429, + 0.589, + 0.0774, + 0.02, 0.0044, - 8.0000, - -0.0600, - 0.0500, - 2.0000, - 5.0000, - 0.0013, - 0.5890, - 0.0467, - -0.0000, + 1.7819, + 5.0, 0.0615, - 2.0000, + -0.0, 0.2424, - 7.0000, - 2.0341, - 0.3333, + 8.0, + 0.0043, + 0.0, 6.8497, + 2.0, + 0.3333, + -0.06, + 0.05, + 0.0013, + 7.0, + 2.039, + 0.2, 0.7429, - 0.2000, - 0.0034, - 0.0000, + 0.589, + 0.0469, -0.0914, -0.1258, - 11.0000, - -0.1550, - 0.1250, - 2.0000, - 5.0000, - 0.0212, - 0.5890, - 0.1719, - 0.0001, + 1.9983, + 5.0, 0.1025, - 2.0000, + 0.0001, 0.7071, - 7.0000, - 1.8809, - 1.0000, + 11.0, + 0.0037, + 10.0, 3.1416, - 0.8000, - 0.2000, - 0.0024, - 10.0000, + 2.0, + 1.0, + -0.155, + 0.125, + 0.0212, + 7.0, + 1.8706, + 0.2, + 0.8, + 0.589, + 0.1728, ], [ - 14.7530, + 14.753, 12.6115, - 8.0000, - -0.0100, - -0.1700, - 2.0000, - 5.0000, - 8.3186, - 0.7363, - 15.1680, - 611.2311, + 1.3139, + 5.0, 0.2837, - 1.0000, + 611.2311, 0.8485, - 5.0000, - 2.0580, - 1.0000, + 8.0, + 0.0046, + 7.0, 0.1723, - 0.2571, + 1.0, + 1.0, + -0.01, + -0.17, + 8.3186, + 5.0, + 2.0538, 0.2286, - 0.0046, - 7.0000, + 0.2571, + 0.7363, + 15.2468, -8.7478, -11.0416, - 5.0000, - 0.0900, - 0.0100, - 2.0000, - 4.0000, - 5.3016, - 0.7363, - 16.0359, - -666.9228, + 1.2448, + 4.0, 0.2358, - 1.0000, + -666.9228, 0.8586, - 4.0000, - 2.0048, - 1.0000, + 5.0, + 0.005, + 7.0, 0.1222, - 0.8286, + 1.0, + 1.0, + 0.09, + 0.01, + 5.3016, + 4.0, + 2.0075, 0.1714, - 0.0050, - 7.0000, + 0.8286, + 0.7363, + 16.1192, -1.1495, -3.2478, - 8.0000, - 0.1300, - -0.0800, - 1.0000, - 2.0000, - 1.7627, - 1.4726, - 3.3270, + 0.8858, + 2.0, + 0.095, 3.9326, - 0.0950, - 2.0000, 0.8586, - 5.0000, - 2.1462, - 0.5000, + 8.0, + 0.0012, + 3.0, 0.0841, - 0.6000, + 2.0, + 0.5, + 0.13, + -0.08, + 1.7627, + 5.0, + 2.1476, 0.7714, - 0.0012, - 3.0000, - 0.0946, + 0.6, + 1.4726, + 3.3443, + 0.0945, -1.7292, - 5.0000, - 0.0600, - 0.1200, - 1.0000, - 3.0000, - 0.5924, - 0.8836, - 2.8071, - 13.9244, + 0.9238, + 3.0, 0.1318, - 1.0000, + 13.9244, 0.8586, - 7.0000, - 2.1384, - 0.5000, + 5.0, + 0.0124, + 8.0, 0.0608, - 0.8286, + 1.0, + 0.5, + 0.06, + 0.12, + 0.5924, + 7.0, + 2.1388, 0.2286, - 0.0124, - 8.0000, + 0.8286, + 0.8836, + 2.8217, -0.2413, -0.7284, - 6.0000, - -0.0500, - -0.1100, - 2.0000, - 4.0000, - 0.2086, - 0.7363, - 2.8863, - -0.3140, + 1.124, + 4.0, 0.1682, - 1.0000, + -0.314, 0.8384, - 6.0000, - 2.0693, - 0.5000, + 6.0, + 0.0165, + 8.0, 0.0498, - 0.8286, + 1.0, + 0.5, + -0.05, + -0.11, + 0.2086, + 6.0, + 2.0597, 0.1714, - 0.0165, - 8.0000, + 0.8286, + 0.7363, + 2.9013, -0.2211, 0.9037, - 6.0000, - 0.0100, - 0.0400, - 2.0000, - 4.0000, - 1.3957, - 0.7363, - 7.0085, - -63.8967, + 1.473, + 4.0, 0.2979, - 1.0000, + -63.8967, 0.8586, - 5.0000, - 1.9304, - 0.6667, + 6.0, + 0.0165, + 8.0, 0.0848, - 0.8286, + 1.0, + 0.6667, + 0.01, + 0.04, + 1.3957, + 5.0, + 1.9356, 0.1714, - 0.0162, - 8.0000, + 0.8286, + 0.7363, + 7.0449, ], [ -0.0619, 0.1991, - 6.0000, - 0.0300, - 0.1300, - 2.0000, - 4.0000, - 0.5010, - 0.8836, - 1.5209, - 1.2447, + 1.2578, + 4.0, 0.1921, - 1.0000, - 0.8990, - 6.0000, - 2.0481, - 0.3333, + 1.2447, + 0.899, + 6.0, + 0.0089, + 6.0, 0.0718, + 1.0, + 0.3333, + 0.03, + 0.13, + 0.501, + 6.0, + 2.0492, + 0.2, 0.1714, - 0.2000, - 0.0092, - 6.0000, + 0.8836, + 1.5288, -3.0176, -3.5235, - 9.0000, - 0.0000, - 0.1000, - 3.0000, - 7.0000, - 7.6385, - 0.4418, - 2.4961, - -1.2227, + 2.5291, + 7.0, 0.3019, - 2.0000, + -1.2227, 0.9091, - 8.0000, - 1.7771, - 1.0000, - 0.0990, - 0.1714, + 9.0, + 0.0156, + 13.0, + 0.099, + 2.0, + 1.0, + 0.0, + 0.1, + 7.6385, + 8.0, + 1.7807, 0.3143, - 0.0081, - 13.0000, - -0.5190, - -0.6649, - 14.0000, - -0.1300, - 0.1900, - 3.0000, - 6.0000, - 0.3096, + 0.1714, 0.4418, - 0.6655, - -0.0086, + 2.509, + -0.519, + -0.6649, + 2.0582, + 6.0, 0.2148, - 2.0000, + -0.0086, 0.8384, - 6.0000, - 1.8777, - 0.5000, + 14.0, + 0.0122, + 13.0, 0.4267, + 2.0, + 0.5, + -0.13, + 0.19, + 0.3096, + 6.0, + 1.8881, + 0.2, 0.6571, - 0.2000, - 0.0133, - 13.0000, + 0.4418, + 0.669, -0.2996, -0.0955, - 8.0000, - 0.1800, - 0.2200, - 3.0000, - 8.0000, - 0.3790, - 0.4418, - 0.7076, - -0.1388, + 2.0035, + 8.0, 0.2115, - 2.0000, + -0.1388, 0.8384, - 7.0000, - 1.9195, - 0.5000, + 8.0, + 0.0069, + 14.0, 0.2553, - 0.1714, + 2.0, + 0.5, + 0.18, + 0.22, + 0.379, + 7.0, + 1.9223, 0.2857, - 0.0064, - 14.0000, + 0.1714, + 0.4418, + 0.7113, -0.3873, 0.4293, - 9.0000, - 0.1100, - 0.3500, - 3.0000, - 7.0000, - 0.2719, - 0.4418, - 0.4774, - 0.0065, + 2.5648, + 7.0, 0.2377, - 3.0000, + 0.0065, 0.8081, - 7.0000, - 1.8065, - 1.0000, + 9.0, + 0.0122, + 13.0, 0.6829, - 0.1714, + 3.0, + 1.0, + 0.11, + 0.35, + 0.2719, + 7.0, + 1.7647, 0.2571, - 0.0098, - 13.0000, + 0.1714, + 0.4418, + 0.4799, -0.9868, -1.4515, - 9.0000, - -0.0400, - 0.0800, - 3.0000, - 7.0000, - 1.5658, - 0.4418, - 1.1244, - -0.0475, + 2.5809, + 7.0, 0.2511, - 3.0000, + -0.0475, 0.8687, - 8.0000, - 1.7647, - 1.0000, - 0.3430, - 0.1714, + 9.0, + 0.0208, + 13.0, + 0.343, + 3.0, + 1.0, + -0.04, + 0.08, + 1.5658, + 8.0, + 1.7683, 0.2857, - 0.0168, - 13.0000, + 0.1714, + 0.4418, + 1.1302, ], ] ) @@ -739,672 +784,672 @@ def test_catch22_wrapper_on_basic_motions(): [ 0.0804, 0.3578, - 4.0, - -0.28, - 0.04, 1.1428, 3.0, - 0.3176, - 0.9327, - 1.226, - 0.6572, 0.0956, - 1.0, + 0.6572, 0.9192, 7.0, - 2.0909, - 0.5, + 0.0108, + 13.0, 0.0653, - 0.2, + 1.0, + 0.5, + -0.28, + 0.04, + 0.3176, + 4.0, + 2.0909, 0.3714, - 0.0108, - 13.0, + 0.2, + 0.9327, + 1.226, -0.5321, -0.9703, - 8.0, - -0.05, - 0.02, 2.2361, 6.0, - 0.8128, - 0.4909, - 0.968, - -0.0197, 0.281, - 2.0, + -0.0197, 0.8889, 8.0, - 1.8244, - 0.75, - 0.3456, - 0.1714, - 0.2857, 0.0243, 11.0, + 0.3456, + 2.0, + 0.75, + -0.05, + 0.02, + 0.8128, + 8.0, + 1.8244, + 0.2857, + 0.1714, + 0.4909, + 0.968, 0.2931, 0.5641, - 8.0, - -0.78, - -0.34, 1.5664, 5.0, - 0.4932, - 0.6381, - 1.1057, - 0.2011, 0.1343, - 2.0, + 0.2011, 0.8485, 9.0, - 1.9998, - 0.3333, - 0.0597, - 0.6857, - 0.3143, 0.0046, 10.0, + 0.0597, + 2.0, + 0.3333, + -0.78, + -0.34, + 0.4932, + 8.0, + 1.9998, + 0.3143, + 0.6857, + 0.6381, + 1.1057, 0.0448, 0.0448, - 5.0, - -0.38, - 0.06, 1.2329, 7.0, - 0.474, - 0.6381, - 1.1134, - -0.3243, 0.1159, - 2.0, + -0.3243, 0.899, 10.0, - 2.0078, - 0.3333, - 0.0806, - 0.1714, - 0.3143, 0.0101, 12.0, + 0.0806, + 2.0, + 0.3333, + -0.38, + 0.06, + 0.474, + 5.0, + 2.0078, + 0.3143, + 0.1714, + 0.6381, + 1.1134, 0.7638, 0.5217, - 8.0, - 0.03, - 0.35, 1.815, 7.0, - 0.6291, - 0.4909, - 1.0392, - 0.2326, 0.13, - 2.0, + 0.2326, 0.899, 10.0, - 1.9443, - 0.6667, - 0.1696, - 0.8, - 0.2286, 0.0046, 13.0, + 0.1696, + 2.0, + 0.6667, + 0.03, + 0.35, + 0.6291, + 8.0, + 1.9443, + 0.2286, + 0.8, + 0.4909, + 1.0392, 0.4149, 0.6463, - 8.0, - -0.11, - -0.805, 2.1736, 6.0, - 0.7752, - 0.4909, - 0.9495, - -0.1261, 0.185, - 2.0, + -0.1261, 0.9091, 8.0, - 1.8142, - 1.0, - 0.2083, - 0.7429, - 0.2, 0.0174, 11.0, + 0.2083, + 2.0, + 1.0, + -0.11, + -0.805, + 0.7752, + 8.0, + 1.8142, + 0.2, + 0.7429, + 0.4909, + 0.9495, ], [ -0.8062, -0.5798, - 8.0, - -0.16, - 0.01, 1.851, 8.0, - 0.635, - 0.4909, - 0.9945, - 0.44, 0.1823, - 4.0, + 0.44, 0.8586, 9.0, - 1.9501, - 0.3333, - 0.1544, - 0.1714, - 0.6571, 0.0091, 24.0, + 0.1544, + 4.0, + 0.3333, + -0.16, + 0.01, + 0.635, + 8.0, + 1.9501, + 0.6571, + 0.1714, + 0.4909, + 0.9945, -0.1266, 0.1157, - 7.0, - 0.2, - -0.37, 1.3705, 3.0, - 0.4195, - 0.7363, - 1.1266, - -1.0354, 0.1041, - 3.0, + -1.0354, 0.8889, 13.0, - 2.0298, - 0.3333, - 0.0602, - 0.3143, - 0.2, 0.0037, 8.0, + 0.0602, + 3.0, + 0.3333, + 0.2, + -0.37, + 0.4195, + 7.0, + 2.0298, + 0.2, + 0.3143, + 0.7363, + 1.1266, 0.4734, 0.207, - 5.0, - -0.23, - 0.06, 1.2531, 3.0, - 0.4252, - 0.7854, - 1.1178, - 0.4621, 0.1742, - 3.0, + 0.4621, 0.8687, 12.0, - 2.0059, - 0.3333, - 0.0523, - 0.7429, - 0.5429, 0.0116, 13.0, + 0.0523, + 3.0, + 0.3333, + -0.23, + 0.06, + 0.4252, + 5.0, + 2.0059, + 0.5429, + 0.7429, + 0.7854, + 1.1178, -0.0013, -0.3798, - 6.0, - 0.06, - -0.2, 1.6441, 5.0, - 0.4549, - 0.6872, - 1.0719, - 0.0662, 0.0992, - 2.0, + 0.0662, 0.8788, 13.0, - 2.0097, - 0.3333, - 0.0879, - 0.6571, - 0.2286, 0.0037, 14.0, + 0.0879, + 2.0, + 0.3333, + 0.06, + -0.2, + 0.4549, + 6.0, + 2.0097, + 0.2286, + 0.6571, + 0.6872, + 1.0719, 0.2664, -0.0514, - 5.0, - 0.19, - 0.3, 1.2534, 4.0, - 0.2506, - 0.8345, - 1.1916, - 0.1995, 0.0692, - 1.0, + 0.1995, 0.8081, 14.0, + 0.0073, + 6.0, + 0.0673, + 1.0, + 0.3333, + 0.19, + 0.3, + 0.2506, + 5.0, 2.0191, - 0.3333, - 0.0673, - 0.7429, 0.8286, - 0.0073, - 6.0, + 0.7429, + 0.8345, + 1.1916, -0.3504, 0.0039, - 5.0, - -0.14, - -0.04, 1.1481, 5.0, - 0.3408, - 0.6872, - 1.1216, - 1.2034, 0.0276, - 2.0, + 1.2034, 0.7778, 8.0, - 1.9736, - 0.3333, - 0.0342, - 0.2286, - 0.2, 0.0092, 8.0, + 0.0342, + 2.0, + 0.3333, + -0.14, + -0.04, + 0.3408, + 5.0, + 1.9736, + 0.2, + 0.2286, + 0.6872, + 1.1216, ], [ 0.3645, 0.1094, - 5.0, - -0.13, - -0.34, 0.9342, 6.0, - 0.4448, - 0.8836, - 1.1237, - -0.7506, 0.0879, - 2.0, + -0.7506, 0.8687, 7.0, - 2.133, - 0.3333, - 0.0845, - 0.6, - 0.2, 0.0024, 7.0, + 0.0845, + 2.0, + 0.3333, + -0.13, + -0.34, + 0.4448, + 5.0, + 2.133, + 0.2, + 0.6, + 0.8836, + 1.1237, 0.3658, 0.5744, - 7.0, - -0.23, - -0.01, 1.8247, 5.0, - 0.7293, - 0.589, - 1.087, - 0.3158, 0.2149, - 2.0, + 0.3158, 0.8687, 8.0, - 1.9505, - 1.0, - 0.1758, - 0.8286, - 0.2, 0.0116, 10.0, + 0.1758, + 2.0, + 1.0, + -0.23, + -0.01, + 0.7293, + 7.0, + 1.9505, + 0.2, + 0.8286, + 0.589, + 1.087, -0.3643, -0.0421, - 6.0, - -0.91, - -0.22, 1.1595, 3.0, - 0.2371, - 0.9327, - 1.1946, - -0.6767, 0.13, - 1.0, + -0.6767, 0.8182, 9.0, - 2.0869, - 1.0, - 0.0611, - 0.6, - 0.1714, 0.0004, 7.0, + 0.0611, + 1.0, + 1.0, + -0.91, + -0.22, + 0.2371, + 6.0, + 2.0869, + 0.1714, + 0.6, + 0.9327, + 1.1946, -0.3788, -0.1279, - 7.0, - -0.39, - -0.08, 1.8915, 5.0, - 0.6975, - 0.589, - 1.0127, - -0.0644, 0.1358, - 2.0, + -0.0644, 0.8586, 11.0, - 1.9728, - 0.25, - 0.1512, - 0.7429, - 0.2571, 0.0122, 10.0, + 0.1512, + 2.0, + 0.25, + -0.39, + -0.08, + 0.6975, + 7.0, + 1.9728, + 0.2571, + 0.7429, + 0.589, + 1.0127, 0.2208, -0.1147, - 7.0, - -0.08, - 0.03, 1.7819, 5.0, - 0.6061, - 0.589, - 1.0105, - -0.1311, 0.0862, - 2.0, + -0.1311, 0.8687, 8.0, - 2.039, - 0.3333, - 0.1591, - 0.7429, - 0.2, 0.0043, 12.0, + 0.1591, + 2.0, + 0.3333, + -0.08, + 0.03, + 0.6061, + 7.0, + 2.039, + 0.2, + 0.7429, + 0.589, + 1.0105, -0.6362, -0.8457, - 7.0, - -0.2, - 0.12, 1.9983, 5.0, - 0.7828, - 0.589, - 1.0502, - 0.0228, 0.1284, - 2.0, + 0.0228, 0.8586, 11.0, - 1.8706, - 1.0, - 0.3778, - 0.8, - 0.2, 0.0037, 10.0, + 0.3778, + 2.0, + 1.0, + -0.2, + 0.12, + 0.7828, + 7.0, + 1.8706, + 0.2, + 0.8, + 0.589, + 1.0502, ], [ 0.7285, 0.5526, - 5.0, - -0.07, - -0.17, 1.3139, 5.0, - 0.0561, - 0.7363, - 1.2518, - 0.3383, 0.254, - 1.0, + 0.3383, 0.8081, 8.0, - 2.0538, - 1.0, - 0.1392, - 0.2571, - 0.2286, 0.0046, 7.0, + 0.1392, + 1.0, + 1.0, + -0.07, + -0.17, + 0.0561, + 5.0, + 2.0538, + 0.2286, + 0.2571, + 0.7363, + 1.2518, -0.3475, -0.5286, - 4.0, - 0.08, - -0.01, 1.2448, 4.0, - 0.0331, - 0.7363, - 1.2729, - -0.3284, 0.2395, - 1.0, + -0.3284, 0.8283, 5.0, - 2.0075, - 1.0, - 0.1219, - 0.8286, - 0.1714, 0.005, 7.0, - 0.4846, - -0.2633, - 5.0, - -0.02, - -0.12, - 0.8858, - 2.0, - 0.2239, - 1.4726, - 1.192, - 0.1781, - 0.1024, + 0.1219, + 1.0, + 1.0, + 0.08, + -0.01, + 0.0331, + 4.0, + 2.0075, + 0.1714, + 0.8286, + 0.7363, + 1.2729, + 0.4846, + -0.2633, + 0.8858, 2.0, + 0.1024, + 0.1781, 0.8485, 8.0, - 2.1476, - 0.5, - 0.0817, - 0.6, - 0.7714, 0.0012, 3.0, + 0.0817, + 2.0, + 0.5, + -0.02, + -0.12, + 0.2239, + 5.0, + 2.1476, + 0.7714, + 0.6, + 1.4726, + 1.192, 0.0544, -0.747, - 7.0, - 0.06, - 0.12, 0.9238, 3.0, - 0.1144, - 0.8836, - 1.2399, - 1.1815, 0.1503, - 1.0, + 1.1815, 0.8283, 5.0, - 2.1388, - 0.5, - 0.0664, - 0.8286, - 0.2286, 0.0124, 8.0, + 0.0664, + 1.0, + 0.5, + 0.06, + 0.12, + 0.1144, + 7.0, + 2.1388, + 0.2286, + 0.8286, + 0.8836, + 1.2399, -0.1774, -0.3877, - 6.0, - -0.05, - -0.1, 1.124, 4.0, - 0.0388, - 0.7363, - 1.2521, - -0.0252, 0.1556, - 1.0, + -0.0252, 0.8182, 6.0, - 2.0597, - 0.5, - 0.086, - 0.8286, - 0.1714, 0.0165, 8.0, + 0.086, + 1.0, + 0.5, + -0.05, + -0.1, + 0.0388, + 6.0, + 2.0597, + 0.1714, + 0.8286, + 0.7363, + 1.2521, -0.0394, 0.1582, - 5.0, - 0.01, - 0.04, 1.473, 4.0, - 0.0431, - 0.7363, - 1.2381, - -0.3468, 0.3, - 1.0, + -0.3468, 0.8586, 6.0, - 1.9356, - 0.6667, - 0.2129, - 0.8286, - 0.1714, 0.0165, 8.0, + 0.2129, + 1.0, + 0.6667, + 0.01, + 0.04, + 0.0431, + 5.0, + 1.9356, + 0.1714, + 0.8286, + 0.7363, + 1.2381, ], [ -0.6177, -0.4164, - 6.0, - 0.03, - -0.14, 1.2578, 4.0, - 0.2982, - 0.8836, - 1.1794, - 0.5714, 0.1921, - 1.0, + 0.5714, 0.899, 6.0, - 2.0492, - 0.3333, - 0.1321, - 0.1714, - 0.2, 0.0089, 6.0, + 0.1321, + 1.0, + 0.3333, + 0.03, + -0.14, + 0.2982, + 6.0, + 2.0492, + 0.2, + 0.1714, + 0.8836, + 1.1794, -1.0401, -1.2153, - 8.0, - 0.0, - 0.1, 2.5291, 7.0, - 0.9158, - 0.4418, - 0.8688, - -0.0508, 0.3019, - 2.0, + -0.0508, 0.8889, 9.0, - 1.7807, - 1.0, - 0.3581, - 0.1714, - 0.3143, 0.0156, 13.0, + 0.3581, + 2.0, + 1.0, + 0.0, + 0.1, + 0.9158, + 8.0, + 1.7807, + 0.3143, + 0.1714, + 0.4418, + 0.8688, -0.2369, -0.4506, - 6.0, - -0.13, - 0.3, 2.0582, 6.0, - 0.665, - 0.4418, - 0.9804, - -0.0271, 0.2042, - 2.0, + -0.0271, 0.8586, 14.0, - 1.8881, - 0.5, - 0.2277, - 0.6571, - 0.2, 0.0122, 13.0, + 0.2277, + 2.0, + 0.5, + -0.13, + 0.3, + 0.665, + 6.0, + 1.8881, + 0.2, + 0.6571, + 0.4418, + 0.9804, -0.4152, -0.1326, - 7.0, - 0.18, - 0.22, 2.0035, 8.0, - 0.7259, - 0.4418, - 0.9844, - -0.3678, 0.1768, - 2.0, + -0.3678, 0.8384, 8.0, - 1.9223, - 0.5, - 0.1498, - 0.1714, - 0.2857, 0.0069, 14.0, + 0.1498, + 2.0, + 0.5, + 0.18, + 0.22, + 0.7259, + 7.0, + 1.9223, + 0.2857, + 0.1714, + 0.4418, + 0.9844, -0.7112, 0.7554, - 7.0, - 0.11, - 0.22, 2.5648, 7.0, - 0.8769, - 0.4418, - 0.8619, - 0.0377, 0.2621, - 3.0, + 0.0377, 0.8283, 9.0, - 1.7647, - 1.0, - 0.317, - 0.1714, - 0.2571, 0.0122, 13.0, + 0.317, + 3.0, + 1.0, + 0.11, + 0.22, + 0.8769, + 7.0, + 1.7647, + 0.2571, + 0.1714, + 0.4418, + 0.8619, -0.7754, -1.133, - 8.0, - -0.04, - 0.08, 2.5809, 7.0, - 0.9274, - 0.4418, - 0.8698, - -0.0217, 0.2648, - 3.0, + -0.0217, 0.8485, 9.0, - 1.7683, - 1.0, - 0.4885, - 0.1714, - 0.2857, 0.0208, 13.0, + 0.4885, + 3.0, + 1.0, + -0.04, + 0.08, + 0.9274, + 8.0, + 1.7683, + 0.2857, + 0.1714, + 0.4418, + 0.8698, ], ] ) diff --git a/aeon/transformations/collection/interval_based/tests/test_intervals.py b/aeon/transformations/collection/interval_based/tests/test_intervals.py index 23aeda4e31..d0374eb5ea 100644 --- a/aeon/transformations/collection/interval_based/tests/test_intervals.py +++ b/aeon/transformations/collection/interval_based/tests/test_intervals.py @@ -58,4 +58,4 @@ def test_supervised_transformers(): ) X_t = sit.fit_transform(X, y) - assert X_t.shape == (X.shape[0], 7) + assert X_t.shape == (X.shape[0], 8)