Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Test coverage for Resnet Network #2553

Merged
merged 5 commits into from
Feb 26, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions aeon/networks/tests/test_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for the ResNet Model."""

import pytest

from aeon.networks import ResNetNetwork
from aeon.utils.validation._dependencies import _check_soft_dependencies


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="skip test if required soft dependency not available",
)
def test_default_initialization():
"""Test if the network initializes with proper attributes."""
model = ResNetNetwork()
assert isinstance(
model, ResNetNetwork
), "Model initialization failed: Incorrect type"
assert model.n_residual_blocks == 3, "Default residual blocks count mismatch"
assert (
model.n_conv_per_residual_block == 3
), "Default convolution blocks count mismatch"
assert model.n_filters is None, "Default n_filters should be None"
assert model.kernel_size is None, "Default kernel_size should be None"
assert model.strides == 1, "Default strides value mismatch"
assert model.dilation_rate == 1, "Default dilation rate mismatch"
assert model.activation == "relu", "Default activation mismatch"
assert model.use_bias is True, "Default use_bias mismatch"
assert model.padding == "same", "Default padding mismatch"


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="skip test if required soft dependency not available",
)
def test_custom_initialization():
"""Test whether custom kwargs are correctly set."""
model = ResNetNetwork(
n_residual_blocks=3,
n_conv_per_residual_block=3,
n_filters=[64, 128, 128],
kernel_size=[8, 5, 3],
activation="relu",
strides=1,
padding="same",
)
model.build_network((128, 1))
assert isinstance(
model, ResNetNetwork
), "Custom initialization failed: Incorrect type"
assert model._n_filters == [64, 128, 128], "n_filters list mismatch"
assert model._kernel_size == [8, 5, 3], "kernel_size list mismatch"
assert model._activation == ["relu", "relu", "relu"], "activation list mismatch"
assert model._strides == [1, 1, 1], "strides list mismatch"
assert model._padding == ["same", "same", "same"], "padding list mismatch"


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="skip test if required soft dependency not available",
)
def test_invalid_initialization():
"""Test if the network raises valid exceptions for invalid configurations."""
with pytest.raises(ValueError, match=".*same as number of residual blocks.*"):
ResNetNetwork(n_filters=[64, 128], n_residual_blocks=3).build_network((128, 1))

with pytest.raises(ValueError, match=".*same as number of convolution layers.*"):
ResNetNetwork(kernel_size=[8, 5], n_conv_per_residual_block=3).build_network(
(128, 1)
)

with pytest.raises(ValueError, match=".*same as number of convolution layers.*"):
ResNetNetwork(strides=[1, 2], n_conv_per_residual_block=3).build_network(
(128, 1)
)


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="skip test if required soft dependency not available",
)
def test_build_network():
"""Test network building with various input shapes."""
model = ResNetNetwork()

input_shapes = [(128, 1), (256, 3), (512, 1)]
for shape in input_shapes:
input_layer, output_layer = model.build_network(shape)
assert hasattr(input_layer, "shape"), "Input layer type mismatch"
assert hasattr(output_layer, "shape"), "Output layer type mismatch"
assert input_layer.shape[1:] == shape, "Input shape mismatch"
assert output_layer.shape[-1] == 128, "Output layer mismatch"


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="skip test if required soft dependency not available",
)
def test_shortcut_layer():
"""Test the shortcut layer functionality."""
model = ResNetNetwork()

input_shape = (128, 64)
input_layer, output_layer = model.build_network(input_shape)

shortcut = model._shortcut_layer(input_layer, output_layer)

assert hasattr(shortcut, "shape"), "Shortcut layer output type mismatch"
assert shortcut.shape[-1] == 128, "Shortcut output shape mismatch"