Skip to content

Commit 98ee176

Browse files
committed
added attributes to rsast too
1 parent d20abf8 commit 98ee176

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

aeon/transformations/collection/shapelet_based/_rsast.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class RSAST(BaseCollectionTransformer):
5454
5555
Parameters
5656
----------
57-
n_random_points: int default = 10 the number of initial random points to extract
57+
n_random_points: int default = 10
58+
the number of initial random points to extract
5859
len_method: string default="both" the type of statistical tool used to get
5960
the length of shapelets. "both"=ACF&PACF, "ACF"=ACF, "PACF"=PACF,
6061
"None"=Extract randomly any length from the TS
@@ -63,10 +64,27 @@ class RSAST(BaseCollectionTransformer):
6364
the number of reference time series to select per class
6465
seed : int, default = None
6566
the seed of the random generator
66-
classifier : sklearn compatible classifier, default = None
67+
estimator : sklearn compatible classifier, default = None
6768
if None, a RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)) is used.
6869
n_jobs : int, default -1
6970
Number of threads to use for the transform.
71+
72+
Attributes
73+
---------
74+
_kernels : list
75+
The z-normalized subsequences used for transformation.
76+
_kernel_orig : list
77+
The original (non z-normalized) subsequences.
78+
_start_positions : list
79+
The starting positions of each subsequence within the original time series.
80+
_classes : list
81+
The class labels associated with each subsequence.
82+
_source_series: list
83+
The index of the original time series in the training set from which each
84+
subsequence was derived.
85+
_kernels_generators_ : dict
86+
A dictionary mapping class labels to the selected reference time series
87+
for that class.
7088
7189
References
7290
----------
@@ -112,6 +130,9 @@ def __init__(
112130
self._kernels = None # z-normalized subsequences
113131
self._cand_length_list = {}
114132
self._kernel_orig = []
133+
self._start_positions = []
134+
self._classes = []
135+
self._source_series = [] # To store the index of the original time series
115136
self._kernels_generators = {} # Reference time series
116137
super().__init__()
117138

@@ -154,7 +175,12 @@ def _fit(self, X: np.ndarray, y: Union[np.ndarray, List]) -> "RSAST":
154175
self.num_classes = classes.shape[0]
155176
m_kernel = 0
156177

157-
# 1--calculate ANOVA per each time t throught the lenght of the TS
178+
# Initialize lists to store start positions, classes, and source series
179+
self._start_positions = []
180+
self._classes = []
181+
self._source_series = []
182+
183+
# 1--calculate ANOVA per each time t throughout the length of the TS
158184
for i in range(X_.shape[1]):
159185
statistic_per_class = {}
160186
for c in classes:
@@ -184,12 +210,16 @@ def _fit(self, X: np.ndarray, y: Union[np.ndarray, List]) -> "RSAST":
184210
X_c = X_[y == c]
185211

186212
cnt = np.min([self.nb_inst_per_class, X_c.shape[0]]).astype(int)
213+
214+
# Store the original indices of the sampled time series
215+
original_indices = np.where(y == c)[0]
187216

188-
choosen = self._random_state.permutation(X_c.shape[0])[:cnt]
217+
chosen_indices = self._random_state.permutation(X_c.shape[0])[:cnt]
189218

190219
self._kernels_generators[c] = []
191220

192-
for rep, idx in enumerate(choosen):
221+
for rep, idx in enumerate(chosen_indices):
222+
original_idx = original_indices[idx] # Get the original index
193223
# defining indices for length list
194224
idx_len_list = c + "," + str(idx) + "," + str(rep)
195225

@@ -290,6 +320,11 @@ def _fit(self, X: np.ndarray, y: Union[np.ndarray, List]) -> "RSAST":
290320
self._kernel_orig.append(np.squeeze(kernel))
291321
self._kernels_generators[c].extend(X_c[idx].reshape(1, -1))
292322

323+
# Store the start position, class, and the original index in the training set
324+
self._start_positions.append(i)
325+
self._classes.append(c)
326+
self._source_series.append(original_idx)
327+
293328
# 3--save the calculated subsequences
294329
n_kernels = len(self._kernel_orig)
295330

aeon/transformations/collection/shapelet_based/_sast.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ class SAST(BaseCollectionTransformer):
7070
_classes : list
7171
The class labels associated with each subsequence.
7272
_source_series: list
73-
The index of the original time series in the training set from which each subsequence was derived.
73+
The index of the original time series in the training set from which each
74+
subsequence was derived.
7475
kernels_generators_ : dict
75-
A dictionary mapping class labels to the selected reference time series for that class.
76+
A dictionary mapping class labels to the selected reference time series
77+
for that class.
7678
7779
7880
References

0 commit comments

Comments
 (0)