Skip to content

Commit

Permalink
hyda bugfix (#1771)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewMiddlehurst authored Jul 9, 2024
1 parent 2003cf2 commit 9bb6c99
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions aeon/transformations/collection/convolution_based/_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class HydraTransformer(BaseCollectionTransformer):
"output_data_type": "Tabular",
"algorithm_type": "convolution",
"python_dependencies": "torch",
"fit_is_empty": True,
}

def __init__(
Expand All @@ -82,7 +81,7 @@ def __init__(

super().__init__()

def _transform(self, X, y=None):
def _fit(self, X, y=None):
import torch

if isinstance(self.random_state, int):
Expand All @@ -91,14 +90,16 @@ def _transform(self, X, y=None):
n_jobs = check_n_jobs(self.n_jobs)
torch.set_num_threads(n_jobs)

self.hydra = _HydraInternal(
self._hydra = _HydraInternal(
X.shape[2],
X.shape[1],
k=self.n_kernels,
g=self.n_groups,
max_num_channels=self.max_num_channels,
)
return self.hydra(torch.tensor(X).float())

def _transform(self, X, y=None):
return self._hydra(torch.tensor(X).float())


if _check_soft_dependencies("torch", severity="none"):
Expand Down

0 comments on commit 9bb6c99

Please sign in to comment.