Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add AEDRNNClusterer #1784

Merged
merged 72 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
26092a6
Add AEDRNNNetwork
aadya940 May 18, 2024
d24c685
minor fix
aadya940 May 18, 2024
58e56a1
minor
aadya940 May 18, 2024
3f27785
minor refactoring
aadya940 May 27, 2024
2f62ca3
add _* to private methods
aadya940 May 27, 2024
8838040
precommit
aadya940 May 27, 2024
b22cace
minor fix
aadya940 May 27, 2024
41d0493
minor
aadya940 May 27, 2024
9f95380
minor
aadya940 May 31, 2024
a0083fe
minor
aadya940 May 31, 2024
e622dba
minor
aadya940 May 31, 2024
253295a
minor
aadya940 May 31, 2024
db341b3
Add dilation_rate_decoder kwarg
aadya940 Jun 1, 2024
452ac75
add kwargs
aadya940 Jun 1, 2024
a4bf588
pre-commit
aadya940 Jun 1, 2024
5e73c36
minor
aadya940 Jun 1, 2024
33305a8
temporal_latent_space
aadya940 Jun 2, 2024
fdf929c
Adjust for temporal_latent_space
aadya940 Jun 2, 2024
e3f3d62
Add docstring
aadya940 Jun 3, 2024
f9c2165
Add test cases for AEDRNNNetwork
aadya940 Jun 3, 2024
c7af48c
Add skipif for pytest tests
aadya940 Jun 3, 2024
a1c51ba
minor
aadya940 Jun 3, 2024
be3735b
minor fix
aadya940 Jun 3, 2024
6cb7af0
minor
aadya940 Jun 3, 2024
4093d86
minor
aadya940 Jun 4, 2024
9efe7b1
minor
aadya940 Jun 4, 2024
b867ba3
Add tests
aadya940 Jun 4, 2024
9d4d5cf
minor
aadya940 Jun 4, 2024
0ae9957
add tag
aadya940 Jun 14, 2024
93ba4e0
minor
aadya940 Jun 18, 2024
3e41662
fix bugs
aadya940 Jun 20, 2024
86c126f
Merge branch 'main' into drnn
aadya940 Jun 21, 2024
9212dba
update base
aadya940 Jun 21, 2024
a6a5453
minor
aadya940 Jun 29, 2024
2c93c1b
minor
aadya940 Jul 10, 2024
37b0574
Add AEDRNNClusterer
aadya940 Jul 10, 2024
f971b3c
typo
aadya940 Jul 11, 2024
ed12377
add to __init__
aadya940 Jul 11, 2024
6adda18
minor:
aadya940 Jul 11, 2024
c968719
merge
aadya940 Jul 20, 2024
564ca64
modelcheckpoint callback fixes
aadya940 Jul 20, 2024
d75faad
minor fixes
aadya940 Jul 25, 2024
8b5ebbc
fixes
aadya940 Jul 25, 2024
b29a6fc
Add AEDRNNClusterer example notebook
aadya940 Jul 25, 2024
f5d22c3
docstring
aadya940 Jul 26, 2024
d562419
fixes
aadya940 Jul 28, 2024
fd443f0
Update _ae_drnn.py
aadya940 Jul 28, 2024
b436fdd
Fixes for reviews
aadya940 Jul 29, 2024
e14c8ec
Merge branch 'aedrnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Jul 29, 2024
b53fe01
fixes
aadya940 Jul 31, 2024
e200c73
Merge branch 'aeon-toolkit:main' into aedrnn-clusterer
aadya940 Jul 31, 2024
758737c
Automatic `pre-commit` fixes
aadya940 Jul 31, 2024
ef20057
update network
aadya940 Aug 1, 2024
0dfb8a9
Add estimator kwarg in clusterer
aadya940 Aug 15, 2024
6662705
minor fixes
aadya940 Aug 15, 2024
26f60ef
some fixes
aadya940 Aug 16, 2024
039e872
fix notebooks
aadya940 Aug 16, 2024
ff6d8b1
remove deprecated
aadya940 Aug 16, 2024
e7fd765
Merge branch 'main' into aedrnn-clusterer
aadya940 Aug 23, 2024
9e266d1
Delete examples/clustering/deep_clustering.ipynb
aadya940 Aug 30, 2024
c19657b
Add metrics kwarg
aadya940 Aug 31, 2024
3ef1161
Merge branch 'aedrnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Aug 31, 2024
27d78fc
Merge branch 'main' into aedrnn-clusterer
aadya940 Oct 28, 2024
1363d1d
remove return_X_y
aadya940 Oct 31, 2024
29474cf
Merge branch 'main' into aedrnn-clusterer
aadya940 Nov 3, 2024
417e359
Update _ae_drnn.py
aadya940 Nov 5, 2024
627b4c1
Merge branch 'main' into aedrnn-clusterer
aadya940 Nov 5, 2024
8f898ab
minor
aadya940 Nov 9, 2024
25bf98a
Merge branch 'aedrnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Nov 9, 2024
40618f7
Automatic `pre-commit` fixes
aadya940 Nov 9, 2024
d5ec772
merge main
aadya940 Nov 11, 2024
df8e807
Merge branch 'aedrnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aeon/clustering/deep_learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
"BaseDeepClusterer",
"AEFCNClusterer",
"AEResNetClusterer",
"AEDRNNClusterer",
"AEAttentionBiGRUClusterer",
"AEBiGRUClusterer",
]
from aeon.clustering.deep_learning._ae_abgru import AEAttentionBiGRUClusterer
from aeon.clustering.deep_learning._ae_bgru import AEBiGRUClusterer
from aeon.clustering.deep_learning._ae_drnn import AEDRNNClusterer
from aeon.clustering.deep_learning._ae_fcn import AEFCNClusterer
from aeon.clustering.deep_learning._ae_resnet import AEResNetClusterer
from aeon.clustering.deep_learning.base import BaseDeepClusterer
355 changes: 355 additions & 0 deletions aeon/clustering/deep_learning/_ae_drnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
"""Deep Learning Auto-Encoder using DRNN Network."""

