diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index d19f29a5b..f3557732c 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -5,7 +5,7 @@ on: branches: [main] pull_request: - branches: [main] + branches: [main, cdat-migration-fy24] workflow_dispatch: diff --git a/auxiliary_tools/template_cdat_regression_test.ipynb b/auxiliary_tools/template_cdat_regression_test.ipynb new file mode 100644 index 000000000..8b4d00bd1 --- /dev/null +++ b/auxiliary_tools/template_cdat_regression_test.ipynb @@ -0,0 +1,1333 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CDAT Migration Regression Test (FY24)\n", + "\n", + "This notebook is used to perform regression testing between the development and\n", + "production versions of a diagnostic set.\n", + "\n", + "## How it works\n", + "\n", + "It compares the relative differences (%) between two sets of `.json` files in two\n", + "separate directories, one for the refactored code and the other for the `main` branch.\n", + "\n", + "It will display metrics values with relative differences >= 2%. Relative differences are used instead of absolute differences because:\n", + "\n", + "- Relative differences are in percentages, which shows the scale of the differences.\n", + "- Absolute differences are just a raw number that doesn't factor in\n", + " floating point size (e.g., 100.00 vs. 0.0001), which can be misleading.\n", + "\n", + "## How to use\n", + "\n", + "PREREQUISITE: The diagnostic set's metrics stored in `.json` files in two directories\n", + "(dev and `main` branches).\n", + "\n", + "1. Make a copy of this notebook.\n", + "2. Run `mamba create -n cdat_regression_test -y -c conda-forge \"python<3.12\" pandas matplotlib-base ipykernel`\n", + "3. Run `mamba activate cdat_regression_test`\n", + "4. Update `DEV_PATH` and `PROD_PATH` in the copy of your notebook.\n", + "5. Run all cells IN ORDER.\n", + "6. Review results for any outstanding differences (>= 2%).\n", + " - Debug these differences (e.g., bug in metrics functions, incorrect variable references, etc.)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import math\n", + "from typing import List\n", + "\n", + "import pandas as pd\n", + "\n", + "# TODO: Update DEV_RESULTS and PROD_RESULTS to your diagnostic sets.\n", + "DEV_PATH = \"/global/cfs/cdirs/e3sm/www/vo13/examples_658/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model\"\n", + "PROD_PATH = \"/global/cfs/cdirs/e3sm/www/vo13/examples/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model\"\n", + "\n", + "DEV_GLOB = sorted(glob.glob(DEV_PATH + \"/*.json\"))\n", + "PROD_GLOB = sorted(glob.glob(PROD_PATH + \"/*.json\"))\n", + "\n", + "# The names of the columns that store percentage difference values.\n", + "PERCENTAGE_COLUMNS = [\n", + " \"test DIFF (%)\",\n", + " \"ref DIFF (%)\",\n", + " \"test_regrid DIFF (%)\",\n", + " \"ref_regrid DIFF (%)\",\n", + " \"diff DIFF (%)\",\n", + " \"misc DIFF (%)\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Core Functions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def get_metrics(filepaths: List[str]) -> pd.DataFrame:\n", + " \"\"\"Get the metrics using a glob of `.json` metric files in a directory.\n", + "\n", + " Parameters\n", + " ----------\n", + " filepaths : List[str]\n", + " The filepaths for metrics `.json` files.\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame\n", + " The DataFrame containing the metrics for all of the variables in\n", + " the results directory.\n", + " \"\"\"\n", + " metrics = []\n", + "\n", + " for filepath in filepaths:\n", + " df = pd.read_json(filepath)\n", + "\n", + " filename = filepath.split(\"/\")[-1]\n", + " var_key = filename.split(\"-\")[1]\n", + "\n", + " # Add the variable key to the MultiIndex and update the index\n", + " # before stacking to make the DataFrame easier to parse.\n", + " multiindex = pd.MultiIndex.from_product([[var_key], [*df.index]])\n", + " df = df.set_index(multiindex)\n", + " df.stack()\n", + "\n", + " metrics.append(df)\n", + "\n", + " df_final = pd.concat(metrics)\n", + "\n", + " # Reorder columns and drop \"unit\" column (string dtype breaks Pandas\n", + " # arithmetic).\n", + " df_final = df_final[[\"test\", \"ref\", \"test_regrid\", \"ref_regrid\", \"diff\", \"misc\"]]\n", + "\n", + " return df_final\n", + "\n", + "\n", + "def get_rel_diffs(df_actual: pd.DataFrame, df_reference: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Get the relative differences between two DataFrames.\n", + "\n", + " Formula: abs(actual - reference) / abs(actual)\n", + "\n", + " Parameters\n", + " ----------\n", + " df_actual : pd.DataFrame\n", + " The first DataFrame representing \"actual\" results (dev branch).\n", + " df_reference : pd.DataFrame\n", + " The second DataFrame representing \"reference\" results (main branch).\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame\n", + " The DataFrame containing absolute and relative differences between\n", + " the metrics DataFrames.\n", + " \"\"\"\n", + " df_diff = abs(df_actual - df_reference) / abs(df_actual)\n", + " df_diff = df_diff.add_suffix(\" DIFF (%)\")\n", + "\n", + " return df_diff\n", + "\n", + "\n", + "def sort_columns(df: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sorts the order of the columns for the final DataFrame output.\n", + "\n", + " Parameters\n", + " ----------\n", + " df : pd.DataFrame\n", + " The final DataFrame output.\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame\n", + " The final DataFrame output with sorted columns.\n", + " \"\"\"\n", + " columns = [\n", + " \"test_dev\",\n", + " \"test_prod\",\n", + " \"test DIFF (%)\",\n", + " \"ref_dev\",\n", + " \"ref_prod\",\n", + " \"ref DIFF (%)\",\n", + " \"test_regrid_dev\",\n", + " \"test_regrid_prod\",\n", + " \"test_regrid DIFF (%)\",\n", + " \"ref_regrid_dev\",\n", + " \"ref_regrid_prod\",\n", + " \"ref_regrid DIFF (%)\",\n", + " \"diff_dev\",\n", + " \"diff_prod\",\n", + " \"diff DIFF (%)\",\n", + " \"misc_dev\",\n", + " \"misc_prod\",\n", + " \"misc DIFF (%)\",\n", + " ]\n", + "\n", + " df_new = df.copy()\n", + " df_new = df_new[columns]\n", + "\n", + " return df_new\n", + "\n", + "\n", + "def update_diffs_to_pct(df: pd.DataFrame):\n", + " \"\"\"Update relative diff columns from float to string percentage.\n", + "\n", + " Parameters\n", + " ----------\n", + " df : pd.DataFrame\n", + " The final DataFrame containing metrics and diffs (floats).\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame\n", + " The final DataFrame containing metrics and diffs (str percentage).\n", + " \"\"\"\n", + " df_new = df.copy()\n", + " df_new[PERCENTAGE_COLUMNS] = df_new[PERCENTAGE_COLUMNS].map(\n", + " lambda x: \"{0:.2f}%\".format(x * 100) if not math.isnan(x) else x\n", + " )\n", + "\n", + " return df_new" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Get the DataFrame containing development and production metrics.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "df_metrics_dev = get_metrics(DEV_GLOB)\n", + "df_metrics_prod = get_metrics(PROD_GLOB)\n", + "df_metrics_all = pd.concat(\n", + " [df_metrics_dev.add_suffix(\"_dev\"), df_metrics_prod.add_suffix(\"_prod\")],\n", + " axis=1,\n", + " join=\"outer\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Get DataFrame for differences >= 2%.\n", + "\n", + "- Get the relative differences for all metrics\n", + "- Filter down metrics to those with differences >= 2%\n", + " - If all cells in a row are NaN (< 2%), the entire row is dropped to make the results easier to parse.\n", + " - Any remaining NaN cells are below < 2% difference and **should be ignored**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "df_metrics_diffs = get_rel_diffs(df_metrics_dev, df_metrics_prod)\n", + "df_metrics_diffs_thres = df_metrics_diffs[df_metrics_diffs >= 0.02]\n", + "df_metrics_diffs_thres = df_metrics_diffs_thres.dropna(\n", + " axis=0, how=\"all\", ignore_index=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Combine both DataFrames to get the final result.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "df_final = df_metrics_diffs_thres.join(df_metrics_all)\n", + "df_final = sort_columns(df_final)\n", + "df_final = update_diffs_to_pct(df_final)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Display final DataFrame and review results.\n", + "\n", + "- Red cells are differences >= 2%\n", + "- `nan` cells are differences < 2% and **should be ignored**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 var_keymetrictest_devtest_prodtest DIFF (%)ref_devref_prodref DIFF (%)test_regrid_devtest_regrid_prodtest_regrid DIFF (%)ref_regrid_devref_regrid_prodref_regrid DIFF (%)diff_devdiff_proddiff DIFF (%)misc_devmisc_prodmisc DIFF (%)
0FLUTmax299.911864299.355074nan300.162128299.776167nan299.911864299.355074nan300.162128299.776167nan9.4923599.7888093.12%nannannan
1FLUTmin124.610884125.987072nan122.878196124.148986nan124.610884125.987072nan122.878196124.148986nan-15.505809-17.0323259.84%nannannan
2FSNSmax269.789702269.798166nan272.722362272.184917nan269.789702269.798166nan272.722362272.184917nan20.64792924.85985220.40%nannannan
3FSNSmin16.89742317.7608895.11%16.71013416.2370612.83%16.89742317.7608895.11%16.71013416.2370612.83%-28.822277-28.324921nannannannan
4FSNTOAmax360.624327360.209193nan362.188816361.778529nan360.624327360.209193nan362.188816361.778529nan18.60227622.62426621.62%nannannan
5FSNTOAmean239.859777240.001860nan241.439641241.544384nan239.859777240.001860nan241.439641241.544384nan-1.579864-1.5425242.36%nannannan
6FSNTOAmin44.90704148.2568187.46%47.22350250.3396086.60%44.90704148.2568187.46%47.22350250.3396086.60%-23.576184-23.171864nannannannan
7LHFLXmax282.280453289.0799402.41%275.792933276.297281nan282.280453289.0799402.41%275.792933276.297281nan47.53550353.16892411.85%nannannan
8LHFLXmean88.37960988.470270nan88.96955088.976266nan88.37960988.470270nan88.96955088.976266nan-0.589942-0.50599614.23%nannannan
9LHFLXmin-0.878371-0.54924837.47%-1.176561-0.94611019.59%-0.878371-0.54924837.47%-1.176561-0.94611019.59%-34.375924-33.902769nannannannan
10LWCFmax78.49365377.473220nan86.12195984.993825nan78.49365377.473220nan86.12195984.993825nan9.61605710.79610412.27%nannannan
11LWCFmean24.37322424.370539nan24.40669724.391579nan24.37322424.370539nan24.40669724.391579nan-0.033473-0.02104037.14%nannannan
12LWCFmin-0.667812-0.6171077.59%-1.360010-1.18178713.10%-0.667812-0.6171077.59%-1.360010-1.18178713.10%-10.574643-10.1451884.06%nannannan
13NETCFmax13.22460412.6218254.56%13.71543813.2327163.52%13.22460412.6218254.56%13.71543813.2327163.52%10.89934410.2848255.64%nannannan
14NETCFmin-66.633044-66.008633nan-64.832041-67.3980473.96%-66.633044-66.008633nan-64.832041-67.3980473.96%-17.923932-17.940099nannannannan
15NET_FLUX_SRFmax155.691338156.424180nan166.556120166.506173nan155.691338156.424180nan166.556120166.506173nan59.81944961.6728243.10%nannannan
16NET_FLUX_SRFmean0.3940160.51633031.04%-0.0681860.068584200.58%0.3940160.51633031.04%-0.0681860.068584200.58%0.4622020.4477463.13%nannannan
17NET_FLUX_SRFmin-284.505205-299.5050245.27%-280.893287-290.2029343.31%-284.505205-299.5050245.27%-280.893287-290.2029343.31%-75.857589-85.85208913.18%nannannan
18PRECTmax17.28995117.071276nan20.26486220.138274nan17.28995117.071276nan20.26486220.138274nan2.3441112.4066252.67%nannannan
19PRECTmean3.0538023.056760nan3.0748853.074978nan3.0538023.056760nan3.0748853.074978nan-0.021083-0.01821813.59%nannannan
20PSLmin970.981710971.390765nan973.198437973.235326nan970.981710971.390765nan973.198437973.235326nan-6.328677-6.1046103.54%nannannan
21PSLrmsenannannannannannannannannannannannannannannan1.0428840.9799816.03%
22RESTOMmax84.29550283.821906nan87.70794487.451262nan84.29550283.821906nan87.70794487.451262nan17.39628321.42361623.15%nannannan
23RESTOMmean0.4815490.65656036.34%0.0180410.162984803.40%0.4815490.65656036.34%0.0180410.162984803.40%0.4635080.4935766.49%nannannan
24RESTOMmin-127.667181-129.014673nan-127.417586-128.673508nan-127.667181-129.014673nan-127.417586-128.673508nan-15.226249-14.8696142.34%nannannan
25SHFLXmax114.036895112.859646nan116.870038116.432591nan114.036895112.859646nan116.870038116.432591nan28.32065627.5567552.70%nannannan
26SHFLXmin-88.650312-88.386947nan-85.809438-85.480377nan-88.650312-88.386947nan-85.809438-85.480377nan-27.776625-28.3630532.11%nannannan
27SSTmin-1.788055-1.788055nan-1.676941-1.676941nan-1.788055-1.788055nan-1.676941-1.676941nan-4.513070-2.99327233.68%nannannan
28SWCFmax-0.518025-0.5368443.63%-0.311639-0.3316166.41%-0.518025-0.5368443.63%-0.311639-0.3316166.41%11.66893912.0870773.58%nannannan
29SWCFmin-123.625017-122.042043nan-131.053537-130.430161nan-123.625017-122.042043nan-131.053537-130.430161nan-21.415249-20.8089732.83%nannannan
30TREFHTmax31.14150831.058424nan29.81921029.721868nan31.14150831.058424nan29.81921029.721868nan4.9817575.1261852.90%nannannan
31TREFHTmax31.14150831.058424nan29.81921029.721868nan31.14150831.058424nan29.81921029.721868nan4.8678555.1261852.90%nannannan
32TREFHTmax31.14150831.058424nan29.81921029.721868nan31.14150831.058424nan29.81921029.721868nan4.9817575.1261855.31%nannannan
33TREFHTmax31.14150831.058424nan29.81921029.721868nan31.14150831.058424nan29.81921029.721868nan4.8678555.1261855.31%nannannan
34TREFHTmean14.76994614.741707nan13.84201313.800258nan14.76994614.741707nan13.84201313.800258nan0.9279330.9414492.28%nannannan
35TREFHTmean9.2142249.114572nan8.0833497.957917nan9.2142249.114572nan8.0833497.957917nan1.1308761.1566552.28%nannannan
36TREFHTmin-56.266677-55.623001nan-58.159250-57.542053nan-56.266677-55.623001nan-58.159250-57.542053nan-0.681558-0.6243718.39%nannannan
37TREFHTmin-56.266677-55.623001nan-58.159250-57.542053nan-56.266677-55.623001nan-58.159250-57.542053nan-0.681558-0.6243718.39%nannannan
38TREFHTmin-56.266677-55.623001nan-58.159250-57.542053nan-56.266677-55.623001nan-58.159250-57.542053nan-0.681558-0.6243718.39%nannannan
39TREFHTmin-56.266677-55.623001nan-58.159250-57.542053nan-56.266677-55.623001nan-58.159250-57.542053nan-0.681558-0.6243718.39%nannannan
40TREFHTrmsenannannannannannannannannannannannannannannan1.1607181.1799952.68%
41TREFHTrmsenannannannannannannannannannannannannannannan1.3431691.3791412.68%
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_final.reset_index(names=[\"var_key\", \"metric\"]).style.map(\n", + " lambda x: \"background-color : red\" if isinstance(x, str) else \"\",\n", + " subset=pd.IndexSlice[:, PERCENTAGE_COLUMNS],\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cdat_regression_test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/e3sm_diags/driver/aerosol_budget_driver.py b/e3sm_diags/driver/aerosol_budget_driver.py index 9c1de7d00..faf0c4005 100644 --- a/e3sm_diags/driver/aerosol_budget_driver.py +++ b/e3sm_diags/driver/aerosol_budget_driver.py @@ -3,6 +3,7 @@ script is integrated in e3sm_diags by Jill Zhang, with input from Kai Zhang, Taufiq Hassan, Xue Zheng, Ziming Ke, Susannah Burrows, and Naser Mahfouz. """ + from __future__ import annotations import csv diff --git a/e3sm_diags/driver/qbo_driver.py b/e3sm_diags/driver/qbo_driver.py index 3379f4c46..5bc0c5ec2 100644 --- a/e3sm_diags/driver/qbo_driver.py +++ b/e3sm_diags/driver/qbo_driver.py @@ -125,6 +125,11 @@ def run_diag(parameter: QboParameter) -> QboParameter: test_dict["name"] = test_ds._get_test_name() ref_dict["name"] = ref_ds._get_ref_name() + try: + ref_dict["name"] = ref_ds._get_ref_name() + except AttributeError: + ref_dict["name"] = parameter.ref_name + _save_metrics_to_json(parameter, test_dict, "test") # type: ignore _save_metrics_to_json(parameter, ref_dict, "ref") # type: ignore diff --git a/e3sm_diags/driver/utils/climo_xr.py b/e3sm_diags/driver/utils/climo_xr.py index bb229048c..acbe73fa2 100644 --- a/e3sm_diags/driver/utils/climo_xr.py +++ b/e3sm_diags/driver/utils/climo_xr.py @@ -1,8 +1,8 @@ """This module stores climatology functions operating on Xarray objects. - This file will eventually be refactored to use xCDAT's climatology API. """ + from typing import Dict, List, Literal, get_args import numpy as np diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 94b109e66..94c3f69c6 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -1,6 +1,5 @@ """This module stores the Dataset class, which is the primary class for I/O. - This Dataset class operates on `xr.Dataset` objects, which are created using netCDF files. These `xr.Dataset` contain either the reference or test variable. This variable can either be from a climatology file or a time series file. @@ -8,6 +7,7 @@ calculated. Reference and test variables can also be derived using other variables from dataset files. """ + from __future__ import annotations import collections diff --git a/e3sm_diags/metrics/metrics.py b/e3sm_diags/metrics/metrics.py index 333980643..d98fe519d 100644 --- a/e3sm_diags/metrics/metrics.py +++ b/e3sm_diags/metrics/metrics.py @@ -1,4 +1,5 @@ """This module stores functions to calculate metrics using Xarray objects.""" + from __future__ import annotations from typing import List, Literal diff --git a/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py b/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py new file mode 100644 index 000000000..765235095 --- /dev/null +++ b/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py @@ -0,0 +1,132 @@ +import os + +import cartopy.crs as ccrs +import matplotlib +import numpy as np + +from e3sm_diags.driver.utils.general import get_output_dir +from e3sm_diags.logger import custom_logger +from e3sm_diags.metrics import mean +from e3sm_diags.plot.cartopy.deprecated_lat_lon_plot import plot_panel + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +plotTitle = {"fontsize": 11.5} +plotSideTitle = {"fontsize": 9.5} + + +def plot(test, test_site, ref_site, parameter): + # Plot scatter plot + # Position and sizes of subplot axes in page coordinates (0 to 1) + # (left, bottom, width, height) in page coordinates + panel = [ + (0.09, 0.40, 0.72, 0.30), + (0.19, 0.2, 0.62, 0.30), + ] + # Border padding relative to subplot axes for saving individual panels + # (left, bottom, right, top) in page coordinates + border = (-0.06, -0.03, 0.13, 0.03) + + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + fig.suptitle(parameter.var_id, x=0.5, y=0.97) + proj = ccrs.PlateCarree() + max1 = test.max() + min1 = test.min() + mean1 = mean(test) + # TODO: Replace this function call with `e3sm_diags.plot.utils._add_colormap()`. + plot_panel( + 0, + fig, + proj, + test, + parameter.contour_levels, + parameter.test_colormap, + (parameter.test_name_yrs, None, None), + parameter, + stats=(max1, mean1, min1), + ) + + ax = fig.add_axes(panel[1]) + ax.set_title(f"{parameter.var_id} from AERONET sites") + + # define 1:1 line, and x y axis limits + + if parameter.var_id == "AODVIS": + x1 = np.arange(0.01, 3.0, 0.1) + y1 = np.arange(0.01, 3.0, 0.1) + plt.xlim(0.03, 1) + plt.ylim(0.03, 1) + else: + x1 = np.arange(0.0001, 1.0, 0.01) + y1 = np.arange(0.0001, 1.0, 0.01) + plt.xlim(0.001, 0.3) + plt.ylim(0.001, 0.3) + + plt.loglog(x1, y1, "-k", linewidth=0.5) + plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5) + plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5) + + corr = np.corrcoef(ref_site, test_site) + xmean = np.mean(ref_site) + ymean = np.mean(test_site) + ax.text( + 0.3, + 0.9, + f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}", + horizontalalignment="right", + verticalalignment="top", + transform=ax.transAxes, + ) + + # axis ticks + plt.tick_params(axis="both", which="major") + plt.tick_params(axis="both", which="minor") + + # axis labels + plt.xlabel(f"ref: {parameter.ref_name_yrs}") + plt.ylabel(f"test: {parameter.test_name_yrs}") + + plt.loglog(ref_site, test_site, "kx", markersize=3.0, mfc="none") + + # legend + plt.legend(frameon=False, prop={"size": 5}) + + # TODO: This section can be refactored to use `plot.utils._save_plot()`. + for f in parameter.output_format: + f = f.lower().split(".")[-1] + fnm = os.path.join( + get_output_dir(parameter.current_set, parameter), + f"{parameter.output_file}" + "." + f, + ) + plt.savefig(fnm) + logger.info(f"Plot saved in: {fnm}") + + for f in parameter.output_format_subplot: + fnm = os.path.join( + get_output_dir(parameter.current_set, parameter), + parameter.output_file, + ) + page = fig.get_size_inches() + i = 0 + for p in panel: + # Extent of subplot + subpage = np.array(p).reshape(2, 2) + subpage[1, :] = subpage[0, :] + subpage[1, :] + subpage = subpage + np.array(border).reshape(2, 2) + subpage = list(((subpage) * page).flatten()) # type: ignore + extent = matplotlib.transforms.Bbox.from_extents(*subpage) + # Save subplot + fname = fnm + ".%i." % (i) + f + plt.savefig(fname, bbox_inches=extent) + + orig_fnm = os.path.join( + get_output_dir(parameter.current_set, parameter), + parameter.output_file, + ) + fname = orig_fnm + ".%i." % (i) + f + logger.info(f"Sub-plot saved in: {fname}") + + i += 1 diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py new file mode 100644 index 000000000..a72bf5dce --- /dev/null +++ b/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py @@ -0,0 +1,187 @@ +from typing import List, Optional, Tuple + +import matplotlib +import numpy as np +import xarray as xr +import xcdat as xc + +from e3sm_diags.driver.utils.type_annotations import MetricsDict +from e3sm_diags.logger import custom_logger +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS +from e3sm_diags.plot.utils import ( + DEFAULT_PANEL_CFG, + _add_colorbar, + _add_contour_plot, + _add_min_mean_max_text, + _add_rmse_corr_text, + _configure_titles, + _configure_x_and_y_axes, + _get_c_levels_and_norm, + _save_plot, +) + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + + +# Configs for x axis ticks and x axis limits. +X_TICKS = np.array([-90, -60, -30, 0, 30, 60, 90]) +X_LIM = -90, 90 + + +def plot( + parameter: CoreParameter, + da_test: xr.DataArray, + da_ref: xr.DataArray, + da_diff: xr.DataArray, + metrics_dict: MetricsDict, +): + """Plot the variable's metrics generated by the zonal_mean_2d set. + + Parameters + ---------- + parameter : CoreParameter + The CoreParameter object containing plot configurations. + da_test : xr.DataArray + The test data. + da_ref : xr.DataArray + The reference data. + da_diff : xr.DataArray + The difference between `da_test` and `da_ref` (both are regridded to + the lower resolution of the two beforehand). + metrics_dict : Metrics + The metrics. + """ + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18) + + # The variable units. + units = metrics_dict["units"] + + # Add the first subplot for test data. + min1 = metrics_dict["test"]["min"] # type: ignore + mean1 = metrics_dict["test"]["mean"] # type: ignore + max1 = metrics_dict["test"]["max"] # type: ignore + + _add_colormap( + 0, + da_test, + fig, + parameter, + parameter.test_colormap, + parameter.contour_levels, + title=(parameter.test_name_yrs, parameter.test_title, units), # type: ignore + metrics=(max1, mean1, min1), # type: ignore + ) + + # Add the second and third subplots for ref data and the differences, + # respectively. + min2 = metrics_dict["ref"]["min"] # type: ignore + mean2 = metrics_dict["ref"]["mean"] # type: ignore + max2 = metrics_dict["ref"]["max"] # type: ignore + + _add_colormap( + 1, + da_ref, + fig, + parameter, + parameter.reference_colormap, + parameter.contour_levels, + title=(parameter.ref_name_yrs, parameter.reference_title, units), # type: ignore + metrics=(max2, mean2, min2), # type: ignore + ) + + min3 = metrics_dict["diff"]["min"] # type: ignore + mean3 = metrics_dict["diff"]["mean"] # type: ignore + max3 = metrics_dict["diff"]["max"] # type: ignore + r = metrics_dict["misc"]["rmse"] # type: ignore + c = metrics_dict["misc"]["corr"] # type: ignore + + _add_colormap( + 2, + da_diff, + fig, + parameter, + parameter.diff_colormap, + parameter.diff_levels, + title=(None, parameter.diff_title, da_diff.attrs["units"]), # + metrics=(max3, mean3, min3, r, c), # type: ignore + ) + + _save_plot(fig, parameter) + + plt.close() + + +def _add_colormap( + subplot_num: int, + var: xr.DataArray, + fig: plt.Figure, + parameter: CoreParameter, + color_map: str, + contour_levels: List[float], + title: Tuple[Optional[str], str, str], + metrics: Tuple[float, ...], +): + lat = xc.get_dim_coords(var, axis="Y") + plev = xc.get_dim_coords(var, axis="Z") + var = var.squeeze() + + # Configure contour levels + # -------------------------------------------------------------------------- + c_levels, norm = _get_c_levels_and_norm(contour_levels) + + # Add the contour plot + # -------------------------------------------------------------------------- + ax = fig.add_axes(DEFAULT_PANEL_CFG[subplot_num], projection=None) + + contour_plot = _add_contour_plot( + ax, parameter, var, lat, plev, color_map, None, norm, c_levels + ) + + # Configure the aspect ratio and plot titles. + # -------------------------------------------------------------------------- + ax.set_aspect("auto") + _configure_titles(ax, title) + + # Configure x and y axis. + # -------------------------------------------------------------------------- + _configure_x_and_y_axes(ax, X_TICKS, None, None, parameter.current_set) + ax.set_xlim(X_LIM) + + if parameter.plot_log_plevs: + ax.set_yscale("log") + + if parameter.plot_plevs: + plev_ticks = parameter.plevs + plt.yticks(plev_ticks, plev_ticks) + + # For default plevs, specify the pressure axis and show the 50 mb tick + # at the top. + if ( + not parameter.plot_log_plevs + and not parameter.plot_plevs + and parameter.plevs == DEFAULT_PLEVS + ): + plev_ticks = parameter.plevs + new_ticks = [plev_ticks[0]] + plev_ticks[1::2] + new_ticks = [int(x) for x in new_ticks] + plt.yticks(new_ticks, new_ticks) + + plt.ylabel("pressure (mb)") + ax.invert_yaxis() + + # Add and configure the color bar. + # -------------------------------------------------------------------------- + _add_colorbar(fig, subplot_num, DEFAULT_PANEL_CFG, contour_plot, c_levels) + + # Add metrics text. + # -------------------------------------------------------------------------- + # Min, Mean, Max + _add_min_mean_max_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics) + + if len(metrics) == 5: + _add_rmse_corr_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics) diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py new file mode 100644 index 000000000..004f3c93d --- /dev/null +++ b/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py @@ -0,0 +1,15 @@ +import xarray as xr + +from e3sm_diags.driver.utils.type_annotations import MetricsDict +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.plot.cartopy.zonal_mean_2d_plot import plot as base_plot + + +def plot( + parameter: CoreParameter, + da_test: xr.DataArray, + da_ref: xr.DataArray, + da_diff: xr.DataArray, + metrics_dict: MetricsDict, +): + return base_plot(parameter, da_test, da_ref, da_diff, metrics_dict) diff --git a/e3sm_diags/plot/deprecated_lat_lon_plot.py b/e3sm_diags/plot/deprecated_lat_lon_plot.py new file mode 100644 index 000000000..4eaebcf80 --- /dev/null +++ b/e3sm_diags/plot/deprecated_lat_lon_plot.py @@ -0,0 +1,360 @@ +""" +WARNING: This module has been deprecated and replaced by +`e3sm_diags.plot.lat_lon_plot.py`. This file temporarily kept because +`e3sm_diags.plot.cartopy.aerosol_aeronet_plot.plot` references the +`plot_panel()` function. Once the aerosol_aeronet set is refactored, this +file can be deleted. +""" +from __future__ import print_function + +import os + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import cdutil +import matplotlib +import numpy as np +import numpy.ma as ma +from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter + +from e3sm_diags.derivations.default_regions import regions_specs +from e3sm_diags.driver.utils.general import get_output_dir +from e3sm_diags.logger import custom_logger +from e3sm_diags.plot import get_colormap + +matplotlib.use("Agg") +import matplotlib.colors as colors # isort:skip # noqa: E402 +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +plotTitle = {"fontsize": 11.5} +plotSideTitle = {"fontsize": 9.5} + +# Position and sizes of subplot axes in page coordinates (0 to 1) +panel = [ + (0.1691, 0.6810, 0.6465, 0.2258), + (0.1691, 0.3961, 0.6465, 0.2258), + (0.1691, 0.1112, 0.6465, 0.2258), +] + +# Border padding relative to subplot axes for saving individual panels +# (left, bottom, right, top) in page coordinates +border = (-0.06, -0.03, 0.13, 0.03) + + +def add_cyclic(var): + lon = var.getLongitude() + return var(longitude=(lon[0], lon[0] + 360.0, "coe")) + + +def get_ax_size(fig, ax): + bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + width, height = bbox.width, bbox.height + width *= fig.dpi + height *= fig.dpi + return width, height + + +def determine_tick_step(degrees_covered): + if degrees_covered > 180: + return 60 + if degrees_covered > 60: + return 30 + elif degrees_covered > 30: + return 10 + elif degrees_covered > 20: + return 5 + else: + return 1 + + +def plot_panel( # noqa: C901 + n, fig, proj, var, clevels, cmap, title, parameters, stats=None +): + var = add_cyclic(var) + lon = var.getLongitude() + lat = var.getLatitude() + var = ma.squeeze(var.asma()) + + # Contour levels + levels = None + norm = None + if len(clevels) > 0: + levels = [-1.0e8] + clevels + [1.0e8] + norm = colors.BoundaryNorm(boundaries=levels, ncolors=256) + + # ax.set_global() + region_str = parameters.regions[0] + region = regions_specs[region_str] + global_domain = True + full_lon = True + if "domain" in region.keys(): # type: ignore + # Get domain to plot + domain = region["domain"] # type: ignore + global_domain = False + else: + # Assume global domain + domain = cdutil.region.domain(latitude=(-90.0, 90, "ccb")) + kargs = domain.components()[0].kargs + lon_west, lon_east, lat_south, lat_north = (0, 360, -90, 90) + if "longitude" in kargs: + full_lon = False + lon_west, lon_east, _ = kargs["longitude"] + # Note cartopy Problem with gridlines across the dateline:https://github.com/SciTools/cartopy/issues/821. Region cross dateline is not supported yet. + if lon_west > 180 and lon_east > 180: + lon_west = lon_west - 360 + lon_east = lon_east - 360 + + if "latitude" in kargs: + lat_south, lat_north, _ = kargs["latitude"] + lon_covered = lon_east - lon_west + lon_step = determine_tick_step(lon_covered) + xticks = np.arange(lon_west, lon_east, lon_step) + # Subtract 0.50 to get 0 W to show up on the right side of the plot. + # If less than 0.50 is subtracted, then 0 W will overlap 0 E on the left side of the plot. + # If a number is added, then the value won't show up at all. + if global_domain or full_lon: + xticks = np.append(xticks, lon_east - 0.50) + proj = ccrs.PlateCarree(central_longitude=180) + else: + xticks = np.append(xticks, lon_east) + lat_covered = lat_north - lat_south + lat_step = determine_tick_step(lat_covered) + yticks = np.arange(lat_south, lat_north, lat_step) + yticks = np.append(yticks, lat_north) + + # Contour plot + ax = fig.add_axes(panel[n], projection=proj) + ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=proj) + cmap = get_colormap(cmap, parameters) + p1 = ax.contourf( + lon, + lat, + var, + transform=ccrs.PlateCarree(), + norm=norm, + levels=levels, + cmap=cmap, + extend="both", + ) + + # ax.set_aspect('auto') + # Full world would be aspect 360/(2*180) = 1 + ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south))) + ax.coastlines(lw=0.3) + if not global_domain and "RRM" in region_str: + ax.coastlines(resolution="50m", color="black", linewidth=1) + state_borders = cfeature.NaturalEarthFeature( + category="cultural", + name="admin_1_states_provinces_lakes", + scale="50m", + facecolor="none", + ) + ax.add_feature(state_borders, edgecolor="black") + if title[0] is not None: + ax.set_title(title[0], loc="left", fontdict=plotSideTitle) + if title[1] is not None: + ax.set_title(title[1], fontdict=plotTitle) + if title[2] is not None: + ax.set_title(title[2], loc="right", fontdict=plotSideTitle) + ax.set_xticks(xticks, crs=ccrs.PlateCarree()) + ax.set_yticks(yticks, crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True, number_format=".0f") + lat_formatter = LatitudeFormatter() + ax.xaxis.set_major_formatter(lon_formatter) + ax.yaxis.set_major_formatter(lat_formatter) + ax.tick_params(labelsize=8.0, direction="out", width=1) + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("left") + + # Color bar + cbax = fig.add_axes((panel[n][0] + 0.6635, panel[n][1] + 0.0215, 0.0326, 0.1792)) + cbar = fig.colorbar(p1, cax=cbax) + w, h = get_ax_size(fig, cbax) + + if levels is None: + cbar.ax.tick_params(labelsize=9.0, length=0) + + else: + maxval = np.amax(np.absolute(levels[1:-1])) + if maxval < 0.2: + fmt = "%5.3f" + pad = 28 + elif maxval < 10.0: + fmt = "%5.2f" + pad = 25 + elif maxval < 100.0: + fmt = "%5.1f" + pad = 25 + elif maxval > 9999.0: + fmt = "%.0f" + pad = 40 + else: + fmt = "%6.1f" + pad = 30 + + cbar.set_ticks(levels[1:-1]) + labels = [fmt % level for level in levels[1:-1]] + cbar.ax.set_yticklabels(labels, ha="right") + cbar.ax.tick_params(labelsize=9.0, pad=pad, length=0) + + # Min, Mean, Max + fig.text( + panel[n][0] + 0.6635, + panel[n][1] + 0.2107, + "Max\nMean\nMin", + ha="left", + fontdict=plotSideTitle, + ) + + fmt_m = [] + # printing in scientific notation if value greater than 10^5 + for i in range(len(stats[0:3])): + fs = "1e" if stats[i] > 100000.0 else "2f" + fmt_m.append(fs) + fmt_metrics = f"%.{fmt_m[0]}\n%.{fmt_m[1]}\n%.{fmt_m[2]}" + + fig.text( + panel[n][0] + 0.7635, + panel[n][1] + 0.2107, + # "%.2f\n%.2f\n%.2f" % stats[0:3], + fmt_metrics % stats[0:3], + ha="right", + fontdict=plotSideTitle, + ) + + # RMSE, CORR + if len(stats) == 5: + fig.text( + panel[n][0] + 0.6635, + panel[n][1] - 0.0105, + "RMSE\nCORR", + ha="left", + fontdict=plotSideTitle, + ) + fig.text( + panel[n][0] + 0.7635, + panel[n][1] - 0.0105, + "%.2f\n%.2f" % stats[3:5], + ha="right", + fontdict=plotSideTitle, + ) + + # grid resolution info: + if n == 2 and "RRM" in region_str: + dlat = lat[2] - lat[1] + dlon = lon[2] - lon[1] + fig.text( + panel[n][0] + 0.4635, + panel[n][1] - 0.04, + "Resolution: {:.2f}x{:.2f}".format(dlat, dlon), + ha="left", + fontdict=plotSideTitle, + ) + + +def plot(reference, test, diff, metrics_dict, parameter): + # Create figure, projection + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + proj = ccrs.PlateCarree() + + # Figure title + fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18) + + # First two panels + min1 = metrics_dict["test"]["min"] + mean1 = metrics_dict["test"]["mean"] + max1 = metrics_dict["test"]["max"] + + plot_panel( + 0, + fig, + proj, + test, + parameter.contour_levels, + parameter.test_colormap, + (parameter.test_name_yrs, parameter.test_title, test.units), + parameter, + stats=(max1, mean1, min1), + ) + + if not parameter.model_only: + min2 = metrics_dict["ref"]["min"] + mean2 = metrics_dict["ref"]["mean"] + max2 = metrics_dict["ref"]["max"] + + plot_panel( + 1, + fig, + proj, + reference, + parameter.contour_levels, + parameter.reference_colormap, + (parameter.ref_name_yrs, parameter.reference_title, reference.units), + parameter, + stats=(max2, mean2, min2), + ) + + # Third panel + min3 = metrics_dict["diff"]["min"] + mean3 = metrics_dict["diff"]["mean"] + max3 = metrics_dict["diff"]["max"] + r = metrics_dict["misc"]["rmse"] + c = metrics_dict["misc"]["corr"] + plot_panel( + 2, + fig, + proj, + diff, + parameter.diff_levels, + parameter.diff_colormap, + (None, parameter.diff_title, test.units), + parameter, + stats=(max3, mean3, min3, r, c), + ) + + # Save figure + for f in parameter.output_format: + f = f.lower().split(".")[-1] + fnm = os.path.join( + get_output_dir(parameter.current_set, parameter), + parameter.output_file + "." + f, + ) + plt.savefig(fnm) + logger.info(f"Plot saved in: {fnm}") + + # Save individual subplots + if parameter.ref_name == "": + panels = [panel[0]] + else: + panels = panel + + for f in parameter.output_format_subplot: + fnm = os.path.join( + get_output_dir(parameter.current_set, parameter), + parameter.output_file, + ) + page = fig.get_size_inches() + i = 0 + for p in panels: + # Extent of subplot + subpage = np.array(p).reshape(2, 2) + subpage[1, :] = subpage[0, :] + subpage[1, :] + subpage = subpage + np.array(border).reshape(2, 2) + subpage = list(((subpage) * page).flatten()) # type: ignore + extent = matplotlib.transforms.Bbox.from_extents(*subpage) + # Save subplot + fname = fnm + ".%i." % (i) + f + plt.savefig(fname, bbox_inches=extent) + + orig_fnm = os.path.join( + get_output_dir(parameter.current_set, parameter), + parameter.output_file, + ) + fname = orig_fnm + ".%i." % (i) + f + logger.info(f"Sub-plot saved in: {fname}") + + i += 1 + + plt.close() diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 2c89c92cb..03501dc95 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -561,6 +561,63 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest @pytest.mark.xfail( reason="Need to figure out why to create dummy incorrect time scalar variable with Xarray." ) + def test_returns_climo_dataset_with_derived_variable(self): + # We will derive the "PRECT" variable using the "pr" variable. + ds_pr = xr.Dataset( + coords={ + **spatial_coords, + "time": xr.DataArray( + dims="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + ), + ], + dtype=object, + ), + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + data_vars={ + **spatial_bounds, + "pr": xr.DataArray( + xr.DataArray( + data=np.array( + [ + [[1.0, 1.0], [1.0, 1.0]], + ] + ), + dims=["time", "lat", "lon"], + attrs={"units": "mm/s"}, + ) + ), + }, + ) + + parameter = _create_parameter_object( + "ref", "climo", self.data_path, "2000", "2001" + ) + parameter.ref_file = "pr_200001_200112.nc" + ds_pr.to_netcdf(f"{self.data_path}/{parameter.ref_file}") + + ds = Dataset(parameter, data_type="ref") + + result = ds.get_climo_dataset("PRECT", season="ANN") + expected = ds_pr.copy() + expected = expected.squeeze(dim="time").drop_vars("time") + expected["PRECT"] = expected["pr"] * 3600 * 24 + expected["PRECT"].attrs["units"] = "mm/day" + expected = expected.drop_vars("pr") + + xr.testing.assert_identical(result, expected) + + @pytest.mark.xfail def test_returns_climo_dataset_using_derived_var_directly_from_dataset_and_replaces_scalar_time_var( self, ): diff --git a/tests/e3sm_diags/driver/utils/test_regrid.py b/tests/e3sm_diags/driver/utils/test_regrid.py index 6dc33fcda..c02451345 100644 --- a/tests/e3sm_diags/driver/utils/test_regrid.py +++ b/tests/e3sm_diags/driver/utils/test_regrid.py @@ -231,7 +231,6 @@ def test_regrids_to_first_dataset_with_equal_latitude_points(self, tool): expected_a = ds_a.copy() expected_b = ds_a.copy() - if tool in ["esmf", "xesmf"]: expected_b.so.attrs["regrid_method"] = "conservative"