Skip to content

Commit

Permalink
Remove default diags running with only Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Nov 29, 2023
1 parent 633b52c commit 9e3f2fa
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 74 deletions.
3 changes: 2 additions & 1 deletion e3sm_diags/e3sm_diags_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,9 @@ def main(parameters=[]):

# If no parameters are passed, use the parser args as defaults. Otherwise,
# create the dictionary of expected parameters.
if not parameters:
if len(parameters) == 0:
parameters = get_parameters(parser)

expected_parameters = create_parameter_dict(parameters)

if not os.path.exists(parameters[0].results_dir):
Expand Down
24 changes: 2 additions & 22 deletions e3sm_diags/parameter/core_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,28 +58,8 @@ def __init__(self):
# 'model_vs_obs' (by default), 'model_vs_model', or 'obs_vs_obs'.
self.run_type: str = "model_vs_obs"

# A list of the sets to be run. Default is all sets
self.sets: List[str] = [
"zonal_mean_xy",
"zonal_mean_2d",
"zonal_mean_2d_stratosphere",
"meridional_mean_2d",
"lat_lon",
"polar",
"area_mean_time_series",
"cosp_histogram",
"enso_diags",
"qbo",
"streamflow",
"diurnal_cycle",
"arm_diags",
"tc_analysis",
"annual_cycle_zonal_mean",
"lat_lon_land",
"lat_lon_river",
"aerosol_aeronet",
"aerosol_budget",
]
# A list of the sets to be run.
self.sets: List[str] = []

# The current set that is being ran when looping over sets in
# `e3sm_diags_driver.run_diag()`.
Expand Down
8 changes: 2 additions & 6 deletions e3sm_diags/parser/core_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,7 @@ def get_cfg_parameters(
RuntimeError
If the parameters input file is not `.json` or `.cfg` format.
"""

parameters = []
params = []

self._parse_arguments()

Expand All @@ -801,10 +800,7 @@ def get_cfg_parameters(
else:
raise RuntimeError("The parameters input file must be a .cfg file")

for p in params:
parameters.append(p)

return parameters
return params

def _get_cfg_parameters(
self, cfg_file, check_values=False, argparse_vals_only=True
Expand Down
80 changes: 37 additions & 43 deletions e3sm_diags/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
from itertools import chain
from typing import List, Union

import e3sm_diags # noqa: F401
from e3sm_diags.e3sm_diags_driver import get_default_diags_path, main
from e3sm_diags.e3sm_diags_driver import main
from e3sm_diags.logger import custom_logger, move_log_to_prov_dir
from e3sm_diags.parameter import SET_TO_PARAMETERS
from e3sm_diags.parameter.core_parameter import CoreParameter
Expand All @@ -18,7 +20,6 @@ class Run:
"""

def __init__(self):
self.sets_to_run = CoreParameter().sets
self.parser = CoreParser()

def run_diags(self, parameters):
Expand Down Expand Up @@ -48,46 +49,47 @@ def get_final_parameters(self, parameters):

# For each of the passed in parameters, we can only have one of
# each type.
types = set([p.__class__ for p in parameters])
if len(types) != len(parameters):
msg = "You passed in two or more parameters of the same type."
param_types_list = [
p.__class__ for p in parameters if p.__class__ != CoreParameter
]
param_types_set = set(param_types_list)

if len(param_types_set) != len(param_types_list):
msg = (
"You passed in two or more non-CoreParameter objects of the same type."
)
raise RuntimeError(msg)

self._add_parent_attrs_to_children(parameters)

final_params = []
# Get the sets to run using the parameter objects via the Python API
sets_to_run = [param.sets for param in parameters]
self.sets_to_run = list(chain.from_iterable(sets_to_run))

api_params = []
for set_name in self.sets_to_run:
other_params = self._get_other_diags(parameters[0].run_type)

# For each of the set_names, get the corresponding parameter.
param = self._get_instance_of_param_class(
api_param = self._get_instance_of_param_class(
SET_TO_PARAMETERS[set_name], parameters
)

# Since each parameter will have lots of default values, we want to remove them.
# Otherwise when calling get_parameters(), these default values
# will take precedence over values defined in other_params.
self._remove_attrs_with_default_values(param)
param.sets = [set_name]

params = self.parser.get_parameters(
orig_parameters=param,
other_parameters=other_params,
cmd_default_vars=False,
argparse_vals_only=False,
)
self._remove_attrs_with_default_values(api_param)
api_param.sets = [set_name]

# Makes sure that any parameters that are selectors will be in param.
self._add_attrs_with_default_values(api_param)

# Makes sure that any parameters that are selectors
# will be in param.
self._add_attrs_with_default_values(param)
# The select() call in get_parameters() was made for the original
# command-line way of using CDP.
# We just call it manually with the parameter object param.
params = self.parser.select(param, params)
api_params.append(api_param)

final_params.extend(params)
# Get the diagnostic parameter objects from the .cfg file passed via
# the -d/--diags CLI argument (if set). Otherwise, it will be an empty
# list.
cfg_params = self._get_diags_from_cfg_file()

final_params = api_params + cfg_params
self.parser.check_values_of_params(final_params)

return final_params
Expand Down Expand Up @@ -233,29 +235,21 @@ def _get_instance_of_param_class(self, cls, parameters):
msg = "There's weren't any class of types {} in your parameters."
raise RuntimeError(msg.format(class_types))

def _get_other_diags(self, run_type):
def _get_diags_from_cfg_file(self) -> Union[List, List[CoreParameter]]:
"""
If the user has ran the script with a -d, get the diags for that.
If not, load the default diags based on sets_to_run and run_type.
Get parameters defined by the cfg file passed to -d/--diags (if set).
If ``-d`` is not passed, return an empty list.
"""
args = self.parser.view_args()
params = []

# If the user has passed in args with -d.
if args.other_parameters:
params = self.parser.get_cfg_parameters(argparse_vals_only=False)
else:
default_diags_paths = [
get_default_diags_path(set_name, run_type, False)
for set_name in self.sets_to_run
]
params = self.parser.get_cfg_parameters(
files_to_open=default_diags_paths, argparse_vals_only=False
)

# For each of the params, add in the default values
# using the parameter classes in SET_TO_PARAMETERS.
for i in range(len(params)):
params[i] = SET_TO_PARAMETERS[params[i].sets[0]]() + params[i]
# For each of the params, add in the default values
# using the parameter classes in SET_TO_PARAMETERS.
for i in range(len(params)):
params[i] = SET_TO_PARAMETERS[params[i].sets[0]]() + params[i]

return params

Expand Down
4 changes: 2 additions & 2 deletions e3sm_diags/viewer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def insert_data_in_row(row_obj, name, url):

def create_viewer(root_dir, parameters):
"""
Based of the parameters, find the files with the
certain extension and create the viewer in root_dir.
Based of the parameters, find the files with the certain extension and
create the viewer in root_dir.
"""
# Group each parameter object based on the `sets` parameter.
set_to_parameters = collections.defaultdict(list)
Expand Down

0 comments on commit 9e3f2fa

Please sign in to comment.