Skip to content

Commit

Permalink
feat: Add a CLI to dump the Hydra configuration files into a single Y…
Browse files Browse the repository at this point in the history
…AML file. (#137)

Add a CLI to dump the Hydra configuration files into a single YAML file.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jjlk and pre-commit-ci[bot] authored Feb 20, 2025
1 parent 9331e63 commit ef1e76e
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 17 deletions.
83 changes: 66 additions & 17 deletions training/src/anemoi/training/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

from __future__ import annotations

import contextlib
import importlib.resources as pkg_resources
import logging
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -29,6 +31,7 @@

if TYPE_CHECKING:
import argparse
from collections.abc import Generator

from pydantic import BaseModel

Expand Down Expand Up @@ -72,25 +75,32 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None:
action="store_true",
)

help_msg = "Dump Anemoi configs to a YAML file."
dump = subparsers.add_parser(
"dump",
help=help_msg,
description=help_msg,
)
dump.add_argument("--config-path", "-i", default=Path.cwd(), type=Path, help="Configuration directory")
dump.add_argument("--config-name", "-n", default="dev", help="Name of the configuration")
dump.add_argument("--output", "-o", default="./config.yaml", type=Path, help="Output file path")
dump.add_argument("--overwrite", "-f", action="store_true")

def run(self, args: argparse.Namespace) -> None:

self.overwrite = args.overwrite

if args.subcommand == "generate":
LOGGER.info(
"Generating configs, please wait.",
)
self.traverse_config(args.output)

LOGGER.info("Inference checkpoint saved to %s", args.output)
return

if args.subcommand == "training-home":
anemoi_home = Path.home() / ".config" / "anemoi" / "training" / "config"
LOGGER.info(
"Generating configs, please wait.",
)
self.traverse_config(anemoi_home)
LOGGER.info("Inference checkpoint saved to %s", anemoi_home)
self.traverse_config(anemoi_home)
return

if args.subcommand == "validate":
Expand All @@ -103,6 +113,11 @@ def run(self, args: argparse.Namespace) -> None:
LOGGER.info("Config files validated.")
return

if args.subcommand == "dump":
LOGGER.info("Dumping config to %s", args.output)
self.dump_config(args.config_path, args.config_name, args.output)
return

def traverse_config(self, destination_dir: Path | str) -> None:
"""Writes the given configuration data to the specified file path."""
config_package = "anemoi.training.config"
Expand All @@ -113,17 +128,7 @@ def traverse_config(self, destination_dir: Path | str) -> None:

# Traverse through the package's config directory
with pkg_resources.as_file(pkg_resources.files(config_package)) as config_path:
for data in config_path.rglob("*"): # Recursively walk through all files and directories
item = Path(data)
if item.is_file() and item.suffix == ".yaml":
file_path = Path(destination_dir, item.relative_to(config_path))

file_path.parent.mkdir(parents=True, exist_ok=True)

if not file_path.exists() or self.overwrite:
self.copy_file(item, file_path)
else:
LOGGER.info("File %s already exists, skipping", file_path)
self.copy_files(config_path, destination_dir)

@staticmethod
def copy_file(item: Path, file_path: Path) -> None:
Expand All @@ -134,6 +139,20 @@ def copy_file(item: Path, file_path: Path) -> None:
except Exception:
LOGGER.exception("Failed to copy %s", item.name)

def copy_files(self, source_directory: Path, target_directory: Path) -> None:
"""Copies directory files to a target directory."""
for data in source_directory.rglob("*yaml"): # Recursively walk through all files and directories
item = Path(data)
if item.is_file():
file_path = Path(target_directory, item.relative_to(source_directory))

file_path.parent.mkdir(parents=True, exist_ok=True)

if not file_path.exists() or self.overwrite:
self.copy_file(item, file_path)
else:
LOGGER.info("File %s already exists, skipping", file_path)

def _mask_slurm_env_variables(self, cfg: DictConfig) -> None:
"""Mask environment variables are set."""
# Convert OmegaConf dict to YAML format (raw string)
Expand Down Expand Up @@ -187,6 +206,36 @@ def validate_config(self, name: Path | str, mask_env_vars: bool) -> None:
OmegaConf.resolve(cfg)
BaseSchema(**cfg)

def dump_config(self, config_path: Path, name: str, output: Path) -> None:
"""Dump config files in one YAML file."""
# Copy config files in temporary directory
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_dir = Path(tmpdirname)
self.copy_files(config_path, tmp_dir)
if not tmp_dir.exists():
LOGGER.error("No files found in %s", config_path.absolute())
raise FileNotFoundError

# Move to config directory to be able to handle hydra
with change_directory(tmp_dir), initialize(version_base=None, config_path="./"):
cfg = compose(config_name=name)

# Dump configuration in output file
LOGGER.info("Dumping file in %s.", output)
with output.open("w") as f:
f.write(OmegaConf.to_yaml(cfg))


@contextlib.contextmanager
def change_directory(destination: Path) -> Generator[None, None, None]:
"""A context manager to temporarily change the current working directory."""
original_directory = Path.cwd()
try:
os.chdir(destination)
yield
finally:
os.chdir(original_directory)


def extract_primitive_type_hints(model: type[BaseModel], prefix: str = "") -> dict[str, Any]:
field_types = {}
Expand Down
40 changes: 40 additions & 0 deletions training/tests/commands/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import tempfile
from pathlib import Path
from unittest import mock

import pytest
from omegaconf import OmegaConf

from anemoi.training.commands.config import ConfigGenerator


@pytest.fixture
def config_generator() -> ConfigGenerator:
return ConfigGenerator()


def test_dump_config(config_generator: ConfigGenerator) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
config_path = Path(tmpdirname) / "config"
config_path.mkdir(parents=True, exist_ok=True)
(config_path / "test.yaml").write_text("test: value")

output_path = Path(tmpdirname) / "output.yaml"
with mock.patch("anemoi.training.commands.config.ConfigGenerator.copy_files") as mock_copy_files, mock.patch(
"anemoi.training.commands.config.initialize",
), mock.patch("anemoi.training.commands.config.compose", return_value=OmegaConf.create({"test": "value"})):
config_generator.dump_config(config_path, "test", output_path)

mock_copy_files.assert_called_once_with(config_path, mock.ANY)
assert output_path.exists()
assert OmegaConf.load(output_path) == {"test": "value"}

0 comments on commit ef1e76e

Please sign in to comment.