From 772cc203f8e22f3f3f4374e368df85fb6455ed72 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 12 Nov 2024 20:43:58 +0100 Subject: [PATCH] adapt all cli to use --config arg instead of `config` --- neural_lam/create_graph.py | 4 +++- neural_lam/datastore/plot_example.py | 17 +++++++++++++---- neural_lam/train_model.py | 3 ++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index fad3a44f..854818b7 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -561,7 +561,7 @@ def create_graph_from_datastore( def cli(input_args=None): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( - "config_path", + "--config_path", type=str, help="Path to neural-lam configuration file", ) @@ -590,6 +590,8 @@ def cli(input_args=None): ) args = parser.parse_args(input_args) + assert args.config is not None, "Specify your config with --config_path" + # Load neural-lam configuration and datastore to use _, datastore = load_config_and_datastore(config_path=args.config_path) diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py index 734aa2a6..827ec403 100644 --- a/neural_lam/datastore/plot_example.py +++ b/neural_lam/datastore/plot_example.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt # Local -from .config import load_config_and_datastore +from . import DATASTORES, init_datastore def plot_example_from_datastore( @@ -105,6 +105,13 @@ def _parse_dict(arg_str): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--datastore_kind", + type=str, + choices=DATASTORES.keys(), + default="mdp", + help="Kind of datastore to use", + ) parser.add_argument( "--datastore_config_path", type=str, @@ -150,6 +157,10 @@ def _parse_dict(arg_str): ) args = parser.parse_args() + assert ( + args.datastore_config_path is not None + ), "Specify your datastore config with --datastore_config_path" + selection = dict(args.selection) index_selection = dict(args.index_selection) @@ -161,9 +172,7 @@ def _parse_dict(arg_str): "column dimension and/or selection." ) - _, datastore = load_config_and_datastore( - config_path=args.datastore_config_path - ) + _, datastore = init_datastore(config_path=args.datastore_config_path) plot_example_from_datastore( args.category, diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index a638e497..8f400b3b 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -30,7 +30,7 @@ def main(input_args=None): description="Train or evaluate NeurWP models for LAM" ) parser.add_argument( - "config_path", + "--config_path", type=str, help="Path to the configuration for neural-lam", ) @@ -209,6 +209,7 @@ def main(input_args=None): } # Asserts for arguments + assert args.config is not None, "Specify your config with --config_path" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None,