diff --git a/training/src/anemoi/training/commands/config.py b/training/src/anemoi/training/commands/config.py index 4314ced2..67cd8c12 100644 --- a/training/src/anemoi/training/commands/config.py +++ b/training/src/anemoi/training/commands/config.py @@ -10,11 +10,13 @@ from __future__ import annotations +import contextlib import importlib.resources as pkg_resources import logging import os import re import shutil +import tempfile from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -29,6 +31,7 @@ if TYPE_CHECKING: import argparse + from collections.abc import Generator from pydantic import BaseModel @@ -72,25 +75,32 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None: action="store_true", ) + help_msg = "Dump Anemoi configs to a YAML file." + dump = subparsers.add_parser( + "dump", + help=help_msg, + description=help_msg, + ) + dump.add_argument("--config-path", "-i", default=Path.cwd(), type=Path, help="Configuration directory") + dump.add_argument("--config-name", "-n", default="dev", help="Name of the configuration") + dump.add_argument("--output", "-o", default="./config.yaml", type=Path, help="Output file path") + dump.add_argument("--overwrite", "-f", action="store_true") + def run(self, args: argparse.Namespace) -> None: self.overwrite = args.overwrite + if args.subcommand == "generate": LOGGER.info( "Generating configs, please wait.", ) self.traverse_config(args.output) - - LOGGER.info("Inference checkpoint saved to %s", args.output) return if args.subcommand == "training-home": anemoi_home = Path.home() / ".config" / "anemoi" / "training" / "config" - LOGGER.info( - "Generating configs, please wait.", - ) - self.traverse_config(anemoi_home) LOGGER.info("Inference checkpoint saved to %s", anemoi_home) + self.traverse_config(anemoi_home) return if args.subcommand == "validate": @@ -103,6 +113,11 @@ def run(self, args: argparse.Namespace) -> None: LOGGER.info("Config files validated.") return + if args.subcommand == "dump": + LOGGER.info("Dumping config to %s", args.output) + self.dump_config(args.config_path, args.config_name, args.output) + return + def traverse_config(self, destination_dir: Path | str) -> None: """Writes the given configuration data to the specified file path.""" config_package = "anemoi.training.config" @@ -113,17 +128,7 @@ def traverse_config(self, destination_dir: Path | str) -> None: # Traverse through the package's config directory with pkg_resources.as_file(pkg_resources.files(config_package)) as config_path: - for data in config_path.rglob("*"): # Recursively walk through all files and directories - item = Path(data) - if item.is_file() and item.suffix == ".yaml": - file_path = Path(destination_dir, item.relative_to(config_path)) - - file_path.parent.mkdir(parents=True, exist_ok=True) - - if not file_path.exists() or self.overwrite: - self.copy_file(item, file_path) - else: - LOGGER.info("File %s already exists, skipping", file_path) + self.copy_files(config_path, destination_dir) @staticmethod def copy_file(item: Path, file_path: Path) -> None: @@ -134,6 +139,20 @@ def copy_file(item: Path, file_path: Path) -> None: except Exception: LOGGER.exception("Failed to copy %s", item.name) + def copy_files(self, source_directory: Path, target_directory: Path) -> None: + """Copies directory files to a target directory.""" + for data in source_directory.rglob("*yaml"): # Recursively walk through all files and directories + item = Path(data) + if item.is_file(): + file_path = Path(target_directory, item.relative_to(source_directory)) + + file_path.parent.mkdir(parents=True, exist_ok=True) + + if not file_path.exists() or self.overwrite: + self.copy_file(item, file_path) + else: + LOGGER.info("File %s already exists, skipping", file_path) + def _mask_slurm_env_variables(self, cfg: DictConfig) -> None: """Mask environment variables are set.""" # Convert OmegaConf dict to YAML format (raw string) @@ -187,6 +206,36 @@ def validate_config(self, name: Path | str, mask_env_vars: bool) -> None: OmegaConf.resolve(cfg) BaseSchema(**cfg) + def dump_config(self, config_path: Path, name: str, output: Path) -> None: + """Dump config files in one YAML file.""" + # Copy config files in temporary directory + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_dir = Path(tmpdirname) + self.copy_files(config_path, tmp_dir) + if not tmp_dir.exists(): + LOGGER.error("No files found in %s", config_path.absolute()) + raise FileNotFoundError + + # Move to config directory to be able to handle hydra + with change_directory(tmp_dir), initialize(version_base=None, config_path="./"): + cfg = compose(config_name=name) + + # Dump configuration in output file + LOGGER.info("Dumping file in %s.", output) + with output.open("w") as f: + f.write(OmegaConf.to_yaml(cfg)) + + +@contextlib.contextmanager +def change_directory(destination: Path) -> Generator[None, None, None]: + """A context manager to temporarily change the current working directory.""" + original_directory = Path.cwd() + try: + os.chdir(destination) + yield + finally: + os.chdir(original_directory) + def extract_primitive_type_hints(model: type[BaseModel], prefix: str = "") -> dict[str, Any]: field_types = {} diff --git a/training/tests/commands/test_config.py b/training/tests/commands/test_config.py new file mode 100644 index 00000000..27e2dd6c --- /dev/null +++ b/training/tests/commands/test_config.py @@ -0,0 +1,40 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import tempfile +from pathlib import Path +from unittest import mock + +import pytest +from omegaconf import OmegaConf + +from anemoi.training.commands.config import ConfigGenerator + + +@pytest.fixture +def config_generator() -> ConfigGenerator: + return ConfigGenerator() + + +def test_dump_config(config_generator: ConfigGenerator) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + config_path = Path(tmpdirname) / "config" + config_path.mkdir(parents=True, exist_ok=True) + (config_path / "test.yaml").write_text("test: value") + + output_path = Path(tmpdirname) / "output.yaml" + with mock.patch("anemoi.training.commands.config.ConfigGenerator.copy_files") as mock_copy_files, mock.patch( + "anemoi.training.commands.config.initialize", + ), mock.patch("anemoi.training.commands.config.compose", return_value=OmegaConf.create({"test": "value"})): + config_generator.dump_config(config_path, "test", output_path) + + mock_copy_files.assert_called_once_with(config_path, mock.ANY) + assert output_path.exists() + assert OmegaConf.load(output_path) == {"test": "value"}