__maintainer__ = []
__all__ = ["AEDRNNClusterer"]

import gc
import os
import time
from copy import deepcopy

from sklearn.utils import check_random_state

from aeon.clustering import DummyClusterer
from aeon.clustering.deep_learning.base import BaseDeepClusterer
from aeon.networks import AEDRNNNetwork
from aeon.utils.validation._dependencies import _check_soft_dependencies

if _check_soft_dependencies(["tensorflow"], severity="none"):
from aeon.networks._ae_drnn import _TensorDilation


class AEDRNNClusterer(BaseDeepClusterer):
"""Auto-Encoder based Dilated Recurrent Neural Network (DRNN).

Parameters
----------
n_clusters : int, default=None
Number of clusters for the deep learnign model.
clustering_algorithm : str, default="deprecated"
Please use the 'estimator' parameter.
estimator : aeon clusterer, default=None
An aeon estimator to be built using the transformed data.
Defaults to aeon TimeSeriesKMeans() with euclidean distance
and mean averaging method and n_clusters set to 2.
clustering_params : dict, default=None
Please use 'estimator' parameter.
latent_space_dim : int, default=128
Dimension of the latent space of the auto-encoder.
temporal_latent_space : bool, default = False
Flag to choose whether the latent space is an MTS or Euclidean space.
n_layers_encoder : int, default = 3
Number of layers in the encoder.
n_layers_decoder : int, default = 3
Number of layers in the decoder.
dilation_rate_encoder : int or list of int, default = 1
The dilation rate for the encoder.
dilation_rate_decoder : int or list of int, default = 1
The dilation rate for the decoder.
activation_encoder : str or list of str, default = "relu"
Activation used after DRNN Layer in the encoder.
activation_decoder : str or list of str, default = "relu"
Activation used after DRNN Layer in the decoder.
n_units_encoder : list or int, default = None
Number of Units in each DRNN Layer of the encoder.
n_units_decoder : list or int, default = None
Number of Units in each DRNN Layer of the decoder.
n_epochs : int, default = 2000
The number of epochs to train the model.
batch_size : int, default = 16
The number of samples per gradient update.
use_mini_batch_size : bool, default = True,
Whether or not to use the mini batch size formula.
random_state : int, RandomState instance or None, default=None
If `int`, random_state is the seed used by the random number generator;
If `RandomState` instance, random_state is the random number generator;
If `None`, the random number generator is the `RandomState` instance used
by `np.random`.
Seeded random number generation can only be guaranteed on CPU processing,
GPU processing will be non-deterministic.
verbose : boolean, default = False
Whether to output extra information.
loss : string, default="mean_squared_error"
Fit parameter for the keras model.
metrics : keras metrics, default = ["mean_squared_error"]
will be set to mean_squared_error as default if None
optimizer : keras.optimizers object, default = Adam(lr=0.01)
Specify the optimizer and the learning rate to be used.
file_path : str, default = "./"
File path to save best model.
save_best_model : bool, default = False
Whether or not to save the best model, if the
modelcheckpoint callback is used by default,
this condition, if True, will prevent the
automatic deletion of the best saved model from
file and the user can choose the file name.
save_last_model : bool, default = False
Whether or not to save the last model, last
epoch trained, using the base class method
save_last_model_to_file.
best_file_name : str, default = "best_model"
The name of the file of the best model, if
save_best_model is set to False, this parameter
is discarded.
last_file_name : str, default = "last_model"
The name of the file of the last model, if
save_last_model is set to False, this parameter
is discarded.
callbacks : keras.callbacks, default = None
List of keras callbacks.

Examples
--------
>>> from aeon.clustering.deep_learning import AEDRNNClusterer
>>> from aeon.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> from aeon.clustering import DummyClusterer
>>> _clst = DummyClusterer(n_clusters=2)
>>> aefcn = AEDRNNClusterer(estimator = _clst,
... n_epochs=20,batch_size=4) # doctest: +SKIP
>>> aefcn.fit(X_train) # doctest: +SKIP
AEDRNNClusterer(...)
"""

