Skip to content

Commit 1e27566

Browse files
aryanpolagithub-actions[bot]
authored andcommitted
Automatic pre-commit fixes
1 parent 2a35014 commit 1e27566

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

aeon/transformations/collection/shapelet_based/_dilated_shapelet_transform.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from numba.typed import List
1717
from sklearn.preprocessing import LabelEncoder
1818

19-
2019
from aeon.distances import get_distance_function
2120
from aeon.transformations.collection import BaseCollectionTransformer
2221
from aeon.utils.numba.general import (
@@ -173,7 +172,7 @@ def __init__(
173172

174173
super().__init__()
175174

176-
def _fit(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] =None):
175+
def _fit(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] = None):
177176
"""Fit the random dilated shapelet transform to a specified X and y.
178177
179178
Parameters
@@ -247,7 +246,7 @@ def _fit(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] =None):
247246

248247
return self
249248

250-
def _transform(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]]=None):
249+
def _transform(self, X: np.ndarray, y: Optional[Union[np.ndarray, List]] = None):
251250
"""Transform X according to the extracted shapelets.
252251
253252
Parameters
@@ -347,7 +346,9 @@ def _check_input_params(self):
347346
self.threshold_percentiles_ = np.asarray(self.threshold_percentiles_)
348347

349348
@classmethod
350-
def get_test_params(cls, parameter_set: str="default") -> "Union[Dict, List[Dict]]":
349+
def get_test_params(
350+
cls, parameter_set: str = "default"
351+
) -> "Union[Dict, List[Dict]]":
351352
"""Return testing parameter settings for the estimator.
352353
353354
Parameters
@@ -540,7 +541,7 @@ def random_dilated_shapelet_extraction(
540541
shapelets and candidate subsequences
541542
542543
Returns
543-
--------
544+
-------
544545
Shapelets : tuple
545546
The returned tuple contains 7 arrays describing the shapelets parameters:
546547
- values : array, shape (max_shapelets, n_channels, max(shapelet_lengths))
@@ -689,7 +690,19 @@ def random_dilated_shapelet_extraction(
689690

690691

691692
@njit(fastmath=True, cache=True, parallel=True)
692-
def dilated_shapelet_transform(X: np.ndarray, shapelets: tuple[np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray,np.ndarray], distance: CPUDispatcher):
693+
def dilated_shapelet_transform(
694+
X: np.ndarray,
695+
shapelets: tuple[
696+
np.ndarray,
697+
np.ndarray,
698+
np.ndarray,
699+
np.ndarray,
700+
np.ndarray,
701+
np.ndarray,
702+
np.ndarray,
703+
],
704+
distance: CPUDispatcher,
705+
):
693706
"""Perform the shapelet transform with a set of shapelets and a set of time series.
694707
695708
Parameters
@@ -832,7 +845,13 @@ def get_all_subsequences(X: np.ndarray, length: int, dilation: int) -> np.ndarra
832845

833846

834847
@njit(fastmath=True, cache=True)
835-
def compute_shapelet_features(X_subs: np.ndarray, values: np.ndarray, length: int, threshold: float, distance: CPUDispatcher):
848+
def compute_shapelet_features(
849+
X_subs: np.ndarray,
850+
values: np.ndarray,
851+
length: int,
852+
threshold: float,
853+
distance: CPUDispatcher,
854+
):
836855
"""Extract the features from a shapelet distance vector.
837856
838857
Given a shapelet and a time series, extract three features from the resulting
@@ -879,7 +898,9 @@ def compute_shapelet_features(X_subs: np.ndarray, values: np.ndarray, length: in
879898

880899

881900
@njit(fastmath=True, cache=True)
882-
def compute_shapelet_dist_vector(X_subs: np.ndarray, values: np.ndarray, length: int, distance: CPUDispatcher):
901+
def compute_shapelet_dist_vector(
902+
X_subs: np.ndarray, values: np.ndarray, length: int, distance: CPUDispatcher
903+
):
883904
"""Extract the features from a shapelet distance vector.
884905
885906
Given a shapelet and a time series, extract three features from the resulting

0 commit comments

Comments
 (0)