Skip to content

Commit

Permalink
[ENH] Improve deep learning networks test coverage for parameters whi…
Browse files Browse the repository at this point in the history
…ch can be list (#1851)

* Improve deep learning networks test coverage for parameters which can be list or not

* checklist : I've added myself to the list of contributors

* remove test for AE networks as they will be having their own tests for special reasons

* restore test_all_networks_functionality and add param test in test_all_networks_params

* use tag instead of the network name

* refactor test_all_networks_params to improve lisibility and skip messages
  • Loading branch information
Cyril-Meyer authored Jul 30, 2024
1 parent c980d7f commit d6baeef
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,15 @@
"doc"
]
},
{
"login": "Cyril-Meyer",
"name": "Cyril Meyer",
"avatar_url": "https://avatars.githubusercontent.com/u/69190238?v=4",
"profile": "https://github.com/Cyril-Meyer",
"contributions": [
"test"
]
},
{
"login": "Moonzyyy",
"name": "Daniele Carli",
Expand Down
72 changes: 72 additions & 0 deletions aeon/networks/tests/test_all_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,75 @@ def test_all_networks_functionality(network):
)
else:
pytest.skip(f"{network.__name__} not to be tested since its a base class.")


@pytest.mark.parametrize("network", _networks)
def test_all_networks_params(network):
"""Test the functionality of all networks."""
input_shape = (100, 2)

if network.__name__ in ["BaseDeepLearningNetwork", "EncoderNetwork"]:
pytest.skip(f"{network.__name__} not to be tested since its a base class.")

if network._config["structure"] == "auto-encoder":
pytest.skip(
f"{network.__name__} not to be tested (AE networks have their own tests)."
)

if not (
_check_soft_dependencies(
network._config["python_dependencies"], severity="none"
)
and _check_python_version(network._config["python_version"], severity="none")
):
pytest.skip(
f"{network.__name__} dependencies not satisfied or invalid \
Python version."
)

# check with default parameters
my_network = network()
my_network.build_network(input_shape=input_shape)

# check with list parameters
params = dict()
for attrname in [
"kernel_size",
"n_filters",
"avg_pool_size",
"activation",
"padding",
"strides",
"dilation_rate",
"use_bias",
]:

# Exceptions to fix
if (
attrname in ["kernel_size", "padding"]
and network.__name__ == "TapNetNetwork"
):
continue
# LITENetwork does not seem to work with list args
if network.__name__ == "LITENetwork":
continue

# Here we use 'None' string as default to differentiate with None values
attr = getattr(my_network, attrname, "None")
if attr != "None":
if attr is None:
attr = 3
elif isinstance(attr, list):
attr = attr[0]
else:
if network.__name__ in ["ResNetNetwork"]:
attr = [attr] * my_network.n_conv_per_residual_block
elif network.__name__ in ["InceptionNetwork"]:
attr = [attr] * my_network.depth
else:
attr = [attr] * my_network.n_layers
params[attrname] = attr

if params:
my_network = network(**params)
my_network.build_network(input_shape=input_shape)

0 comments on commit d6baeef

Please sign in to comment.