def __init__(
self,
n_clusters=None,
estimator=None,
clustering_algorithm="deprecated",
clustering_params=None,
latent_space_dim=128,
temporal_latent_space=False,
n_layers_encoder=3,
n_layers_decoder=3,
dilation_rate_encoder=1,
dilation_rate_decoder=1,
n_units_encoder=None,
n_units_decoder=None,
activation_encoder="relu",
activation_decoder="relu",
n_epochs=2000,
batch_size=32,
use_mini_batch_size=False,
random_state=None,
verbose=False,
loss="mse",
metrics=None,
optimizer="Adam",
file_path="./",
save_best_model=False,
save_last_model=False,
best_file_name="best_model",
last_file_name="last_file",
callbacks=None,
):
self.latent_space_dim = latent_space_dim
self.temporal_latent_space = temporal_latent_space
self.n_layers_encoder = n_layers_encoder
self.n_layers_decoder = n_layers_decoder
self.activation_encoder = activation_encoder
self.activation_decoder = activation_decoder
self.dilation_rate_encoder = dilation_rate_encoder
self.dilation_rate_decoder = dilation_rate_decoder
self.n_units_encoder = n_units_encoder
self.n_units_decoder = n_units_decoder
self.optimizer = optimizer
self.loss = loss
self.metrics = metrics
self.verbose = verbose
self.use_mini_batch_size = use_mini_batch_size
self.callbacks = callbacks
self.file_path = file_path
self.n_epochs = n_epochs
self.save_best_model = save_best_model
self.save_last_model = save_last_model
self.best_file_name = best_file_name
self.random_state = random_state

super().__init__(
n_clusters=n_clusters,
estimator=estimator,
clustering_algorithm=clustering_algorithm,
clustering_params=clustering_params,
batch_size=batch_size,
last_file_name=last_file_name,
)

self._network = AEDRNNNetwork(
latent_space_dim=self.latent_space_dim,
temporal_latent_space=self.temporal_latent_space,
n_layers_encoder=self.n_layers_encoder,
n_layers_decoder=self.n_layers_decoder,
dilation_rate_encoder=self.dilation_rate_encoder,
dilation_rate_decoder=self.dilation_rate_decoder,
activation_encoder=self.activation_encoder,
activation_decoder=self.activation_decoder,
n_units_encoder=self.n_units_encoder,
n_units_decoder=self.n_units_decoder,
)

