Skip to content

Commit

Permalink
Setup for testing benchmark loading
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Nov 29, 2024
1 parent 70c74b4 commit 4985e34
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 12 deletions.
26 changes: 14 additions & 12 deletions amlb/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,18 @@ def _constraints(self):
return constraints_lookup

def benchmark_definition(self, name: str, defaults: TaskConstraint | None = None):
return self._benchmark_definition(name, self.config, defaults)

def _benchmark_definition(
self, name: str, config_: Namespace, defaults: TaskConstraint | None = None
):
"""
:param name: name of the benchmark as defined by resources/benchmarks/{name}.yaml, the path to a user-defined benchmark description file or a study id.
:param defaults: defaults used as a base config for each task in the benchmark definition
:return:
"""
file_defaults, tasks, benchmark_path, benchmark_name = benchmark_load(
name, self.config.benchmarks.definition_dir
name, config_.benchmarks.definition_dir
)
if defaults is not None:
defaults = Namespace(**dataclasses.asdict(defaults))
Expand All @@ -224,15 +229,16 @@ def benchmark_definition(self, name: str, defaults: TaskConstraint | None = None
)
for task in tasks:
task |= defaults # add missing keys from hard defaults + defaults
self._validate_task(task)
Resources._validate_task(task, config_)

self._validate_task(defaults, lenient=True)
Resources._validate_task(defaults, config_, lenient=True)
defaults.enabled = False
tasks.append(defaults)
log.debug("Available task definitions:\n%s", tasks)
return tasks, benchmark_name, benchmark_path

def _validate_task(self, task, lenient=False):
@staticmethod
def _validate_task(task: Namespace, config_: Namespace, lenient: bool = False):
missing = []
for conf in ["name"]:
if task[conf] is None:
Expand All @@ -253,7 +259,7 @@ def _validate_task(self, task, lenient=False):
"quantile_levels",
]:
if task[conf] is None:
task[conf] = self.config.benchmarks.defaults[conf]
task[conf] = config_.benchmarks.defaults[conf]
log.debug(
"Config `{config}` not set for task {name}, using default `{value}`.".format(
config=conf, name=task.name, value=task[conf]
Expand Down Expand Up @@ -287,14 +293,10 @@ def _validate_task(self, task, lenient=False):
"but task definition is {task}".format(task=str(task))
)

conf = "metric"
if task[conf] is None:
task[conf] = None

conf = "ec2_instance_type"
if task[conf] is None:
i_series = self.config.aws.ec2.instance_type.series
i_map = self.config.aws.ec2.instance_type.map
i_series = config_.aws.ec2.instance_type.series
i_map = config_.aws.ec2.instance_type.map
if str(task.cores) in i_map:
i_size = i_map[str(task.cores)]
elif task.cores > 0:
Expand All @@ -315,7 +317,7 @@ def _validate_task(self, task, lenient=False):

conf = "ec2_volume_type"
if task[conf] is None:
task[conf] = self.config.aws.ec2.volume_type
task[conf] = config_.aws.ec2.volume_type
log.debug(
"Config `{config}` not set for task {name}, using default `{value}`.".format(
config=conf, name=task.name, value=task[conf]
Expand Down
120 changes: 120 additions & 0 deletions tests/unit/amlb/resources/test_benchmark_definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from functools import partial

import pytest

from amlb import Resources
from amlb.utils import Namespace


@pytest.fixture
def amlb_dummy_configuration() -> Namespace:
defaults = {
"max_runtime_seconds": 0,
"cores": 1,
"folds": 2,
"max_mem_size_mb": 3,
"min_vol_size_mb": 4,
"quantile_levels": 5,
}

aws_defaults = {
"ec2": {
"volume_type": "gp3",
"instance_type": {
"series": "m5",
"map": {"4": "small", "default": "large"},
},
}
}
return Namespace(
aws=Namespace.from_dict(aws_defaults),
benchmarks=Namespace(defaults=Namespace.from_dict(defaults)),
)


def test_validate_task_strict_requires_name():
with pytest.raises(ValueError) as excinfo:
Resources._validate_task(
task=Namespace(),
config_=Namespace(),
lenient=False,
)
assert "mandatory properties as missing" in excinfo.value.args[0]


def test_validate_task_strict_requires_id(amlb_dummy_configuration: Namespace):
strict_validate = partial(
Resources._validate_task, config_=amlb_dummy_configuration, lenient=False
)
with pytest.raises(ValueError) as excinfo:
strict_validate(task=Namespace(name="foo"))
assert "must contain an ID or one property" in excinfo.value.args[0]


@pytest.mark.parametrize(
("properties", "expected"),
[
(Namespace(id="bar"), "bar"),
(Namespace(openml_task_id=42), "openml.org/t/42"),
(Namespace(openml_dataset_id=42), "openml.org/d/42"),
(Namespace(dataset="bar"), "bar"),
(Namespace(dataset=Namespace(id="bar")), "bar"),
],
)
def test_validate_task_id_formatting(
properties: Namespace, expected: str, amlb_dummy_configuration: Namespace
):
task = Namespace(name="foo") | properties
Resources._validate_task(task=task, config_=amlb_dummy_configuration)
assert task["id"] == expected


def test_validate_task_adds_benchmark_defaults(amlb_dummy_configuration: Namespace):
task = Namespace(name=None)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)

config = Namespace.dict(amlb_dummy_configuration, deep=True)
for setting, default in config["benchmarks"]["defaults"].items():
assert task[setting] == default
assert task["ec2_volume_type"] == amlb_dummy_configuration.aws.ec2.volume_type


def test_validate_task_does_not_overwrite(amlb_dummy_configuration: Namespace):
task = Namespace(name=None, cores=42)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)

config = Namespace.dict(amlb_dummy_configuration, deep=True)
assert task.cores == 42
for setting, default in config["benchmarks"]["defaults"].items():
if setting != "cores":
assert task[setting] == default


def test_validate_task_looks_up_instance_type(amlb_dummy_configuration: Namespace):
instance_type = amlb_dummy_configuration.aws.ec2.instance_type
reverse_size_map = {v: k for k, v in Namespace.dict(instance_type.map).items()}
n_cores_for_small = int(reverse_size_map["small"])

task = Namespace(name="foo", cores=n_cores_for_small)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
assert (
task["ec2_instance_type"] == "m5.small"
), "Should resolve to the instance type with the exact amount of cores"

task = Namespace(name="foo", cores=n_cores_for_small - 1)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
assert (
task["ec2_instance_type"] == "m5.small"
), "If exact amount of cores are not available, should resolve to next biggest"

task = Namespace(name="foo", cores=n_cores_for_small + 1)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
assert (
task["ec2_instance_type"] == "m5.large"
), "If bigger than largest in map, should revert to default"

task = Namespace(name="foo", ec2_instance_type="bar")
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
assert (
task["ec2_instance_type"] == "bar"
), "Should not overwrite explicit configuration"

0 comments on commit 4985e34

Please sign in to comment.