From 9e3f2fa4fa0fd453e545dc54e9755a24a5fa1de4 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Tue, 28 Nov 2023 17:07:26 -0800 Subject: [PATCH] Remove default diags running with only Python API --- e3sm_diags/e3sm_diags_driver.py | 3 +- e3sm_diags/parameter/core_parameter.py | 24 +------- e3sm_diags/parser/core_parser.py | 8 +-- e3sm_diags/run.py | 80 ++++++++++++-------------- e3sm_diags/viewer/main.py | 4 +- 5 files changed, 45 insertions(+), 74 deletions(-) diff --git a/e3sm_diags/e3sm_diags_driver.py b/e3sm_diags/e3sm_diags_driver.py index 111d20f0db..f99d6422bf 100644 --- a/e3sm_diags/e3sm_diags_driver.py +++ b/e3sm_diags/e3sm_diags_driver.py @@ -354,8 +354,9 @@ def main(parameters=[]): # If no parameters are passed, use the parser args as defaults. Otherwise, # create the dictionary of expected parameters. - if not parameters: + if len(parameters) == 0: parameters = get_parameters(parser) + expected_parameters = create_parameter_dict(parameters) if not os.path.exists(parameters[0].results_dir): diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py index 4b96c13ab2..6b677a38f1 100644 --- a/e3sm_diags/parameter/core_parameter.py +++ b/e3sm_diags/parameter/core_parameter.py @@ -58,28 +58,8 @@ def __init__(self): # 'model_vs_obs' (by default), 'model_vs_model', or 'obs_vs_obs'. self.run_type: str = "model_vs_obs" - # A list of the sets to be run. Default is all sets - self.sets: List[str] = [ - "zonal_mean_xy", - "zonal_mean_2d", - "zonal_mean_2d_stratosphere", - "meridional_mean_2d", - "lat_lon", - "polar", - "area_mean_time_series", - "cosp_histogram", - "enso_diags", - "qbo", - "streamflow", - "diurnal_cycle", - "arm_diags", - "tc_analysis", - "annual_cycle_zonal_mean", - "lat_lon_land", - "lat_lon_river", - "aerosol_aeronet", - "aerosol_budget", - ] + # A list of the sets to be run. + self.sets: List[str] = [] # The current set that is being ran when looping over sets in # `e3sm_diags_driver.run_diag()`. diff --git a/e3sm_diags/parser/core_parser.py b/e3sm_diags/parser/core_parser.py index 2d1defd23b..169eaf0072 100644 --- a/e3sm_diags/parser/core_parser.py +++ b/e3sm_diags/parser/core_parser.py @@ -784,8 +784,7 @@ def get_cfg_parameters( RuntimeError If the parameters input file is not `.json` or `.cfg` format. """ - - parameters = [] + params = [] self._parse_arguments() @@ -801,10 +800,7 @@ def get_cfg_parameters( else: raise RuntimeError("The parameters input file must be a .cfg file") - for p in params: - parameters.append(p) - - return parameters + return params def _get_cfg_parameters( self, cfg_file, check_values=False, argparse_vals_only=True diff --git a/e3sm_diags/run.py b/e3sm_diags/run.py index d7a6ba8709..86e018e053 100644 --- a/e3sm_diags/run.py +++ b/e3sm_diags/run.py @@ -1,7 +1,9 @@ import copy +from itertools import chain +from typing import List, Union import e3sm_diags # noqa: F401 -from e3sm_diags.e3sm_diags_driver import get_default_diags_path, main +from e3sm_diags.e3sm_diags_driver import main from e3sm_diags.logger import custom_logger, move_log_to_prov_dir from e3sm_diags.parameter import SET_TO_PARAMETERS from e3sm_diags.parameter.core_parameter import CoreParameter @@ -18,7 +20,6 @@ class Run: """ def __init__(self): - self.sets_to_run = CoreParameter().sets self.parser = CoreParser() def run_diags(self, parameters): @@ -48,46 +49,47 @@ def get_final_parameters(self, parameters): # For each of the passed in parameters, we can only have one of # each type. - types = set([p.__class__ for p in parameters]) - if len(types) != len(parameters): - msg = "You passed in two or more parameters of the same type." + param_types_list = [ + p.__class__ for p in parameters if p.__class__ != CoreParameter + ] + param_types_set = set(param_types_list) + + if len(param_types_set) != len(param_types_list): + msg = ( + "You passed in two or more non-CoreParameter objects of the same type." + ) raise RuntimeError(msg) self._add_parent_attrs_to_children(parameters) - final_params = [] + # Get the sets to run using the parameter objects via the Python API + sets_to_run = [param.sets for param in parameters] + self.sets_to_run = list(chain.from_iterable(sets_to_run)) + api_params = [] for set_name in self.sets_to_run: - other_params = self._get_other_diags(parameters[0].run_type) - # For each of the set_names, get the corresponding parameter. - param = self._get_instance_of_param_class( + api_param = self._get_instance_of_param_class( SET_TO_PARAMETERS[set_name], parameters ) # Since each parameter will have lots of default values, we want to remove them. # Otherwise when calling get_parameters(), these default values # will take precedence over values defined in other_params. - self._remove_attrs_with_default_values(param) - param.sets = [set_name] - - params = self.parser.get_parameters( - orig_parameters=param, - other_parameters=other_params, - cmd_default_vars=False, - argparse_vals_only=False, - ) + self._remove_attrs_with_default_values(api_param) + api_param.sets = [set_name] + + # Makes sure that any parameters that are selectors will be in param. + self._add_attrs_with_default_values(api_param) - # Makes sure that any parameters that are selectors - # will be in param. - self._add_attrs_with_default_values(param) - # The select() call in get_parameters() was made for the original - # command-line way of using CDP. - # We just call it manually with the parameter object param. - params = self.parser.select(param, params) + api_params.append(api_param) - final_params.extend(params) + # Get the diagnostic parameter objects from the .cfg file passed via + # the -d/--diags CLI argument (if set). Otherwise, it will be an empty + # list. + cfg_params = self._get_diags_from_cfg_file() + final_params = api_params + cfg_params self.parser.check_values_of_params(final_params) return final_params @@ -233,29 +235,21 @@ def _get_instance_of_param_class(self, cls, parameters): msg = "There's weren't any class of types {} in your parameters." raise RuntimeError(msg.format(class_types)) - def _get_other_diags(self, run_type): + def _get_diags_from_cfg_file(self) -> Union[List, List[CoreParameter]]: """ - If the user has ran the script with a -d, get the diags for that. - If not, load the default diags based on sets_to_run and run_type. + Get parameters defined by the cfg file passed to -d/--diags (if set). + + If ``-d`` is not passed, return an empty list. """ args = self.parser.view_args() + params = [] - # If the user has passed in args with -d. if args.other_parameters: params = self.parser.get_cfg_parameters(argparse_vals_only=False) - else: - default_diags_paths = [ - get_default_diags_path(set_name, run_type, False) - for set_name in self.sets_to_run - ] - params = self.parser.get_cfg_parameters( - files_to_open=default_diags_paths, argparse_vals_only=False - ) - - # For each of the params, add in the default values - # using the parameter classes in SET_TO_PARAMETERS. - for i in range(len(params)): - params[i] = SET_TO_PARAMETERS[params[i].sets[0]]() + params[i] + # For each of the params, add in the default values + # using the parameter classes in SET_TO_PARAMETERS. + for i in range(len(params)): + params[i] = SET_TO_PARAMETERS[params[i].sets[0]]() + params[i] return params diff --git a/e3sm_diags/viewer/main.py b/e3sm_diags/viewer/main.py index 4227b3e1ee..23b6c786b2 100644 --- a/e3sm_diags/viewer/main.py +++ b/e3sm_diags/viewer/main.py @@ -113,8 +113,8 @@ def insert_data_in_row(row_obj, name, url): def create_viewer(root_dir, parameters): """ - Based of the parameters, find the files with the - certain extension and create the viewer in root_dir. + Based of the parameters, find the files with the certain extension and + create the viewer in root_dir. """ # Group each parameter object based on the `sets` parameter. set_to_parameters = collections.defaultdict(list)