def build_model(self, input_shape, **kwargs):
"""Construct a compiled, un-trained, keras model that is ready for training.

In aeon, time series are stored in numpy arrays of shape
(n_channels,n_timepoints). Keras/tensorflow assume
data is in shape (n_timepoints,n_channels). This method also assumes
(n_timepoints,n_channels). Transpose should happen in fit.

Parameters
----------
input_shape : tuple
The shape of the data fed into the input layer, should be
(n_timepoints,n_channels).

Returns
-------
output : a compiled Keras Model.
"""
import numpy as np
import tensorflow as tf

rng = check_random_state(self.random_state)
self.random_state_ = rng.randint(0, np.iinfo(np.int32).max)
tf.keras.utils.set_random_seed(self.random_state_)
encoder, decoder = self._network.build_network(input_shape, **kwargs)

input_layer = tf.keras.layers.Input(input_shape, name="input layer")
encoder_output = encoder(input_layer)
decoder_output = decoder(encoder_output)
output_layer = tf.keras.layers.Reshape(
target_shape=input_shape, name="outputlayer"
)(decoder_output)

model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)

self.optimizer_ = (
tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer
)

if self.metrics is None:
self._metrics = ["mean_squared_error"]
elif isinstance(self.metrics, list):
self._metrics = self.metrics
elif isinstance(self.metrics, str):
self._metrics = [self.metrics]
else:
raise ValueError("Metrics should be a list, string, or None.")

model.compile(optimizer=self.optimizer_, loss=self.loss, metrics=self._metrics)

return model

def _fit(self, X):
"""Fit the classifier on the training set (X, y).

Parameters
----------
X : np.ndarray of shape = (n_cases (n), n_channels (d), n_timepoints (m))
The training input samples.

Returns
-------
self : object
"""
import tensorflow as tf

# Transpose to conform to Keras input style.
X = X.transpose(0, 2, 1)

self.input_shape = X.shape[1:]
self.training_model_ = self.build_model(self.input_shape)

if self.verbose:
self.training_model_.summary()

if self.use_mini_batch_size:
mini_batch_size = min(self.batch_size, X.shape[0] // 10)
else:
mini_batch_size = self.batch_size

self.file_name_ = (
self.best_file_name if self.save_best_model else str(time.time_ns())
)

if self.callbacks is None:
self.callbacks_ = [
tf.keras.callbacks.ReduceLROnPlateau(
monitor="loss", factor=0.5, patience=50, min_lr=0.0001
),
tf.keras.callbacks.ModelCheckpoint(
filepath=self.file_path + self.file_name_ + ".keras",
monitor="loss",
save_best_only=True,
),
]
else:
self.callbacks_ = self._get_model_checkpoint_callback(
callbacks=self.callbacks,
file_path=self.file_path,
file_name=self.file_name_,
)

self.history = self.training_model_.fit(
X,
X,
batch_size=mini_batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
callbacks=self.callbacks_,
)

try:
self.model_ = tf.keras.models.load_model(
self.file_path + self.file_name_ + ".keras",
compile=False,
custom_objects={"_TensorDilation": _TensorDilation},
)
if not self.save_best_model:
os.remove(self.file_path + self.file_name_ + ".keras")
except FileNotFoundError:
self.model_ = deepcopy(self.training_model_)

self._fit_clustering(X=X)

gc.collect()

return self

def _score(self, X, y=None):
# Transpose to conform to Keras input style.
X = X.transpose(0, 2, 1)
latent_space = self.model_.layers[1].predict(X)
return self._estimator.score(latent_space)

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
For classifiers, a "default" set of parameters should be provided for
general testing, and a "results_comparison" set for comparing against
previously recorded results if the general set does not produce suitable
probabilities to compare against.

Returns
-------
params : dict or list of dict, default={}
Parameters to create testing instances of the class.
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`.
"""
param1 = {
"estimator": DummyClusterer(n_clusters=2),
"n_epochs": 1,
"batch_size": 4,
"n_layers_encoder": 1,
"n_layers_decoder": 1,
}

return [param1]