Skip to content

Commit 8a21f1c

Browse files
authored
[ENH] Import RandomChannelSelector through init (#1933)
* remove Theta * docstring * docstring * fix bug * all in channel selection * docstring
1 parent a049098 commit 8a21f1c

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

aeon/transformations/collection/channel_selection/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
collection using transform.
66
"""
77

8-
__all__ = ["ChannelScorer", "ElbowClassPairwise", "ElbowClassSum"]
8+
__all__ = [
9+
"ChannelScorer",
10+
"ElbowClassPairwise",
11+
"ElbowClassSum",
12+
"RandomChannelSelector",
13+
]
914

1015

1116
from aeon.transformations.collection.channel_selection._channel_scorer import (
@@ -15,3 +20,6 @@
1520
ElbowClassPairwise,
1621
ElbowClassSum,
1722
)
23+
from aeon.transformations.collection.channel_selection._random import (
24+
RandomChannelSelector,
25+
)

aeon/transformations/collection/channel_selection/_channel_scorer.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from aeon.classification.convolution_based._rocket_classifier import RocketClassifier
88
from aeon.transformations.collection.channel_selection.base import BaseChannelSelector
99

10+
__maintainer__ = ["TonyBagnall"]
11+
__all__ = ["ChannelScorer"]
12+
1013

1114
class ChannelScorer(BaseChannelSelector):
1215
"""Channel scorer performs channel selection using a single channel classifier.

aeon/transformations/collection/channel_selection/_random.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from aeon.transformations.collection.channel_selection.base import BaseChannelSelector
88

99
__maintainer__ = ["TonyBagnall"]
10+
__all__ = ["RandomChannelSelector"]
1011

1112

1213
class RandomChannelSelector(BaseChannelSelector):
@@ -15,13 +16,23 @@ class RandomChannelSelector(BaseChannelSelector):
1516
Parameters
1617
----------
1718
p: float, default 0.4
18-
proportion of channels to keep. If p*len(X) is non integer it is rounded up
19+
proportion of channels to keep. If p*n_channels is non integer it is rounded up
1920
to the nearest integer.
2021
2122
Attributes
2223
----------
2324
channels_selected_ : list[int]
2425
List of channels selected in fit.
26+
27+
Examples
28+
--------
29+
>>> import numpy as np
30+
>>> from aeon.transformations.collection.channel_selection import RandomChannelSelector # noqa
31+
>>> X = np.random.rand(10, 10, 100)
32+
>>> selector = RandomChannelSelector(p=0.4)
33+
>>> XNew = selector.fit_transform(X)
34+
>>> XNew.shape
35+
(10, 4, 100)
2536
"""
2637

2738
_tags = {
@@ -41,7 +52,7 @@ def __init__(self, p=0.4, random_state=None):
4152
def _fit(self, X, y):
4253
"""Randomly select channels to retain."""
4354
rng = check_random_state(self.random_state)
44-
to_select = math.ceil(self.p * len(X))
55+
to_select = math.ceil(self.p * X.shape[1])
4556
self.channels_selected_ = rng.choice(
4657
list(range(X[0].shape[0])), size=to_select, replace=False
4758
)

0 commit comments

Comments
 (0)