Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Added useful attributes to extracted shapelets for RDST #1959

Merged
merged 14 commits into from
Aug 13, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ class RandomDilatedShapeletTransform(BaseCollectionTransformer):
shapelets : list
The stored shapelets. Each item in the list is a tuple containing:
- shapelet values
- startpoint values
- length parameter
- dilation parameter
- threshold parameter
- normalization parameter
- mean parameter
- standard deviation parameter
- class value
max_shapelet_length_ : int
The maximum actual shapelet length fitted to train data.
min_n_timepoints_ : int
Expand Down Expand Up @@ -240,7 +242,7 @@ def _fit(self, X, y=None):

# Shapelet "length" is length-1 times dilation
self.max_shapelet_length_ = np.max(
(self.shapelets_[1] - 1) * self.shapelets_[2]
(self.shapelets_[2] - 1) * self.shapelets_[3]
)

return self
Expand Down Expand Up @@ -410,6 +412,8 @@ def _init_random_shapelet_params(
-------
values : array, shape (max_shapelets, n_channels, max(shapelet_lengths))
An initialized (empty) value array for each shapelet
startpoints: array, shape (max_shapelets)
An initialized (empty) startpoint array for each shapelet
lengths : array, shape (max_shapelets)
The randomly initialized length of each shapelet
dilations : array, shape (max_shapelets)
Expand All @@ -422,8 +426,15 @@ def _init_random_shapelet_params(
Means of the shapelets
stds : array, shape (max_shapelets, n_channels)
Standard deviation of the shapelets
class: array, shape (max_shapelets)
An initialized (empty) class array for each shapelet

"""
# Init startpoint array
startpoints = np.zeros(max_shapelets, dtype=np.int32)
# Init class array
classes = np.zeros(max_shapelets, dtype=np.int32)

# Lengths of the shapelets
# test dtypes correctness
lengths = np.random.choice(shapelet_lengths, size=max_shapelets).astype(np.int32)
Expand Down Expand Up @@ -461,7 +472,17 @@ def _init_random_shapelet_params(
means = np.zeros((max_shapelets, n_channels), dtype=np.float64)
stds = np.zeros((max_shapelets, n_channels), dtype=np.float64)

return values, lengths, dilations, threshold, normalize, means, stds
return (
values,
startpoints,
lengths,
dilations,
threshold,
normalize,
means,
stds,
classes,
)


@njit(cache=True)
Expand Down Expand Up @@ -541,6 +562,8 @@ def random_dilated_shapelet_extraction(
The returned tuple contains 7 arrays describing the shapelets parameters:
- values : array, shape (max_shapelets, n_channels, max(shapelet_lengths))
Values of the shapelets.
- startpoints : array, shape (max_shapelets)
Start points parameter of the shapelets
- lengths : array, shape (max_shapelets)
Length parameter of the shapelets
- dilations : array, shape (max_shapelets)
Expand All @@ -553,6 +576,8 @@ def random_dilated_shapelet_extraction(
Means of the shapelets
- stds : array, shape (max_shapelets, n_channels)
Standard deviation of the shapelets
- classes : array, shape (max_shapelets)
An initialized (empty) startpoint array for each shapelet
"""
n_cases = len(X)
n_channels = X[0].shape[0]
Expand All @@ -567,12 +592,14 @@ def random_dilated_shapelet_extraction(
# Initialize shapelets
(
values,
startpoints,
lengths,
dilations,
threshold,
normalize,
means,
stds,
classes,
) = _init_random_shapelet_params(
max_shapelets,
shapelet_lengths,
Expand Down Expand Up @@ -664,6 +691,13 @@ def random_dilated_shapelet_extraction(

threshold[i_shp] = np.random.uniform(lower_bound, upper_bound)
values[i_shp, :, :length] = _val

# Extract the starting point index of the shapelet
startpoints[i_shp] = idx_timestamp

# Extract the class value of the shapelet
classes[i_shp] = y[idx_sample]

if norm:
means[i_shp] = _means
stds[i_shp] = _stds
Expand All @@ -675,12 +709,14 @@ def random_dilated_shapelet_extraction(

return (
values[mask_values],
startpoints[mask_values],
lengths[mask_values],
dilations[mask_values],
threshold[mask_values],
normalize[mask_values],
means[mask_values],
stds[mask_values],
classes[mask_values],
)


Expand All @@ -696,6 +732,8 @@ def dilated_shapelet_transform(X, shapelets, distance):
The returned tuple contains 7 arrays describing the shapelets parameters:
- values : array, shape (n_shapelets, n_channels, max(shapelet_lengths))
Values of the shapelets.
- startpoints : array, shape (max_shapelets)
Start points parameter of the shapelets
- lengths : array, shape (n_shapelets)
Length parameter of the shapelets
- dilations : array, shape (n_shapelets)
Expand All @@ -708,6 +746,8 @@ def dilated_shapelet_transform(X, shapelets, distance):
Means of the shapelets
- stds : array, shape (n_shapelets, n_channels)
Standard deviation of the shapelets
- classes : array, shape (max_shapelets)
An initialized (empty) startpoint array for each shapelet
distance: CPUDispatcher
A Numba function used to compute the distance between two multidimensional
time series of shape (n_channels, length).
Expand All @@ -722,12 +762,14 @@ def dilated_shapelet_transform(X, shapelets, distance):
"""
(
values,
startpoints,
lengths,
dilations,
threshold,
normalize,
means,
stds,
classes,
) = shapelets
n_shapelets = len(lengths)
n_cases = len(X)
Expand Down
9 changes: 5 additions & 4 deletions aeon/visualisation/estimator/_shapelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,12 @@ def __init__(self, estimator):

def _get_shapelet(self, id_shapelet):
if isinstance(self.estimator, RandomDilatedShapeletTransform):
length_ = self.estimator.shapelets_[1][id_shapelet]
values_ = self.estimator.shapelets_[0][id_shapelet]
dilation_ = self.estimator.shapelets_[2][id_shapelet]
threshold_ = self.estimator.shapelets_[3][id_shapelet]
normalize_ = self.estimator.shapelets_[4][id_shapelet]
# startpos_ = self.estimator.shapelets_[1][id_shapelet]
length_ = self.estimator.shapelets_[2][id_shapelet]
dilation_ = self.estimator.shapelets_[3][id_shapelet]
threshold_ = self.estimator.shapelets_[4][id_shapelet]
normalize_ = self.estimator.shapelets_[5][id_shapelet]
distance = self.estimator.distance

elif isinstance(self.estimator, (RSAST, SAST)):
Expand Down