diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml
index d19f29a5b..f3557732c 100644
--- a/.github/workflows/build_workflow.yml
+++ b/.github/workflows/build_workflow.yml
@@ -5,7 +5,7 @@ on:
branches: [main]
pull_request:
- branches: [main]
+ branches: [main, cdat-migration-fy24]
workflow_dispatch:
diff --git a/auxiliary_tools/template_cdat_regression_test.ipynb b/auxiliary_tools/template_cdat_regression_test.ipynb
new file mode 100644
index 000000000..8b4d00bd1
--- /dev/null
+++ b/auxiliary_tools/template_cdat_regression_test.ipynb
@@ -0,0 +1,1333 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# CDAT Migration Regression Test (FY24)\n",
+ "\n",
+ "This notebook is used to perform regression testing between the development and\n",
+ "production versions of a diagnostic set.\n",
+ "\n",
+ "## How it works\n",
+ "\n",
+ "It compares the relative differences (%) between two sets of `.json` files in two\n",
+ "separate directories, one for the refactored code and the other for the `main` branch.\n",
+ "\n",
+ "It will display metrics values with relative differences >= 2%. Relative differences are used instead of absolute differences because:\n",
+ "\n",
+ "- Relative differences are in percentages, which shows the scale of the differences.\n",
+ "- Absolute differences are just a raw number that doesn't factor in\n",
+ " floating point size (e.g., 100.00 vs. 0.0001), which can be misleading.\n",
+ "\n",
+ "## How to use\n",
+ "\n",
+ "PREREQUISITE: The diagnostic set's metrics stored in `.json` files in two directories\n",
+ "(dev and `main` branches).\n",
+ "\n",
+ "1. Make a copy of this notebook.\n",
+ "2. Run `mamba create -n cdat_regression_test -y -c conda-forge \"python<3.12\" pandas matplotlib-base ipykernel`\n",
+ "3. Run `mamba activate cdat_regression_test`\n",
+ "4. Update `DEV_PATH` and `PROD_PATH` in the copy of your notebook.\n",
+ "5. Run all cells IN ORDER.\n",
+ "6. Review results for any outstanding differences (>= 2%).\n",
+ " - Debug these differences (e.g., bug in metrics functions, incorrect variable references, etc.)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup Code\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import glob\n",
+ "import math\n",
+ "from typing import List\n",
+ "\n",
+ "import pandas as pd\n",
+ "\n",
+ "# TODO: Update DEV_RESULTS and PROD_RESULTS to your diagnostic sets.\n",
+ "DEV_PATH = \"/global/cfs/cdirs/e3sm/www/vo13/examples_658/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model\"\n",
+ "PROD_PATH = \"/global/cfs/cdirs/e3sm/www/vo13/examples/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model\"\n",
+ "\n",
+ "DEV_GLOB = sorted(glob.glob(DEV_PATH + \"/*.json\"))\n",
+ "PROD_GLOB = sorted(glob.glob(PROD_PATH + \"/*.json\"))\n",
+ "\n",
+ "# The names of the columns that store percentage difference values.\n",
+ "PERCENTAGE_COLUMNS = [\n",
+ " \"test DIFF (%)\",\n",
+ " \"ref DIFF (%)\",\n",
+ " \"test_regrid DIFF (%)\",\n",
+ " \"ref_regrid DIFF (%)\",\n",
+ " \"diff DIFF (%)\",\n",
+ " \"misc DIFF (%)\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Core Functions\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_metrics(filepaths: List[str]) -> pd.DataFrame:\n",
+ " \"\"\"Get the metrics using a glob of `.json` metric files in a directory.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " filepaths : List[str]\n",
+ " The filepaths for metrics `.json` files.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pd.DataFrame\n",
+ " The DataFrame containing the metrics for all of the variables in\n",
+ " the results directory.\n",
+ " \"\"\"\n",
+ " metrics = []\n",
+ "\n",
+ " for filepath in filepaths:\n",
+ " df = pd.read_json(filepath)\n",
+ "\n",
+ " filename = filepath.split(\"/\")[-1]\n",
+ " var_key = filename.split(\"-\")[1]\n",
+ "\n",
+ " # Add the variable key to the MultiIndex and update the index\n",
+ " # before stacking to make the DataFrame easier to parse.\n",
+ " multiindex = pd.MultiIndex.from_product([[var_key], [*df.index]])\n",
+ " df = df.set_index(multiindex)\n",
+ " df.stack()\n",
+ "\n",
+ " metrics.append(df)\n",
+ "\n",
+ " df_final = pd.concat(metrics)\n",
+ "\n",
+ " # Reorder columns and drop \"unit\" column (string dtype breaks Pandas\n",
+ " # arithmetic).\n",
+ " df_final = df_final[[\"test\", \"ref\", \"test_regrid\", \"ref_regrid\", \"diff\", \"misc\"]]\n",
+ "\n",
+ " return df_final\n",
+ "\n",
+ "\n",
+ "def get_rel_diffs(df_actual: pd.DataFrame, df_reference: pd.DataFrame) -> pd.DataFrame:\n",
+ " \"\"\"Get the relative differences between two DataFrames.\n",
+ "\n",
+ " Formula: abs(actual - reference) / abs(actual)\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " df_actual : pd.DataFrame\n",
+ " The first DataFrame representing \"actual\" results (dev branch).\n",
+ " df_reference : pd.DataFrame\n",
+ " The second DataFrame representing \"reference\" results (main branch).\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pd.DataFrame\n",
+ " The DataFrame containing absolute and relative differences between\n",
+ " the metrics DataFrames.\n",
+ " \"\"\"\n",
+ " df_diff = abs(df_actual - df_reference) / abs(df_actual)\n",
+ " df_diff = df_diff.add_suffix(\" DIFF (%)\")\n",
+ "\n",
+ " return df_diff\n",
+ "\n",
+ "\n",
+ "def sort_columns(df: pd.DataFrame) -> pd.DataFrame:\n",
+ " \"\"\"Sorts the order of the columns for the final DataFrame output.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " df : pd.DataFrame\n",
+ " The final DataFrame output.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pd.DataFrame\n",
+ " The final DataFrame output with sorted columns.\n",
+ " \"\"\"\n",
+ " columns = [\n",
+ " \"test_dev\",\n",
+ " \"test_prod\",\n",
+ " \"test DIFF (%)\",\n",
+ " \"ref_dev\",\n",
+ " \"ref_prod\",\n",
+ " \"ref DIFF (%)\",\n",
+ " \"test_regrid_dev\",\n",
+ " \"test_regrid_prod\",\n",
+ " \"test_regrid DIFF (%)\",\n",
+ " \"ref_regrid_dev\",\n",
+ " \"ref_regrid_prod\",\n",
+ " \"ref_regrid DIFF (%)\",\n",
+ " \"diff_dev\",\n",
+ " \"diff_prod\",\n",
+ " \"diff DIFF (%)\",\n",
+ " \"misc_dev\",\n",
+ " \"misc_prod\",\n",
+ " \"misc DIFF (%)\",\n",
+ " ]\n",
+ "\n",
+ " df_new = df.copy()\n",
+ " df_new = df_new[columns]\n",
+ "\n",
+ " return df_new\n",
+ "\n",
+ "\n",
+ "def update_diffs_to_pct(df: pd.DataFrame):\n",
+ " \"\"\"Update relative diff columns from float to string percentage.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " df : pd.DataFrame\n",
+ " The final DataFrame containing metrics and diffs (floats).\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pd.DataFrame\n",
+ " The final DataFrame containing metrics and diffs (str percentage).\n",
+ " \"\"\"\n",
+ " df_new = df.copy()\n",
+ " df_new[PERCENTAGE_COLUMNS] = df_new[PERCENTAGE_COLUMNS].map(\n",
+ " lambda x: \"{0:.2f}%\".format(x * 100) if not math.isnan(x) else x\n",
+ " )\n",
+ "\n",
+ " return df_new"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Get the DataFrame containing development and production metrics.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_metrics_dev = get_metrics(DEV_GLOB)\n",
+ "df_metrics_prod = get_metrics(PROD_GLOB)\n",
+ "df_metrics_all = pd.concat(\n",
+ " [df_metrics_dev.add_suffix(\"_dev\"), df_metrics_prod.add_suffix(\"_prod\")],\n",
+ " axis=1,\n",
+ " join=\"outer\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Get DataFrame for differences >= 2%.\n",
+ "\n",
+ "- Get the relative differences for all metrics\n",
+ "- Filter down metrics to those with differences >= 2%\n",
+ " - If all cells in a row are NaN (< 2%), the entire row is dropped to make the results easier to parse.\n",
+ " - Any remaining NaN cells are below < 2% difference and **should be ignored**.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_metrics_diffs = get_rel_diffs(df_metrics_dev, df_metrics_prod)\n",
+ "df_metrics_diffs_thres = df_metrics_diffs[df_metrics_diffs >= 0.02]\n",
+ "df_metrics_diffs_thres = df_metrics_diffs_thres.dropna(\n",
+ " axis=0, how=\"all\", ignore_index=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Combine both DataFrames to get the final result.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_final = df_metrics_diffs_thres.join(df_metrics_all)\n",
+ "df_final = sort_columns(df_final)\n",
+ "df_final = update_diffs_to_pct(df_final)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Display final DataFrame and review results.\n",
+ "\n",
+ "- Red cells are differences >= 2%\n",
+ "- `nan` cells are differences < 2% and **should be ignored**\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " var_key | \n",
+ " metric | \n",
+ " test_dev | \n",
+ " test_prod | \n",
+ " test DIFF (%) | \n",
+ " ref_dev | \n",
+ " ref_prod | \n",
+ " ref DIFF (%) | \n",
+ " test_regrid_dev | \n",
+ " test_regrid_prod | \n",
+ " test_regrid DIFF (%) | \n",
+ " ref_regrid_dev | \n",
+ " ref_regrid_prod | \n",
+ " ref_regrid DIFF (%) | \n",
+ " diff_dev | \n",
+ " diff_prod | \n",
+ " diff DIFF (%) | \n",
+ " misc_dev | \n",
+ " misc_prod | \n",
+ " misc DIFF (%) | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " FLUT | \n",
+ " max | \n",
+ " 299.911864 | \n",
+ " 299.355074 | \n",
+ " nan | \n",
+ " 300.162128 | \n",
+ " 299.776167 | \n",
+ " nan | \n",
+ " 299.911864 | \n",
+ " 299.355074 | \n",
+ " nan | \n",
+ " 300.162128 | \n",
+ " 299.776167 | \n",
+ " nan | \n",
+ " 9.492359 | \n",
+ " 9.788809 | \n",
+ " 3.12% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " FLUT | \n",
+ " min | \n",
+ " 124.610884 | \n",
+ " 125.987072 | \n",
+ " nan | \n",
+ " 122.878196 | \n",
+ " 124.148986 | \n",
+ " nan | \n",
+ " 124.610884 | \n",
+ " 125.987072 | \n",
+ " nan | \n",
+ " 122.878196 | \n",
+ " 124.148986 | \n",
+ " nan | \n",
+ " -15.505809 | \n",
+ " -17.032325 | \n",
+ " 9.84% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " FSNS | \n",
+ " max | \n",
+ " 269.789702 | \n",
+ " 269.798166 | \n",
+ " nan | \n",
+ " 272.722362 | \n",
+ " 272.184917 | \n",
+ " nan | \n",
+ " 269.789702 | \n",
+ " 269.798166 | \n",
+ " nan | \n",
+ " 272.722362 | \n",
+ " 272.184917 | \n",
+ " nan | \n",
+ " 20.647929 | \n",
+ " 24.859852 | \n",
+ " 20.40% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " FSNS | \n",
+ " min | \n",
+ " 16.897423 | \n",
+ " 17.760889 | \n",
+ " 5.11% | \n",
+ " 16.710134 | \n",
+ " 16.237061 | \n",
+ " 2.83% | \n",
+ " 16.897423 | \n",
+ " 17.760889 | \n",
+ " 5.11% | \n",
+ " 16.710134 | \n",
+ " 16.237061 | \n",
+ " 2.83% | \n",
+ " -28.822277 | \n",
+ " -28.324921 | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " FSNTOA | \n",
+ " max | \n",
+ " 360.624327 | \n",
+ " 360.209193 | \n",
+ " nan | \n",
+ " 362.188816 | \n",
+ " 361.778529 | \n",
+ " nan | \n",
+ " 360.624327 | \n",
+ " 360.209193 | \n",
+ " nan | \n",
+ " 362.188816 | \n",
+ " 361.778529 | \n",
+ " nan | \n",
+ " 18.602276 | \n",
+ " 22.624266 | \n",
+ " 21.62% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " FSNTOA | \n",
+ " mean | \n",
+ " 239.859777 | \n",
+ " 240.001860 | \n",
+ " nan | \n",
+ " 241.439641 | \n",
+ " 241.544384 | \n",
+ " nan | \n",
+ " 239.859777 | \n",
+ " 240.001860 | \n",
+ " nan | \n",
+ " 241.439641 | \n",
+ " 241.544384 | \n",
+ " nan | \n",
+ " -1.579864 | \n",
+ " -1.542524 | \n",
+ " 2.36% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " FSNTOA | \n",
+ " min | \n",
+ " 44.907041 | \n",
+ " 48.256818 | \n",
+ " 7.46% | \n",
+ " 47.223502 | \n",
+ " 50.339608 | \n",
+ " 6.60% | \n",
+ " 44.907041 | \n",
+ " 48.256818 | \n",
+ " 7.46% | \n",
+ " 47.223502 | \n",
+ " 50.339608 | \n",
+ " 6.60% | \n",
+ " -23.576184 | \n",
+ " -23.171864 | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " LHFLX | \n",
+ " max | \n",
+ " 282.280453 | \n",
+ " 289.079940 | \n",
+ " 2.41% | \n",
+ " 275.792933 | \n",
+ " 276.297281 | \n",
+ " nan | \n",
+ " 282.280453 | \n",
+ " 289.079940 | \n",
+ " 2.41% | \n",
+ " 275.792933 | \n",
+ " 276.297281 | \n",
+ " nan | \n",
+ " 47.535503 | \n",
+ " 53.168924 | \n",
+ " 11.85% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " LHFLX | \n",
+ " mean | \n",
+ " 88.379609 | \n",
+ " 88.470270 | \n",
+ " nan | \n",
+ " 88.969550 | \n",
+ " 88.976266 | \n",
+ " nan | \n",
+ " 88.379609 | \n",
+ " 88.470270 | \n",
+ " nan | \n",
+ " 88.969550 | \n",
+ " 88.976266 | \n",
+ " nan | \n",
+ " -0.589942 | \n",
+ " -0.505996 | \n",
+ " 14.23% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " LHFLX | \n",
+ " min | \n",
+ " -0.878371 | \n",
+ " -0.549248 | \n",
+ " 37.47% | \n",
+ " -1.176561 | \n",
+ " -0.946110 | \n",
+ " 19.59% | \n",
+ " -0.878371 | \n",
+ " -0.549248 | \n",
+ " 37.47% | \n",
+ " -1.176561 | \n",
+ " -0.946110 | \n",
+ " 19.59% | \n",
+ " -34.375924 | \n",
+ " -33.902769 | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " LWCF | \n",
+ " max | \n",
+ " 78.493653 | \n",
+ " 77.473220 | \n",
+ " nan | \n",
+ " 86.121959 | \n",
+ " 84.993825 | \n",
+ " nan | \n",
+ " 78.493653 | \n",
+ " 77.473220 | \n",
+ " nan | \n",
+ " 86.121959 | \n",
+ " 84.993825 | \n",
+ " nan | \n",
+ " 9.616057 | \n",
+ " 10.796104 | \n",
+ " 12.27% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " LWCF | \n",
+ " mean | \n",
+ " 24.373224 | \n",
+ " 24.370539 | \n",
+ " nan | \n",
+ " 24.406697 | \n",
+ " 24.391579 | \n",
+ " nan | \n",
+ " 24.373224 | \n",
+ " 24.370539 | \n",
+ " nan | \n",
+ " 24.406697 | \n",
+ " 24.391579 | \n",
+ " nan | \n",
+ " -0.033473 | \n",
+ " -0.021040 | \n",
+ " 37.14% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " LWCF | \n",
+ " min | \n",
+ " -0.667812 | \n",
+ " -0.617107 | \n",
+ " 7.59% | \n",
+ " -1.360010 | \n",
+ " -1.181787 | \n",
+ " 13.10% | \n",
+ " -0.667812 | \n",
+ " -0.617107 | \n",
+ " 7.59% | \n",
+ " -1.360010 | \n",
+ " -1.181787 | \n",
+ " 13.10% | \n",
+ " -10.574643 | \n",
+ " -10.145188 | \n",
+ " 4.06% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " NETCF | \n",
+ " max | \n",
+ " 13.224604 | \n",
+ " 12.621825 | \n",
+ " 4.56% | \n",
+ " 13.715438 | \n",
+ " 13.232716 | \n",
+ " 3.52% | \n",
+ " 13.224604 | \n",
+ " 12.621825 | \n",
+ " 4.56% | \n",
+ " 13.715438 | \n",
+ " 13.232716 | \n",
+ " 3.52% | \n",
+ " 10.899344 | \n",
+ " 10.284825 | \n",
+ " 5.64% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " NETCF | \n",
+ " min | \n",
+ " -66.633044 | \n",
+ " -66.008633 | \n",
+ " nan | \n",
+ " -64.832041 | \n",
+ " -67.398047 | \n",
+ " 3.96% | \n",
+ " -66.633044 | \n",
+ " -66.008633 | \n",
+ " nan | \n",
+ " -64.832041 | \n",
+ " -67.398047 | \n",
+ " 3.96% | \n",
+ " -17.923932 | \n",
+ " -17.940099 | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " NET_FLUX_SRF | \n",
+ " max | \n",
+ " 155.691338 | \n",
+ " 156.424180 | \n",
+ " nan | \n",
+ " 166.556120 | \n",
+ " 166.506173 | \n",
+ " nan | \n",
+ " 155.691338 | \n",
+ " 156.424180 | \n",
+ " nan | \n",
+ " 166.556120 | \n",
+ " 166.506173 | \n",
+ " nan | \n",
+ " 59.819449 | \n",
+ " 61.672824 | \n",
+ " 3.10% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " NET_FLUX_SRF | \n",
+ " mean | \n",
+ " 0.394016 | \n",
+ " 0.516330 | \n",
+ " 31.04% | \n",
+ " -0.068186 | \n",
+ " 0.068584 | \n",
+ " 200.58% | \n",
+ " 0.394016 | \n",
+ " 0.516330 | \n",
+ " 31.04% | \n",
+ " -0.068186 | \n",
+ " 0.068584 | \n",
+ " 200.58% | \n",
+ " 0.462202 | \n",
+ " 0.447746 | \n",
+ " 3.13% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " NET_FLUX_SRF | \n",
+ " min | \n",
+ " -284.505205 | \n",
+ " -299.505024 | \n",
+ " 5.27% | \n",
+ " -280.893287 | \n",
+ " -290.202934 | \n",
+ " 3.31% | \n",
+ " -284.505205 | \n",
+ " -299.505024 | \n",
+ " 5.27% | \n",
+ " -280.893287 | \n",
+ " -290.202934 | \n",
+ " 3.31% | \n",
+ " -75.857589 | \n",
+ " -85.852089 | \n",
+ " 13.18% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " PRECT | \n",
+ " max | \n",
+ " 17.289951 | \n",
+ " 17.071276 | \n",
+ " nan | \n",
+ " 20.264862 | \n",
+ " 20.138274 | \n",
+ " nan | \n",
+ " 17.289951 | \n",
+ " 17.071276 | \n",
+ " nan | \n",
+ " 20.264862 | \n",
+ " 20.138274 | \n",
+ " nan | \n",
+ " 2.344111 | \n",
+ " 2.406625 | \n",
+ " 2.67% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " PRECT | \n",
+ " mean | \n",
+ " 3.053802 | \n",
+ " 3.056760 | \n",
+ " nan | \n",
+ " 3.074885 | \n",
+ " 3.074978 | \n",
+ " nan | \n",
+ " 3.053802 | \n",
+ " 3.056760 | \n",
+ " nan | \n",
+ " 3.074885 | \n",
+ " 3.074978 | \n",
+ " nan | \n",
+ " -0.021083 | \n",
+ " -0.018218 | \n",
+ " 13.59% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " PSL | \n",
+ " min | \n",
+ " 970.981710 | \n",
+ " 971.390765 | \n",
+ " nan | \n",
+ " 973.198437 | \n",
+ " 973.235326 | \n",
+ " nan | \n",
+ " 970.981710 | \n",
+ " 971.390765 | \n",
+ " nan | \n",
+ " 973.198437 | \n",
+ " 973.235326 | \n",
+ " nan | \n",
+ " -6.328677 | \n",
+ " -6.104610 | \n",
+ " 3.54% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " PSL | \n",
+ " rmse | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " 1.042884 | \n",
+ " 0.979981 | \n",
+ " 6.03% | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " RESTOM | \n",
+ " max | \n",
+ " 84.295502 | \n",
+ " 83.821906 | \n",
+ " nan | \n",
+ " 87.707944 | \n",
+ " 87.451262 | \n",
+ " nan | \n",
+ " 84.295502 | \n",
+ " 83.821906 | \n",
+ " nan | \n",
+ " 87.707944 | \n",
+ " 87.451262 | \n",
+ " nan | \n",
+ " 17.396283 | \n",
+ " 21.423616 | \n",
+ " 23.15% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " RESTOM | \n",
+ " mean | \n",
+ " 0.481549 | \n",
+ " 0.656560 | \n",
+ " 36.34% | \n",
+ " 0.018041 | \n",
+ " 0.162984 | \n",
+ " 803.40% | \n",
+ " 0.481549 | \n",
+ " 0.656560 | \n",
+ " 36.34% | \n",
+ " 0.018041 | \n",
+ " 0.162984 | \n",
+ " 803.40% | \n",
+ " 0.463508 | \n",
+ " 0.493576 | \n",
+ " 6.49% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " RESTOM | \n",
+ " min | \n",
+ " -127.667181 | \n",
+ " -129.014673 | \n",
+ " nan | \n",
+ " -127.417586 | \n",
+ " -128.673508 | \n",
+ " nan | \n",
+ " -127.667181 | \n",
+ " -129.014673 | \n",
+ " nan | \n",
+ " -127.417586 | \n",
+ " -128.673508 | \n",
+ " nan | \n",
+ " -15.226249 | \n",
+ " -14.869614 | \n",
+ " 2.34% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " SHFLX | \n",
+ " max | \n",
+ " 114.036895 | \n",
+ " 112.859646 | \n",
+ " nan | \n",
+ " 116.870038 | \n",
+ " 116.432591 | \n",
+ " nan | \n",
+ " 114.036895 | \n",
+ " 112.859646 | \n",
+ " nan | \n",
+ " 116.870038 | \n",
+ " 116.432591 | \n",
+ " nan | \n",
+ " 28.320656 | \n",
+ " 27.556755 | \n",
+ " 2.70% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " SHFLX | \n",
+ " min | \n",
+ " -88.650312 | \n",
+ " -88.386947 | \n",
+ " nan | \n",
+ " -85.809438 | \n",
+ " -85.480377 | \n",
+ " nan | \n",
+ " -88.650312 | \n",
+ " -88.386947 | \n",
+ " nan | \n",
+ " -85.809438 | \n",
+ " -85.480377 | \n",
+ " nan | \n",
+ " -27.776625 | \n",
+ " -28.363053 | \n",
+ " 2.11% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " SST | \n",
+ " min | \n",
+ " -1.788055 | \n",
+ " -1.788055 | \n",
+ " nan | \n",
+ " -1.676941 | \n",
+ " -1.676941 | \n",
+ " nan | \n",
+ " -1.788055 | \n",
+ " -1.788055 | \n",
+ " nan | \n",
+ " -1.676941 | \n",
+ " -1.676941 | \n",
+ " nan | \n",
+ " -4.513070 | \n",
+ " -2.993272 | \n",
+ " 33.68% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " SWCF | \n",
+ " max | \n",
+ " -0.518025 | \n",
+ " -0.536844 | \n",
+ " 3.63% | \n",
+ " -0.311639 | \n",
+ " -0.331616 | \n",
+ " 6.41% | \n",
+ " -0.518025 | \n",
+ " -0.536844 | \n",
+ " 3.63% | \n",
+ " -0.311639 | \n",
+ " -0.331616 | \n",
+ " 6.41% | \n",
+ " 11.668939 | \n",
+ " 12.087077 | \n",
+ " 3.58% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " SWCF | \n",
+ " min | \n",
+ " -123.625017 | \n",
+ " -122.042043 | \n",
+ " nan | \n",
+ " -131.053537 | \n",
+ " -130.430161 | \n",
+ " nan | \n",
+ " -123.625017 | \n",
+ " -122.042043 | \n",
+ " nan | \n",
+ " -131.053537 | \n",
+ " -130.430161 | \n",
+ " nan | \n",
+ " -21.415249 | \n",
+ " -20.808973 | \n",
+ " 2.83% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " TREFHT | \n",
+ " max | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 4.981757 | \n",
+ " 5.126185 | \n",
+ " 2.90% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " TREFHT | \n",
+ " max | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 4.867855 | \n",
+ " 5.126185 | \n",
+ " 2.90% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " TREFHT | \n",
+ " max | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 4.981757 | \n",
+ " 5.126185 | \n",
+ " 5.31% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " TREFHT | \n",
+ " max | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 31.141508 | \n",
+ " 31.058424 | \n",
+ " nan | \n",
+ " 29.819210 | \n",
+ " 29.721868 | \n",
+ " nan | \n",
+ " 4.867855 | \n",
+ " 5.126185 | \n",
+ " 5.31% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " TREFHT | \n",
+ " mean | \n",
+ " 14.769946 | \n",
+ " 14.741707 | \n",
+ " nan | \n",
+ " 13.842013 | \n",
+ " 13.800258 | \n",
+ " nan | \n",
+ " 14.769946 | \n",
+ " 14.741707 | \n",
+ " nan | \n",
+ " 13.842013 | \n",
+ " 13.800258 | \n",
+ " nan | \n",
+ " 0.927933 | \n",
+ " 0.941449 | \n",
+ " 2.28% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " TREFHT | \n",
+ " mean | \n",
+ " 9.214224 | \n",
+ " 9.114572 | \n",
+ " nan | \n",
+ " 8.083349 | \n",
+ " 7.957917 | \n",
+ " nan | \n",
+ " 9.214224 | \n",
+ " 9.114572 | \n",
+ " nan | \n",
+ " 8.083349 | \n",
+ " 7.957917 | \n",
+ " nan | \n",
+ " 1.130876 | \n",
+ " 1.156655 | \n",
+ " 2.28% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " TREFHT | \n",
+ " min | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -0.681558 | \n",
+ " -0.624371 | \n",
+ " 8.39% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " TREFHT | \n",
+ " min | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -0.681558 | \n",
+ " -0.624371 | \n",
+ " 8.39% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " TREFHT | \n",
+ " min | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -0.681558 | \n",
+ " -0.624371 | \n",
+ " 8.39% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " TREFHT | \n",
+ " min | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -56.266677 | \n",
+ " -55.623001 | \n",
+ " nan | \n",
+ " -58.159250 | \n",
+ " -57.542053 | \n",
+ " nan | \n",
+ " -0.681558 | \n",
+ " -0.624371 | \n",
+ " 8.39% | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " TREFHT | \n",
+ " rmse | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " 1.160718 | \n",
+ " 1.179995 | \n",
+ " 2.68% | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " TREFHT | \n",
+ " rmse | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " nan | \n",
+ " 1.343169 | \n",
+ " 1.379141 | \n",
+ " 2.68% | \n",
+ "
\n",
+ " \n",
+ "
\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_final.reset_index(names=[\"var_key\", \"metric\"]).style.map(\n",
+ " lambda x: \"background-color : red\" if isinstance(x, str) else \"\",\n",
+ " subset=pd.IndexSlice[:, PERCENTAGE_COLUMNS],\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "cdat_regression_test",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/e3sm_diags/driver/aerosol_budget_driver.py b/e3sm_diags/driver/aerosol_budget_driver.py
index 9c1de7d00..faf0c4005 100644
--- a/e3sm_diags/driver/aerosol_budget_driver.py
+++ b/e3sm_diags/driver/aerosol_budget_driver.py
@@ -3,6 +3,7 @@
script is integrated in e3sm_diags by Jill Zhang, with input from Kai Zhang,
Taufiq Hassan, Xue Zheng, Ziming Ke, Susannah Burrows, and Naser Mahfouz.
"""
+
from __future__ import annotations
import csv
diff --git a/e3sm_diags/driver/qbo_driver.py b/e3sm_diags/driver/qbo_driver.py
index 3379f4c46..5bc0c5ec2 100644
--- a/e3sm_diags/driver/qbo_driver.py
+++ b/e3sm_diags/driver/qbo_driver.py
@@ -125,6 +125,11 @@ def run_diag(parameter: QboParameter) -> QboParameter:
test_dict["name"] = test_ds._get_test_name()
ref_dict["name"] = ref_ds._get_ref_name()
+ try:
+ ref_dict["name"] = ref_ds._get_ref_name()
+ except AttributeError:
+ ref_dict["name"] = parameter.ref_name
+
_save_metrics_to_json(parameter, test_dict, "test") # type: ignore
_save_metrics_to_json(parameter, ref_dict, "ref") # type: ignore
diff --git a/e3sm_diags/driver/utils/climo_xr.py b/e3sm_diags/driver/utils/climo_xr.py
index bb229048c..acbe73fa2 100644
--- a/e3sm_diags/driver/utils/climo_xr.py
+++ b/e3sm_diags/driver/utils/climo_xr.py
@@ -1,8 +1,8 @@
"""This module stores climatology functions operating on Xarray objects.
-
This file will eventually be refactored to use xCDAT's climatology API.
"""
+
from typing import Dict, List, Literal, get_args
import numpy as np
diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py
index 94b109e66..94c3f69c6 100644
--- a/e3sm_diags/driver/utils/dataset_xr.py
+++ b/e3sm_diags/driver/utils/dataset_xr.py
@@ -1,6 +1,5 @@
"""This module stores the Dataset class, which is the primary class for I/O.
-
This Dataset class operates on `xr.Dataset` objects, which are created using
netCDF files. These `xr.Dataset` contain either the reference or test variable.
This variable can either be from a climatology file or a time series file.
@@ -8,6 +7,7 @@
calculated. Reference and test variables can also be derived using other
variables from dataset files.
"""
+
from __future__ import annotations
import collections
diff --git a/e3sm_diags/metrics/metrics.py b/e3sm_diags/metrics/metrics.py
index 333980643..d98fe519d 100644
--- a/e3sm_diags/metrics/metrics.py
+++ b/e3sm_diags/metrics/metrics.py
@@ -1,4 +1,5 @@
"""This module stores functions to calculate metrics using Xarray objects."""
+
from __future__ import annotations
from typing import List, Literal
diff --git a/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py b/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py
new file mode 100644
index 000000000..765235095
--- /dev/null
+++ b/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py
@@ -0,0 +1,132 @@
+import os
+
+import cartopy.crs as ccrs
+import matplotlib
+import numpy as np
+
+from e3sm_diags.driver.utils.general import get_output_dir
+from e3sm_diags.logger import custom_logger
+from e3sm_diags.metrics import mean
+from e3sm_diags.plot.cartopy.deprecated_lat_lon_plot import plot_panel
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt # isort:skip # noqa: E402
+
+logger = custom_logger(__name__)
+
+plotTitle = {"fontsize": 11.5}
+plotSideTitle = {"fontsize": 9.5}
+
+
+def plot(test, test_site, ref_site, parameter):
+ # Plot scatter plot
+ # Position and sizes of subplot axes in page coordinates (0 to 1)
+ # (left, bottom, width, height) in page coordinates
+ panel = [
+ (0.09, 0.40, 0.72, 0.30),
+ (0.19, 0.2, 0.62, 0.30),
+ ]
+ # Border padding relative to subplot axes for saving individual panels
+ # (left, bottom, right, top) in page coordinates
+ border = (-0.06, -0.03, 0.13, 0.03)
+
+ fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
+ fig.suptitle(parameter.var_id, x=0.5, y=0.97)
+ proj = ccrs.PlateCarree()
+ max1 = test.max()
+ min1 = test.min()
+ mean1 = mean(test)
+ # TODO: Replace this function call with `e3sm_diags.plot.utils._add_colormap()`.
+ plot_panel(
+ 0,
+ fig,
+ proj,
+ test,
+ parameter.contour_levels,
+ parameter.test_colormap,
+ (parameter.test_name_yrs, None, None),
+ parameter,
+ stats=(max1, mean1, min1),
+ )
+
+ ax = fig.add_axes(panel[1])
+ ax.set_title(f"{parameter.var_id} from AERONET sites")
+
+ # define 1:1 line, and x y axis limits
+
+ if parameter.var_id == "AODVIS":
+ x1 = np.arange(0.01, 3.0, 0.1)
+ y1 = np.arange(0.01, 3.0, 0.1)
+ plt.xlim(0.03, 1)
+ plt.ylim(0.03, 1)
+ else:
+ x1 = np.arange(0.0001, 1.0, 0.01)
+ y1 = np.arange(0.0001, 1.0, 0.01)
+ plt.xlim(0.001, 0.3)
+ plt.ylim(0.001, 0.3)
+
+ plt.loglog(x1, y1, "-k", linewidth=0.5)
+ plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5)
+ plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5)
+
+ corr = np.corrcoef(ref_site, test_site)
+ xmean = np.mean(ref_site)
+ ymean = np.mean(test_site)
+ ax.text(
+ 0.3,
+ 0.9,
+ f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}",
+ horizontalalignment="right",
+ verticalalignment="top",
+ transform=ax.transAxes,
+ )
+
+ # axis ticks
+ plt.tick_params(axis="both", which="major")
+ plt.tick_params(axis="both", which="minor")
+
+ # axis labels
+ plt.xlabel(f"ref: {parameter.ref_name_yrs}")
+ plt.ylabel(f"test: {parameter.test_name_yrs}")
+
+ plt.loglog(ref_site, test_site, "kx", markersize=3.0, mfc="none")
+
+ # legend
+ plt.legend(frameon=False, prop={"size": 5})
+
+ # TODO: This section can be refactored to use `plot.utils._save_plot()`.
+ for f in parameter.output_format:
+ f = f.lower().split(".")[-1]
+ fnm = os.path.join(
+ get_output_dir(parameter.current_set, parameter),
+ f"{parameter.output_file}" + "." + f,
+ )
+ plt.savefig(fnm)
+ logger.info(f"Plot saved in: {fnm}")
+
+ for f in parameter.output_format_subplot:
+ fnm = os.path.join(
+ get_output_dir(parameter.current_set, parameter),
+ parameter.output_file,
+ )
+ page = fig.get_size_inches()
+ i = 0
+ for p in panel:
+ # Extent of subplot
+ subpage = np.array(p).reshape(2, 2)
+ subpage[1, :] = subpage[0, :] + subpage[1, :]
+ subpage = subpage + np.array(border).reshape(2, 2)
+ subpage = list(((subpage) * page).flatten()) # type: ignore
+ extent = matplotlib.transforms.Bbox.from_extents(*subpage)
+ # Save subplot
+ fname = fnm + ".%i." % (i) + f
+ plt.savefig(fname, bbox_inches=extent)
+
+ orig_fnm = os.path.join(
+ get_output_dir(parameter.current_set, parameter),
+ parameter.output_file,
+ )
+ fname = orig_fnm + ".%i." % (i) + f
+ logger.info(f"Sub-plot saved in: {fname}")
+
+ i += 1
diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py
new file mode 100644
index 000000000..a72bf5dce
--- /dev/null
+++ b/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py
@@ -0,0 +1,187 @@
+from typing import List, Optional, Tuple
+
+import matplotlib
+import numpy as np
+import xarray as xr
+import xcdat as xc
+
+from e3sm_diags.driver.utils.type_annotations import MetricsDict
+from e3sm_diags.logger import custom_logger
+from e3sm_diags.parameter.core_parameter import CoreParameter
+from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS
+from e3sm_diags.plot.utils import (
+ DEFAULT_PANEL_CFG,
+ _add_colorbar,
+ _add_contour_plot,
+ _add_min_mean_max_text,
+ _add_rmse_corr_text,
+ _configure_titles,
+ _configure_x_and_y_axes,
+ _get_c_levels_and_norm,
+ _save_plot,
+)
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt # isort:skip # noqa: E402
+
+logger = custom_logger(__name__)
+
+
+# Configs for x axis ticks and x axis limits.
+X_TICKS = np.array([-90, -60, -30, 0, 30, 60, 90])
+X_LIM = -90, 90
+
+
+def plot(
+ parameter: CoreParameter,
+ da_test: xr.DataArray,
+ da_ref: xr.DataArray,
+ da_diff: xr.DataArray,
+ metrics_dict: MetricsDict,
+):
+ """Plot the variable's metrics generated by the zonal_mean_2d set.
+
+ Parameters
+ ----------
+ parameter : CoreParameter
+ The CoreParameter object containing plot configurations.
+ da_test : xr.DataArray
+ The test data.
+ da_ref : xr.DataArray
+ The reference data.
+ da_diff : xr.DataArray
+ The difference between `da_test` and `da_ref` (both are regridded to
+ the lower resolution of the two beforehand).
+ metrics_dict : Metrics
+ The metrics.
+ """
+ fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
+ fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18)
+
+ # The variable units.
+ units = metrics_dict["units"]
+
+ # Add the first subplot for test data.
+ min1 = metrics_dict["test"]["min"] # type: ignore
+ mean1 = metrics_dict["test"]["mean"] # type: ignore
+ max1 = metrics_dict["test"]["max"] # type: ignore
+
+ _add_colormap(
+ 0,
+ da_test,
+ fig,
+ parameter,
+ parameter.test_colormap,
+ parameter.contour_levels,
+ title=(parameter.test_name_yrs, parameter.test_title, units), # type: ignore
+ metrics=(max1, mean1, min1), # type: ignore
+ )
+
+ # Add the second and third subplots for ref data and the differences,
+ # respectively.
+ min2 = metrics_dict["ref"]["min"] # type: ignore
+ mean2 = metrics_dict["ref"]["mean"] # type: ignore
+ max2 = metrics_dict["ref"]["max"] # type: ignore
+
+ _add_colormap(
+ 1,
+ da_ref,
+ fig,
+ parameter,
+ parameter.reference_colormap,
+ parameter.contour_levels,
+ title=(parameter.ref_name_yrs, parameter.reference_title, units), # type: ignore
+ metrics=(max2, mean2, min2), # type: ignore
+ )
+
+ min3 = metrics_dict["diff"]["min"] # type: ignore
+ mean3 = metrics_dict["diff"]["mean"] # type: ignore
+ max3 = metrics_dict["diff"]["max"] # type: ignore
+ r = metrics_dict["misc"]["rmse"] # type: ignore
+ c = metrics_dict["misc"]["corr"] # type: ignore
+
+ _add_colormap(
+ 2,
+ da_diff,
+ fig,
+ parameter,
+ parameter.diff_colormap,
+ parameter.diff_levels,
+ title=(None, parameter.diff_title, da_diff.attrs["units"]), #
+ metrics=(max3, mean3, min3, r, c), # type: ignore
+ )
+
+ _save_plot(fig, parameter)
+
+ plt.close()
+
+
+def _add_colormap(
+ subplot_num: int,
+ var: xr.DataArray,
+ fig: plt.Figure,
+ parameter: CoreParameter,
+ color_map: str,
+ contour_levels: List[float],
+ title: Tuple[Optional[str], str, str],
+ metrics: Tuple[float, ...],
+):
+ lat = xc.get_dim_coords(var, axis="Y")
+ plev = xc.get_dim_coords(var, axis="Z")
+ var = var.squeeze()
+
+ # Configure contour levels
+ # --------------------------------------------------------------------------
+ c_levels, norm = _get_c_levels_and_norm(contour_levels)
+
+ # Add the contour plot
+ # --------------------------------------------------------------------------
+ ax = fig.add_axes(DEFAULT_PANEL_CFG[subplot_num], projection=None)
+
+ contour_plot = _add_contour_plot(
+ ax, parameter, var, lat, plev, color_map, None, norm, c_levels
+ )
+
+ # Configure the aspect ratio and plot titles.
+ # --------------------------------------------------------------------------
+ ax.set_aspect("auto")
+ _configure_titles(ax, title)
+
+ # Configure x and y axis.
+ # --------------------------------------------------------------------------
+ _configure_x_and_y_axes(ax, X_TICKS, None, None, parameter.current_set)
+ ax.set_xlim(X_LIM)
+
+ if parameter.plot_log_plevs:
+ ax.set_yscale("log")
+
+ if parameter.plot_plevs:
+ plev_ticks = parameter.plevs
+ plt.yticks(plev_ticks, plev_ticks)
+
+ # For default plevs, specify the pressure axis and show the 50 mb tick
+ # at the top.
+ if (
+ not parameter.plot_log_plevs
+ and not parameter.plot_plevs
+ and parameter.plevs == DEFAULT_PLEVS
+ ):
+ plev_ticks = parameter.plevs
+ new_ticks = [plev_ticks[0]] + plev_ticks[1::2]
+ new_ticks = [int(x) for x in new_ticks]
+ plt.yticks(new_ticks, new_ticks)
+
+ plt.ylabel("pressure (mb)")
+ ax.invert_yaxis()
+
+ # Add and configure the color bar.
+ # --------------------------------------------------------------------------
+ _add_colorbar(fig, subplot_num, DEFAULT_PANEL_CFG, contour_plot, c_levels)
+
+ # Add metrics text.
+ # --------------------------------------------------------------------------
+ # Min, Mean, Max
+ _add_min_mean_max_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)
+
+ if len(metrics) == 5:
+ _add_rmse_corr_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)
diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py
new file mode 100644
index 000000000..004f3c93d
--- /dev/null
+++ b/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py
@@ -0,0 +1,15 @@
+import xarray as xr
+
+from e3sm_diags.driver.utils.type_annotations import MetricsDict
+from e3sm_diags.parameter.core_parameter import CoreParameter
+from e3sm_diags.plot.cartopy.zonal_mean_2d_plot import plot as base_plot
+
+
+def plot(
+ parameter: CoreParameter,
+ da_test: xr.DataArray,
+ da_ref: xr.DataArray,
+ da_diff: xr.DataArray,
+ metrics_dict: MetricsDict,
+):
+ return base_plot(parameter, da_test, da_ref, da_diff, metrics_dict)
diff --git a/e3sm_diags/plot/deprecated_lat_lon_plot.py b/e3sm_diags/plot/deprecated_lat_lon_plot.py
new file mode 100644
index 000000000..4eaebcf80
--- /dev/null
+++ b/e3sm_diags/plot/deprecated_lat_lon_plot.py
@@ -0,0 +1,360 @@
+"""
+WARNING: This module has been deprecated and replaced by
+`e3sm_diags.plot.lat_lon_plot.py`. This file temporarily kept because
+`e3sm_diags.plot.cartopy.aerosol_aeronet_plot.plot` references the
+`plot_panel()` function. Once the aerosol_aeronet set is refactored, this
+file can be deleted.
+"""
+from __future__ import print_function
+
+import os
+
+import cartopy.crs as ccrs
+import cartopy.feature as cfeature
+import cdutil
+import matplotlib
+import numpy as np
+import numpy.ma as ma
+from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
+
+from e3sm_diags.derivations.default_regions import regions_specs
+from e3sm_diags.driver.utils.general import get_output_dir
+from e3sm_diags.logger import custom_logger
+from e3sm_diags.plot import get_colormap
+
+matplotlib.use("Agg")
+import matplotlib.colors as colors # isort:skip # noqa: E402
+import matplotlib.pyplot as plt # isort:skip # noqa: E402
+
+logger = custom_logger(__name__)
+
+plotTitle = {"fontsize": 11.5}
+plotSideTitle = {"fontsize": 9.5}
+
+# Position and sizes of subplot axes in page coordinates (0 to 1)
+panel = [
+ (0.1691, 0.6810, 0.6465, 0.2258),
+ (0.1691, 0.3961, 0.6465, 0.2258),
+ (0.1691, 0.1112, 0.6465, 0.2258),
+]
+
+# Border padding relative to subplot axes for saving individual panels
+# (left, bottom, right, top) in page coordinates
+border = (-0.06, -0.03, 0.13, 0.03)
+
+
+def add_cyclic(var):
+ lon = var.getLongitude()
+ return var(longitude=(lon[0], lon[0] + 360.0, "coe"))
+
+
+def get_ax_size(fig, ax):
+ bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+ width, height = bbox.width, bbox.height
+ width *= fig.dpi
+ height *= fig.dpi
+ return width, height
+
+
+def determine_tick_step(degrees_covered):
+ if degrees_covered > 180:
+ return 60
+ if degrees_covered > 60:
+ return 30
+ elif degrees_covered > 30:
+ return 10
+ elif degrees_covered > 20:
+ return 5
+ else:
+ return 1
+
+
+def plot_panel( # noqa: C901
+ n, fig, proj, var, clevels, cmap, title, parameters, stats=None
+):
+ var = add_cyclic(var)
+ lon = var.getLongitude()
+ lat = var.getLatitude()
+ var = ma.squeeze(var.asma())
+
+ # Contour levels
+ levels = None
+ norm = None
+ if len(clevels) > 0:
+ levels = [-1.0e8] + clevels + [1.0e8]
+ norm = colors.BoundaryNorm(boundaries=levels, ncolors=256)
+
+ # ax.set_global()
+ region_str = parameters.regions[0]
+ region = regions_specs[region_str]
+ global_domain = True
+ full_lon = True
+ if "domain" in region.keys(): # type: ignore
+ # Get domain to plot
+ domain = region["domain"] # type: ignore
+ global_domain = False
+ else:
+ # Assume global domain
+ domain = cdutil.region.domain(latitude=(-90.0, 90, "ccb"))
+ kargs = domain.components()[0].kargs
+ lon_west, lon_east, lat_south, lat_north = (0, 360, -90, 90)
+ if "longitude" in kargs:
+ full_lon = False
+ lon_west, lon_east, _ = kargs["longitude"]
+ # Note cartopy Problem with gridlines across the dateline:https://github.com/SciTools/cartopy/issues/821. Region cross dateline is not supported yet.
+ if lon_west > 180 and lon_east > 180:
+ lon_west = lon_west - 360
+ lon_east = lon_east - 360
+
+ if "latitude" in kargs:
+ lat_south, lat_north, _ = kargs["latitude"]
+ lon_covered = lon_east - lon_west
+ lon_step = determine_tick_step(lon_covered)
+ xticks = np.arange(lon_west, lon_east, lon_step)
+ # Subtract 0.50 to get 0 W to show up on the right side of the plot.
+ # If less than 0.50 is subtracted, then 0 W will overlap 0 E on the left side of the plot.
+ # If a number is added, then the value won't show up at all.
+ if global_domain or full_lon:
+ xticks = np.append(xticks, lon_east - 0.50)
+ proj = ccrs.PlateCarree(central_longitude=180)
+ else:
+ xticks = np.append(xticks, lon_east)
+ lat_covered = lat_north - lat_south
+ lat_step = determine_tick_step(lat_covered)
+ yticks = np.arange(lat_south, lat_north, lat_step)
+ yticks = np.append(yticks, lat_north)
+
+ # Contour plot
+ ax = fig.add_axes(panel[n], projection=proj)
+ ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=proj)
+ cmap = get_colormap(cmap, parameters)
+ p1 = ax.contourf(
+ lon,
+ lat,
+ var,
+ transform=ccrs.PlateCarree(),
+ norm=norm,
+ levels=levels,
+ cmap=cmap,
+ extend="both",
+ )
+
+ # ax.set_aspect('auto')
+ # Full world would be aspect 360/(2*180) = 1
+ ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south)))
+ ax.coastlines(lw=0.3)
+ if not global_domain and "RRM" in region_str:
+ ax.coastlines(resolution="50m", color="black", linewidth=1)
+ state_borders = cfeature.NaturalEarthFeature(
+ category="cultural",
+ name="admin_1_states_provinces_lakes",
+ scale="50m",
+ facecolor="none",
+ )
+ ax.add_feature(state_borders, edgecolor="black")
+ if title[0] is not None:
+ ax.set_title(title[0], loc="left", fontdict=plotSideTitle)
+ if title[1] is not None:
+ ax.set_title(title[1], fontdict=plotTitle)
+ if title[2] is not None:
+ ax.set_title(title[2], loc="right", fontdict=plotSideTitle)
+ ax.set_xticks(xticks, crs=ccrs.PlateCarree())
+ ax.set_yticks(yticks, crs=ccrs.PlateCarree())
+ lon_formatter = LongitudeFormatter(zero_direction_label=True, number_format=".0f")
+ lat_formatter = LatitudeFormatter()
+ ax.xaxis.set_major_formatter(lon_formatter)
+ ax.yaxis.set_major_formatter(lat_formatter)
+ ax.tick_params(labelsize=8.0, direction="out", width=1)
+ ax.xaxis.set_ticks_position("bottom")
+ ax.yaxis.set_ticks_position("left")
+
+ # Color bar
+ cbax = fig.add_axes((panel[n][0] + 0.6635, panel[n][1] + 0.0215, 0.0326, 0.1792))
+ cbar = fig.colorbar(p1, cax=cbax)
+ w, h = get_ax_size(fig, cbax)
+
+ if levels is None:
+ cbar.ax.tick_params(labelsize=9.0, length=0)
+
+ else:
+ maxval = np.amax(np.absolute(levels[1:-1]))
+ if maxval < 0.2:
+ fmt = "%5.3f"
+ pad = 28
+ elif maxval < 10.0:
+ fmt = "%5.2f"
+ pad = 25
+ elif maxval < 100.0:
+ fmt = "%5.1f"
+ pad = 25
+ elif maxval > 9999.0:
+ fmt = "%.0f"
+ pad = 40
+ else:
+ fmt = "%6.1f"
+ pad = 30
+
+ cbar.set_ticks(levels[1:-1])
+ labels = [fmt % level for level in levels[1:-1]]
+ cbar.ax.set_yticklabels(labels, ha="right")
+ cbar.ax.tick_params(labelsize=9.0, pad=pad, length=0)
+
+ # Min, Mean, Max
+ fig.text(
+ panel[n][0] + 0.6635,
+ panel[n][1] + 0.2107,
+ "Max\nMean\nMin",
+ ha="left",
+ fontdict=plotSideTitle,
+ )
+
+ fmt_m = []
+ # printing in scientific notation if value greater than 10^5
+ for i in range(len(stats[0:3])):
+ fs = "1e" if stats[i] > 100000.0 else "2f"
+ fmt_m.append(fs)
+ fmt_metrics = f"%.{fmt_m[0]}\n%.{fmt_m[1]}\n%.{fmt_m[2]}"
+
+ fig.text(
+ panel[n][0] + 0.7635,
+ panel[n][1] + 0.2107,
+ # "%.2f\n%.2f\n%.2f" % stats[0:3],
+ fmt_metrics % stats[0:3],
+ ha="right",
+ fontdict=plotSideTitle,
+ )
+
+ # RMSE, CORR
+ if len(stats) == 5:
+ fig.text(
+ panel[n][0] + 0.6635,
+ panel[n][1] - 0.0105,
+ "RMSE\nCORR",
+ ha="left",
+ fontdict=plotSideTitle,
+ )
+ fig.text(
+ panel[n][0] + 0.7635,
+ panel[n][1] - 0.0105,
+ "%.2f\n%.2f" % stats[3:5],
+ ha="right",
+ fontdict=plotSideTitle,
+ )
+
+ # grid resolution info:
+ if n == 2 and "RRM" in region_str:
+ dlat = lat[2] - lat[1]
+ dlon = lon[2] - lon[1]
+ fig.text(
+ panel[n][0] + 0.4635,
+ panel[n][1] - 0.04,
+ "Resolution: {:.2f}x{:.2f}".format(dlat, dlon),
+ ha="left",
+ fontdict=plotSideTitle,
+ )
+
+
+def plot(reference, test, diff, metrics_dict, parameter):
+ # Create figure, projection
+ fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
+ proj = ccrs.PlateCarree()
+
+ # Figure title
+ fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18)
+
+ # First two panels
+ min1 = metrics_dict["test"]["min"]
+ mean1 = metrics_dict["test"]["mean"]
+ max1 = metrics_dict["test"]["max"]
+
+ plot_panel(
+ 0,
+ fig,
+ proj,
+ test,
+ parameter.contour_levels,
+ parameter.test_colormap,
+ (parameter.test_name_yrs, parameter.test_title, test.units),
+ parameter,
+ stats=(max1, mean1, min1),
+ )
+
+ if not parameter.model_only:
+ min2 = metrics_dict["ref"]["min"]
+ mean2 = metrics_dict["ref"]["mean"]
+ max2 = metrics_dict["ref"]["max"]
+
+ plot_panel(
+ 1,
+ fig,
+ proj,
+ reference,
+ parameter.contour_levels,
+ parameter.reference_colormap,
+ (parameter.ref_name_yrs, parameter.reference_title, reference.units),
+ parameter,
+ stats=(max2, mean2, min2),
+ )
+
+ # Third panel
+ min3 = metrics_dict["diff"]["min"]
+ mean3 = metrics_dict["diff"]["mean"]
+ max3 = metrics_dict["diff"]["max"]
+ r = metrics_dict["misc"]["rmse"]
+ c = metrics_dict["misc"]["corr"]
+ plot_panel(
+ 2,
+ fig,
+ proj,
+ diff,
+ parameter.diff_levels,
+ parameter.diff_colormap,
+ (None, parameter.diff_title, test.units),
+ parameter,
+ stats=(max3, mean3, min3, r, c),
+ )
+
+ # Save figure
+ for f in parameter.output_format:
+ f = f.lower().split(".")[-1]
+ fnm = os.path.join(
+ get_output_dir(parameter.current_set, parameter),
+ parameter.output_file + "." + f,
+ )
+ plt.savefig(fnm)
+ logger.info(f"Plot saved in: {fnm}")
+
+ # Save individual subplots
+ if parameter.ref_name == "":
+ panels = [panel[0]]
+ else:
+ panels = panel
+
+ for f in parameter.output_format_subplot:
+ fnm = os.path.join(
+ get_output_dir(parameter.current_set, parameter),
+ parameter.output_file,
+ )
+ page = fig.get_size_inches()
+ i = 0
+ for p in panels:
+ # Extent of subplot
+ subpage = np.array(p).reshape(2, 2)
+ subpage[1, :] = subpage[0, :] + subpage[1, :]
+ subpage = subpage + np.array(border).reshape(2, 2)
+ subpage = list(((subpage) * page).flatten()) # type: ignore
+ extent = matplotlib.transforms.Bbox.from_extents(*subpage)
+ # Save subplot
+ fname = fnm + ".%i." % (i) + f
+ plt.savefig(fname, bbox_inches=extent)
+
+ orig_fnm = os.path.join(
+ get_output_dir(parameter.current_set, parameter),
+ parameter.output_file,
+ )
+ fname = orig_fnm + ".%i." % (i) + f
+ logger.info(f"Sub-plot saved in: {fname}")
+
+ i += 1
+
+ plt.close()
diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py
index 2c89c92cb..03501dc95 100644
--- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py
+++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py
@@ -561,6 +561,63 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest
@pytest.mark.xfail(
reason="Need to figure out why to create dummy incorrect time scalar variable with Xarray."
)
+ def test_returns_climo_dataset_with_derived_variable(self):
+ # We will derive the "PRECT" variable using the "pr" variable.
+ ds_pr = xr.Dataset(
+ coords={
+ **spatial_coords,
+ "time": xr.DataArray(
+ dims="time",
+ data=np.array(
+ [
+ cftime.DatetimeGregorian(
+ 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False
+ ),
+ ],
+ dtype=object,
+ ),
+ attrs={
+ "axis": "T",
+ "long_name": "time",
+ "standard_name": "time",
+ "bounds": "time_bnds",
+ },
+ ),
+ },
+ data_vars={
+ **spatial_bounds,
+ "pr": xr.DataArray(
+ xr.DataArray(
+ data=np.array(
+ [
+ [[1.0, 1.0], [1.0, 1.0]],
+ ]
+ ),
+ dims=["time", "lat", "lon"],
+ attrs={"units": "mm/s"},
+ )
+ ),
+ },
+ )
+
+ parameter = _create_parameter_object(
+ "ref", "climo", self.data_path, "2000", "2001"
+ )
+ parameter.ref_file = "pr_200001_200112.nc"
+ ds_pr.to_netcdf(f"{self.data_path}/{parameter.ref_file}")
+
+ ds = Dataset(parameter, data_type="ref")
+
+ result = ds.get_climo_dataset("PRECT", season="ANN")
+ expected = ds_pr.copy()
+ expected = expected.squeeze(dim="time").drop_vars("time")
+ expected["PRECT"] = expected["pr"] * 3600 * 24
+ expected["PRECT"].attrs["units"] = "mm/day"
+ expected = expected.drop_vars("pr")
+
+ xr.testing.assert_identical(result, expected)
+
+ @pytest.mark.xfail
def test_returns_climo_dataset_using_derived_var_directly_from_dataset_and_replaces_scalar_time_var(
self,
):
diff --git a/tests/e3sm_diags/driver/utils/test_regrid.py b/tests/e3sm_diags/driver/utils/test_regrid.py
index 6dc33fcda..c02451345 100644
--- a/tests/e3sm_diags/driver/utils/test_regrid.py
+++ b/tests/e3sm_diags/driver/utils/test_regrid.py
@@ -231,7 +231,6 @@ def test_regrids_to_first_dataset_with_equal_latitude_points(self, tool):
expected_a = ds_a.copy()
expected_b = ds_a.copy()
-
if tool in ["esmf", "xesmf"]:
expected_b.so.attrs["regrid_method"] = "conservative"