|
16 | 16 | from numba.typed import List
|
17 | 17 | from sklearn.preprocessing import LabelEncoder
|
18 | 18 |
|
19 |
| - |
20 | 19 | from aeon.distances import get_distance_function
|
21 | 20 | from aeon.transformations.collection import BaseCollectionTransformer
|
22 | 21 | from aeon.utils.numba.general import (
|
@@ -173,7 +172,7 @@ def __init__(
|
173 | 172 |
|
174 | 173 | super().__init__()
|
175 | 174 |
|
176 |
| - def _fit(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] =None): |
| 175 | + def _fit(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] = None): |
177 | 176 | """Fit the random dilated shapelet transform to a specified X and y.
|
178 | 177 |
|
179 | 178 | Parameters
|
@@ -247,7 +246,7 @@ def _fit(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] =None):
|
247 | 246 |
|
248 | 247 | return self
|
249 | 248 |
|
250 |
| - def _transform(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]]=None): |
| 249 | + def _transform(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] = None): |
251 | 250 | """Transform X according to the extracted shapelets.
|
252 | 251 |
|
253 | 252 | Parameters
|
@@ -347,7 +346,9 @@ def _check_input_params(self):
|
347 | 346 | self.threshold_percentiles_ = np.asarray(self.threshold_percentiles_)
|
348 | 347 |
|
349 | 348 | @classmethod
|
350 |
| - def get_test_params(cls, parameter_set: str="default") -> "Union[Dict, List[Dict]]": |
| 349 | + def get_test_params( |
| 350 | + cls, parameter_set: str = "default" |
| 351 | + ) -> "Union[Dict, List[Dict]]": |
351 | 352 | """Return testing parameter settings for the estimator.
|
352 | 353 |
|
353 | 354 | Parameters
|
@@ -540,7 +541,7 @@ def random_dilated_shapelet_extraction(
|
540 | 541 | shapelets and candidate subsequences
|
541 | 542 |
|
542 | 543 | Returns
|
543 |
| - -------- |
| 544 | + ------- |
544 | 545 | Shapelets : tuple
|
545 | 546 | The returned tuple contains 7 arrays describing the shapelets parameters:
|
546 | 547 | - values : array, shape (max_shapelets, n_channels, max(shapelet_lengths))
|
@@ -689,7 +690,19 @@ def random_dilated_shapelet_extraction(
|
689 | 690 |
|
690 | 691 |
|
691 | 692 | @njit(fastmath=True, cache=True, parallel=True)
|
692 |
| -def dilated_shapelet_transform(X: np.ndarray, shapelets: tuple[np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray], distance: CPUDispatcher): |
| 693 | +def dilated_shapelet_transform( |
| 694 | + X: np.ndarray, |
| 695 | + shapelets: tuple[ |
| 696 | + np.ndarray, |
| 697 | + np.ndarray, |
| 698 | + np.ndarray, |
| 699 | + np.ndarray, |
| 700 | + np.ndarray, |
| 701 | + np.ndarray, |
| 702 | + np.ndarray, |
| 703 | + ], |
| 704 | + distance: CPUDispatcher, |
| 705 | +): |
693 | 706 | """Perform the shapelet transform with a set of shapelets and a set of time series.
|
694 | 707 |
|
695 | 708 | Parameters
|
@@ -832,7 +845,13 @@ def get_all_subsequences(X: np.ndarray, length: int, dilation: int) -> np.ndarra
|
832 | 845 |
|
833 | 846 |
|
834 | 847 | @njit(fastmath=True, cache=True)
|
835 |
| -def compute_shapelet_features(X_subs: np.ndarray, values: np.ndarray, length: int, threshold: float, distance: CPUDispatcher): |
| 848 | +def compute_shapelet_features( |
| 849 | + X_subs: np.ndarray, |
| 850 | + values: np.ndarray, |
| 851 | + length: int, |
| 852 | + threshold: float, |
| 853 | + distance: CPUDispatcher, |
| 854 | +): |
836 | 855 | """Extract the features from a shapelet distance vector.
|
837 | 856 |
|
838 | 857 | Given a shapelet and a time series, extract three features from the resulting
|
@@ -879,7 +898,9 @@ def compute_shapelet_features(X_subs: np.ndarray, values: np.ndarray, length: in
|
879 | 898 |
|
880 | 899 |
|
881 | 900 | @njit(fastmath=True, cache=True)
|
882 |
| -def compute_shapelet_dist_vector(X_subs: np.ndarray, values: np.ndarray, length: int, distance: CPUDispatcher): |
| 901 | +def compute_shapelet_dist_vector( |
| 902 | + X_subs: np.ndarray, values: np.ndarray, length: int, distance: CPUDispatcher |
| 903 | +): |
883 | 904 | """Extract the features from a shapelet distance vector.
|
884 | 905 |
|
885 | 906 | Given a shapelet and a time series, extract three features from the resulting
|
|
0 commit comments