Skip to content

Commit

Permalink
Add initial code for fixing default sets behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Nov 28, 2023
1 parent a0a3561 commit 643bd93
Showing 3 changed files with 28 additions and 47 deletions.
22 changes: 1 addition & 21 deletions e3sm_diags/parameter/core_parameter.py
Original file line number Diff line number Diff line change
@@ -59,27 +59,7 @@ def __init__(self):
self.run_type: str = "model_vs_obs"

# A list of the sets to be run. Default is all sets
self.sets: List[str] = [
"zonal_mean_xy",
"zonal_mean_2d",
"zonal_mean_2d_stratosphere",
"meridional_mean_2d",
"lat_lon",
"polar",
"area_mean_time_series",
"cosp_histogram",
"enso_diags",
"qbo",
"streamflow",
"diurnal_cycle",
"arm_diags",
"tc_analysis",
"annual_cycle_zonal_mean",
"lat_lon_land",
"lat_lon_river",
"aerosol_aeronet",
"aerosol_budget",
]
self.sets: List[str] = []

# The current set that is being ran when looping over sets in
# `e3sm_diags_driver.run_diag()`.
49 changes: 25 additions & 24 deletions e3sm_diags/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
from itertools import chain
from typing import List

import e3sm_diags # noqa: F401
from e3sm_diags.e3sm_diags_driver import get_default_diags_path, main
from e3sm_diags.e3sm_diags_driver import main
from e3sm_diags.logger import custom_logger, move_log_to_prov_dir
from e3sm_diags.parameter import SET_TO_PARAMETERS
from e3sm_diags.parameter.core_parameter import CoreParameter
@@ -18,7 +20,6 @@ class Run:
"""

def __init__(self):
self.sets_to_run = CoreParameter().sets
self.parser = CoreParser()

def run_diags(self, parameters):
@@ -48,18 +49,24 @@ def get_final_parameters(self, parameters):

# For each of the passed in parameters, we can only have one of
# each type.
types = set([p.__class__ for p in parameters])
if len(types) != len(parameters):
msg = "You passed in two or more parameters of the same type."
param_types_list = [
p.__class__ for p in parameters if p.__class__ != CoreParameter
]
param_types_set = set(param_types_list)

if len(param_types_set) != len(param_types_list):
msg = "You passed in two or more non-CoreParameter parameters of the same type."
raise RuntimeError(msg)

self._add_parent_attrs_to_children(parameters)

final_params = []

for set_name in self.sets_to_run:
other_params = self._get_other_diags(parameters[0].run_type)
# Get the sets to run using the parameter objects.
sets_to_run = [param.sets for param in parameters]
self.sets_to_run = list(chain.from_iterable(sets_to_run))

for set_name in self.sets_to_run:
# For each of the set_names, get the corresponding parameter.
param = self._get_instance_of_param_class(
SET_TO_PARAMETERS[set_name], parameters
@@ -71,6 +78,7 @@ def get_final_parameters(self, parameters):
self._remove_attrs_with_default_values(param)
param.sets = [set_name]

other_params = self._get_diags_from_cfg_file()
params = self.parser.get_parameters(
orig_parameters=param,
other_parameters=other_params,
@@ -233,31 +241,24 @@ def _get_instance_of_param_class(self, cls, parameters):
msg = "There's weren't any class of types {} in your parameters."
raise RuntimeError(msg.format(class_types))

def _get_other_diags(self, run_type):
"""
If the user has ran the script with a -d, get the diags for that.
If not, load the default diags based on sets_to_run and run_type.
def _get_diags_from_cfg_file(self) -> List[CoreParameter] | List:
"""Get diagnostics defined in a cfg file using ``-d`` CLI arg.
If ``-d`` is not passed, return an empty list.
"""
args = self.parser.view_args()

# If the user has passed in args with -d.
if args.other_parameters:
params = self.parser.get_cfg_parameters(argparse_vals_only=False)
else:
default_diags_paths = [
get_default_diags_path(set_name, run_type, False)
for set_name in self.sets_to_run
]
params = self.parser.get_cfg_parameters(
files_to_open=default_diags_paths, argparse_vals_only=False
)
# For each of the params, add in the default values
# using the parameter classes in SET_TO_PARAMETERS.
for i in range(len(params)):
params[i] = SET_TO_PARAMETERS[params[i].sets[0]]() + params[i]

# For each of the params, add in the default values
# using the parameter classes in SET_TO_PARAMETERS.
for i in range(len(params)):
params[i] = SET_TO_PARAMETERS[params[i].sets[0]]() + params[i]
return params

return params
return []


runner = Run()
4 changes: 2 additions & 2 deletions e3sm_diags/viewer/main.py
Original file line number Diff line number Diff line change
@@ -113,8 +113,8 @@ def insert_data_in_row(row_obj, name, url):

def create_viewer(root_dir, parameters):
"""
Based of the parameters, find the files with the
certain extension and create the viewer in root_dir.
Based of the parameters, find the files with the certain extension and
create the viewer in root_dir.
"""
# Group each parameter object based on the `sets` parameter.
set_to_parameters = collections.defaultdict(list)

0 comments on commit 643bd93

Please sign in to comment.