Skip to content

Commit

Permalink
fix: dataset schema too defined too strictly (#143)
Browse files Browse the repository at this point in the history
* fix: allow dict for datasets in dataloader

---------

Co-authored-by: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com>
  • Loading branch information
theissenhelen and anaprietonem authored Feb 25, 2025
1 parent f26adc9 commit 4792ee1
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 32 deletions.
4 changes: 2 additions & 2 deletions training/docs/user-guide/configuring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ run using the following command:

.. code:: bash
anemoi-training config validate --name debug.yaml
anemoi-training config validate --config-name debug.yaml
This will check that the configuration is valid and that all the
required fields are present. If your config is correctly defined then
Expand Down Expand Up @@ -195,7 +195,7 @@ values:

.. code:: bash
(anemoi_core_venv)[] $ anemoi-training config validate --name=debug --mask_env_vars
(anemoi_core_venv)[] $ anemoi-training config validate --config-name=debug --mask_env_vars
2025-02-16 17:48:38 INFO Validating configs.
2025-02-16 17:48:38 WARNING Note that this command is not taking into account if your config has a no_validation flag.So this command will validate the config regardless of the flag.
2025-01-28 09:37:23 INFO Prepending Anemoi Home (/home_path/.config/anemoi/training/config) to the search path.
Expand Down
8 changes: 4 additions & 4 deletions training/src/anemoi/training/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None:
help_msg = "Validate the Anemoi training configs."
validate = subparsers.add_parser("validate", help=help_msg, description=help_msg)

validate.add_argument("--name", help="Name of the primary config file")
validate.add_argument("--config-name", help="Name of the primary config file")
validate.add_argument("--overwrite", "-f", action="store_true")
validate.add_argument(
"--mask_env_vars",
Expand Down Expand Up @@ -109,7 +109,7 @@ def run(self, args: argparse.Namespace) -> None:
"Note that this command is not taking into account if your config has a no_validation flag."
"So this command will validate the config regardless of the flag.",
)
self.validate_config(args.name, args.mask_env_vars)
self.validate_config(args.config_name, args.mask_env_vars)
LOGGER.info("Config files validated.")
return

Expand Down Expand Up @@ -197,10 +197,10 @@ def _mask_slurm_env_variables(self, cfg: DictConfig) -> None:

return OmegaConf.create(updated_cfg)

def validate_config(self, name: Path | str, mask_env_vars: bool) -> None:
def validate_config(self, config_name: Path | str, mask_env_vars: bool) -> None:
"""Validates the configuration files in the given directory."""
with initialize(version_base=None, config_path=""):
cfg = compose(config_name=name)
cfg = compose(config_name=config_name)
if mask_env_vars:
cfg = self._mask_slurm_env_variables(cfg)
OmegaConf.resolve(cfg)
Expand Down
3 changes: 0 additions & 3 deletions training/src/anemoi/training/config/model/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ layer_kernels:

processor:
_target_: anemoi.models.layers.processor.GNNProcessor
_convert_: all
activation: ${model.activation}
trainable_size: ${model.trainable_parameters.hidden2hidden}
sub_graph_edge_attributes: ${model.attributes.edges}
Expand All @@ -30,7 +29,6 @@ processor:

encoder:
_target_: anemoi.models.layers.mapper.GNNForwardMapper
_convert_: all
trainable_size: ${model.trainable_parameters.data2hidden}
sub_graph_edge_attributes: ${model.attributes.edges}
activation: ${model.activation}
Expand All @@ -40,7 +38,6 @@ encoder:

decoder:
_target_: anemoi.models.layers.mapper.GNNBackwardMapper
_convert_: all
trainable_size: ${model.trainable_parameters.hidden2data}
sub_graph_edge_attributes: ${model.attributes.edges}
activation: ${model.activation}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ layer_kernels:

processor:
_target_: anemoi.models.layers.processor.GraphTransformerProcessor
_convert_: all
activation: ${model.activation}
trainable_size: ${model.trainable_parameters.hidden2hidden}
sub_graph_edge_attributes: ${model.attributes.edges}
Expand All @@ -31,7 +30,6 @@ processor:

encoder:
_target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper
_convert_: all
trainable_size: ${model.trainable_parameters.data2hidden}
sub_graph_edge_attributes: ${model.attributes.edges}
activation: ${model.activation}
Expand All @@ -42,7 +40,6 @@ encoder:

decoder:
_target_: anemoi.models.layers.mapper.GraphTransformerBackwardMapper
_convert_: all
trainable_size: ${model.trainable_parameters.hidden2data}
sub_graph_edge_attributes: ${model.attributes.edges}
activation: ${model.activation}
Expand Down
3 changes: 0 additions & 3 deletions training/src/anemoi/training/config/model/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ layer_kernels:

processor:
_target_: anemoi.models.layers.processor.TransformerProcessor
_convert_: all
activation: ${model.activation}
num_layers: 16
num_chunks: 2
Expand All @@ -34,7 +33,6 @@ processor:

encoder:
_target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper
_convert_: all
trainable_size: ${model.trainable_parameters.data2hidden}
sub_graph_edge_attributes: ${model.attributes.edges}
activation: ${model.activation}
Expand All @@ -46,7 +44,6 @@ encoder:

decoder:
_target_: anemoi.models.layers.mapper.GraphTransformerBackwardMapper
_convert_: all
trainable_size: ${model.trainable_parameters.hidden2data}
sub_graph_edge_attributes: ${model.attributes.edges}
activation: ${model.activation}
Expand Down
10 changes: 7 additions & 3 deletions training/src/anemoi/training/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,15 @@ def set_read_group_size_if_not_provided(self) -> BaseSchema:
@model_validator(mode="after")
def check_log_paths_available_for_loggers(self) -> BaseSchema:
logger = []
if self.diagnostics.log.wandb.enabled and not self.hardware.paths.logs.wandb:
if self.diagnostics.log.wandb.enabled and (not self.hardware.paths.logs or not self.hardware.paths.logs.wandb):
logger.append("wandb")
if self.diagnostics.log.mlflow.enabled and not self.hardware.paths.logs.mlflow:
if self.diagnostics.log.mlflow.enabled and (
not self.hardware.paths.logs or not self.hardware.paths.logs.mlflow
):
logger.append("mlflow")
if self.diagnostics.log.tensorboard.enabled and not self.hardware.paths.logs.tensorboard:
if self.diagnostics.log.tensorboard.enabled and (
not self.hardware.paths.logs or not self.hardware.paths.logs.tensorboard
):
logger.append("tensorboard")

if logger:
Expand Down
18 changes: 12 additions & 6 deletions training/src/anemoi/training/schemas/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from pathlib import Path # noqa: TC003
from typing import Any
from typing import Literal
from typing import Optional
from typing import Union

from omegaconf import DictConfig # noqa: TC002
from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import PositiveInt
from pydantic import RootModel
Expand Down Expand Up @@ -57,10 +60,10 @@ def as_seconds(self) -> int:
return int(self.as_timedelta.total_seconds())


class DatasetSchema(BaseModel):
class DatasetSchema(PydanticBaseModel):
"""Dataset configuration schema."""

dataset: Union[str, dict, Path, list[dict]]
dataset: Optional[Union[str, dict, Path, list[dict]]] = None
"Dataset, see anemoi-datasets"
start: Union[str, int, None] = Field(default=None)
"Starting datetime for sample of the dataset."
Expand Down Expand Up @@ -104,6 +107,9 @@ class MaskedGridIndicesSchema(BaseModel):


class DataLoaderSchema(PydanticBaseModel):

model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

prefetch_factor: int = Field(example=2, ge=0)
"Number of batches loaded in advance by each worker."
pin_memory: bool = Field(example=True)
Expand All @@ -114,15 +120,15 @@ class DataLoaderSchema(PydanticBaseModel):
"Per-GPU batch size."
limit_batches: LoaderSet = Field(example=None)
"Limit number of batches to run. Default value null, will run on all the batches."
training: DatasetSchema
training: Union[DatasetSchema, DictConfig]
"Training DatasetSchema."
validation: DatasetSchema
validation: Union[DatasetSchema, DictConfig]
"Validation DatasetSchema."
test: DatasetSchema
test: Union[DatasetSchema, DictConfig]
"Test DatasetSchema."
validation_rollout: PositiveInt = Field(example=1)
"Number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks."
# TODO(Helen): Ccheck that this equal or greater than the number of rollouts expected by callbacks ???
# TODO(Helen): Check that this equal or greater than the number of rollouts expected by callbacks ???
read_group_size: PositiveInt = Field(example=None)
"Number of GPUs per reader group. Defaults to number of GPUs (see BaseSchema validators)."
grid_indices: Union[FullGridIndicesSchema, MaskedGridIndicesSchema]
Expand Down
16 changes: 8 additions & 8 deletions training/src/anemoi/training/schemas/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class Checkpoint(BaseModel):


class FilesSchema(PydanticBaseModel):
dataset: Union[Path, dict[str, Path]] # dict option for multiple datasets
dataset: Union[Path, dict[str, Path], None] = Field(default=None) # dict option for multiple datasets
"Path to the dataset file."
graph: Union[Path, None] = Field(default=None)
graph: Union[Path, None] = None
"Path to the graph file."
checkpoint: dict[str, str]
"Each dictionary key is a checkpoint name, and the value is the path to the checkpoint file."
Expand All @@ -53,19 +53,19 @@ class Logs(PydanticBaseModel):


class PathsSchema(BaseModel):
data: Path
data: Union[Path, None] = None
"Path to the data directory."
graph: Path
graph: Union[Path, None] = None
"Path to the graph directory."
output: Path
output: Union[Path, None] = None
"Path to the output directory."
logs: Logs
logs: Union[Logs, None] = None
"Logging directories."
checkpoints: Path
"Path to the checkpoints directory."
plots: Path
plots: Union[Path, None] = None
"Path to the plots directory."
profiler: Path
profiler: Union[Path, None]
"Path to the profiler directory."


Expand Down

0 comments on commit 4792ee1

Please sign in to comment.