From 4792ee145a3bda05f934180485ddb7b6e3bb25d4 Mon Sep 17 00:00:00 2001 From: Helen Theissen Date: Tue, 25 Feb 2025 13:58:46 +0000 Subject: [PATCH] fix: dataset schema too defined too strictly (#143) * fix: allow dict for datasets in dataloader --------- Co-authored-by: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> --- training/docs/user-guide/configuring.rst | 4 ++-- .../src/anemoi/training/commands/config.py | 8 ++++---- .../src/anemoi/training/config/model/gnn.yaml | 3 --- .../config/model/graphtransformer.yaml | 3 --- .../training/config/model/transformer.yaml | 3 --- .../src/anemoi/training/schemas/base_schema.py | 10 +++++++--- .../src/anemoi/training/schemas/dataloader.py | 18 ++++++++++++------ .../src/anemoi/training/schemas/hardware.py | 16 ++++++++-------- 8 files changed, 33 insertions(+), 32 deletions(-) diff --git a/training/docs/user-guide/configuring.rst b/training/docs/user-guide/configuring.rst index de8c131d..c28432b5 100644 --- a/training/docs/user-guide/configuring.rst +++ b/training/docs/user-guide/configuring.rst @@ -163,7 +163,7 @@ run using the following command: .. code:: bash - anemoi-training config validate --name debug.yaml + anemoi-training config validate --config-name debug.yaml This will check that the configuration is valid and that all the required fields are present. If your config is correctly defined then @@ -195,7 +195,7 @@ values: .. code:: bash - (anemoi_core_venv)[] $ anemoi-training config validate --name=debug --mask_env_vars + (anemoi_core_venv)[] $ anemoi-training config validate --config-name=debug --mask_env_vars 2025-02-16 17:48:38 INFO Validating configs. 2025-02-16 17:48:38 WARNING Note that this command is not taking into account if your config has a no_validation flag.So this command will validate the config regardless of the flag. 2025-01-28 09:37:23 INFO Prepending Anemoi Home (/home_path/.config/anemoi/training/config) to the search path. diff --git a/training/src/anemoi/training/commands/config.py b/training/src/anemoi/training/commands/config.py index 67cd8c12..4d77ade8 100644 --- a/training/src/anemoi/training/commands/config.py +++ b/training/src/anemoi/training/commands/config.py @@ -66,7 +66,7 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None: help_msg = "Validate the Anemoi training configs." validate = subparsers.add_parser("validate", help=help_msg, description=help_msg) - validate.add_argument("--name", help="Name of the primary config file") + validate.add_argument("--config-name", help="Name of the primary config file") validate.add_argument("--overwrite", "-f", action="store_true") validate.add_argument( "--mask_env_vars", @@ -109,7 +109,7 @@ def run(self, args: argparse.Namespace) -> None: "Note that this command is not taking into account if your config has a no_validation flag." "So this command will validate the config regardless of the flag.", ) - self.validate_config(args.name, args.mask_env_vars) + self.validate_config(args.config_name, args.mask_env_vars) LOGGER.info("Config files validated.") return @@ -197,10 +197,10 @@ def _mask_slurm_env_variables(self, cfg: DictConfig) -> None: return OmegaConf.create(updated_cfg) - def validate_config(self, name: Path | str, mask_env_vars: bool) -> None: + def validate_config(self, config_name: Path | str, mask_env_vars: bool) -> None: """Validates the configuration files in the given directory.""" with initialize(version_base=None, config_path=""): - cfg = compose(config_name=name) + cfg = compose(config_name=config_name) if mask_env_vars: cfg = self._mask_slurm_env_variables(cfg) OmegaConf.resolve(cfg) diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 16735578..5eda609b 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -19,7 +19,6 @@ layer_kernels: processor: _target_: anemoi.models.layers.processor.GNNProcessor - _convert_: all activation: ${model.activation} trainable_size: ${model.trainable_parameters.hidden2hidden} sub_graph_edge_attributes: ${model.attributes.edges} @@ -30,7 +29,6 @@ processor: encoder: _target_: anemoi.models.layers.mapper.GNNForwardMapper - _convert_: all trainable_size: ${model.trainable_parameters.data2hidden} sub_graph_edge_attributes: ${model.attributes.edges} activation: ${model.activation} @@ -40,7 +38,6 @@ encoder: decoder: _target_: anemoi.models.layers.mapper.GNNBackwardMapper - _convert_: all trainable_size: ${model.trainable_parameters.hidden2data} sub_graph_edge_attributes: ${model.attributes.edges} activation: ${model.activation} diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index ba95c6e0..15a35abf 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -19,7 +19,6 @@ layer_kernels: processor: _target_: anemoi.models.layers.processor.GraphTransformerProcessor - _convert_: all activation: ${model.activation} trainable_size: ${model.trainable_parameters.hidden2hidden} sub_graph_edge_attributes: ${model.attributes.edges} @@ -31,7 +30,6 @@ processor: encoder: _target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper - _convert_: all trainable_size: ${model.trainable_parameters.data2hidden} sub_graph_edge_attributes: ${model.attributes.edges} activation: ${model.activation} @@ -42,7 +40,6 @@ encoder: decoder: _target_: anemoi.models.layers.mapper.GraphTransformerBackwardMapper - _convert_: all trainable_size: ${model.trainable_parameters.hidden2data} sub_graph_edge_attributes: ${model.attributes.edges} activation: ${model.activation} diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index eae42292..7e36326e 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -18,7 +18,6 @@ layer_kernels: processor: _target_: anemoi.models.layers.processor.TransformerProcessor - _convert_: all activation: ${model.activation} num_layers: 16 num_chunks: 2 @@ -34,7 +33,6 @@ processor: encoder: _target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper - _convert_: all trainable_size: ${model.trainable_parameters.data2hidden} sub_graph_edge_attributes: ${model.attributes.edges} activation: ${model.activation} @@ -46,7 +44,6 @@ encoder: decoder: _target_: anemoi.models.layers.mapper.GraphTransformerBackwardMapper - _convert_: all trainable_size: ${model.trainable_parameters.hidden2data} sub_graph_edge_attributes: ${model.attributes.edges} activation: ${model.activation} diff --git a/training/src/anemoi/training/schemas/base_schema.py b/training/src/anemoi/training/schemas/base_schema.py index 226177e7..31d3536e 100644 --- a/training/src/anemoi/training/schemas/base_schema.py +++ b/training/src/anemoi/training/schemas/base_schema.py @@ -69,11 +69,15 @@ def set_read_group_size_if_not_provided(self) -> BaseSchema: @model_validator(mode="after") def check_log_paths_available_for_loggers(self) -> BaseSchema: logger = [] - if self.diagnostics.log.wandb.enabled and not self.hardware.paths.logs.wandb: + if self.diagnostics.log.wandb.enabled and (not self.hardware.paths.logs or not self.hardware.paths.logs.wandb): logger.append("wandb") - if self.diagnostics.log.mlflow.enabled and not self.hardware.paths.logs.mlflow: + if self.diagnostics.log.mlflow.enabled and ( + not self.hardware.paths.logs or not self.hardware.paths.logs.mlflow + ): logger.append("mlflow") - if self.diagnostics.log.tensorboard.enabled and not self.hardware.paths.logs.tensorboard: + if self.diagnostics.log.tensorboard.enabled and ( + not self.hardware.paths.logs or not self.hardware.paths.logs.tensorboard + ): logger.append("tensorboard") if logger: diff --git a/training/src/anemoi/training/schemas/dataloader.py b/training/src/anemoi/training/schemas/dataloader.py index 5d4696e7..83fbb66c 100644 --- a/training/src/anemoi/training/schemas/dataloader.py +++ b/training/src/anemoi/training/schemas/dataloader.py @@ -14,9 +14,12 @@ from pathlib import Path # noqa: TC003 from typing import Any from typing import Literal +from typing import Optional from typing import Union +from omegaconf import DictConfig # noqa: TC002 from pydantic import BaseModel as PydanticBaseModel +from pydantic import ConfigDict from pydantic import Field from pydantic import PositiveInt from pydantic import RootModel @@ -57,10 +60,10 @@ def as_seconds(self) -> int: return int(self.as_timedelta.total_seconds()) -class DatasetSchema(BaseModel): +class DatasetSchema(PydanticBaseModel): """Dataset configuration schema.""" - dataset: Union[str, dict, Path, list[dict]] + dataset: Optional[Union[str, dict, Path, list[dict]]] = None "Dataset, see anemoi-datasets" start: Union[str, int, None] = Field(default=None) "Starting datetime for sample of the dataset." @@ -104,6 +107,9 @@ class MaskedGridIndicesSchema(BaseModel): class DataLoaderSchema(PydanticBaseModel): + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + prefetch_factor: int = Field(example=2, ge=0) "Number of batches loaded in advance by each worker." pin_memory: bool = Field(example=True) @@ -114,15 +120,15 @@ class DataLoaderSchema(PydanticBaseModel): "Per-GPU batch size." limit_batches: LoaderSet = Field(example=None) "Limit number of batches to run. Default value null, will run on all the batches." - training: DatasetSchema + training: Union[DatasetSchema, DictConfig] "Training DatasetSchema." - validation: DatasetSchema + validation: Union[DatasetSchema, DictConfig] "Validation DatasetSchema." - test: DatasetSchema + test: Union[DatasetSchema, DictConfig] "Test DatasetSchema." validation_rollout: PositiveInt = Field(example=1) "Number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks." - # TODO(Helen): Ccheck that this equal or greater than the number of rollouts expected by callbacks ??? + # TODO(Helen): Check that this equal or greater than the number of rollouts expected by callbacks ??? read_group_size: PositiveInt = Field(example=None) "Number of GPUs per reader group. Defaults to number of GPUs (see BaseSchema validators)." grid_indices: Union[FullGridIndicesSchema, MaskedGridIndicesSchema] diff --git a/training/src/anemoi/training/schemas/hardware.py b/training/src/anemoi/training/schemas/hardware.py index e698cdce..b92ee6b4 100644 --- a/training/src/anemoi/training/schemas/hardware.py +++ b/training/src/anemoi/training/schemas/hardware.py @@ -34,9 +34,9 @@ class Checkpoint(BaseModel): class FilesSchema(PydanticBaseModel): - dataset: Union[Path, dict[str, Path]] # dict option for multiple datasets + dataset: Union[Path, dict[str, Path], None] = Field(default=None) # dict option for multiple datasets "Path to the dataset file." - graph: Union[Path, None] = Field(default=None) + graph: Union[Path, None] = None "Path to the graph file." checkpoint: dict[str, str] "Each dictionary key is a checkpoint name, and the value is the path to the checkpoint file." @@ -53,19 +53,19 @@ class Logs(PydanticBaseModel): class PathsSchema(BaseModel): - data: Path + data: Union[Path, None] = None "Path to the data directory." - graph: Path + graph: Union[Path, None] = None "Path to the graph directory." - output: Path + output: Union[Path, None] = None "Path to the output directory." - logs: Logs + logs: Union[Logs, None] = None "Logging directories." checkpoints: Path "Path to the checkpoints directory." - plots: Path + plots: Union[Path, None] = None "Path to the plots directory." - profiler: Path + profiler: Union[Path, None] "Path to the profiler directory."