Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update prepare_variable #119

Merged
merged 16 commits into from
Mar 8, 2023
1 change: 1 addition & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- python-cdo
- cmor
- pint-xarray
- flox
# for testing
- pytest
- pytest-cov
Expand Down
3 changes: 2 additions & 1 deletion cordex/cmor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .cmor import cmorize_variable, prepare_variable
from .config import set_options
from .config import options, set_options
from .utils import (
mid_of_month,
mid_of_season,
Expand Down Expand Up @@ -30,4 +30,5 @@
"season_bounds",
"to_cftime",
"set_options",
"options",
]
103 changes: 70 additions & 33 deletions cordex/cmor/cmor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
import os
from os import path as op
from warnings import warn

import cf_xarray as cfxr
Expand Down Expand Up @@ -31,14 +31,18 @@
_get_cfvarinfo,
_get_cordex_pole,
_get_pole,
_read_table,
_strip_time_cell_method,
_tmp_table,
mid_of_month,
month_bounds,
time_bounds_name,
)

xr.set_options(keep_attrs=True)

flox_method = "blockwise"


def resample_both_closed(ds, hfreq, op, **kwargs):
rolling = getattr(ds.rolling(time=hfreq + 1, center=True), op)()
Expand Down Expand Up @@ -67,17 +71,18 @@ def _resample(
ds, time, time_cell_method="point", label="left", time_offset=True, **kwargs
):
"""Resample a variable."""
# freq = "{}H".format(hfreq)
if time_cell_method == "point":
return ds.resample(
time=time, label=label, **kwargs
).nearest() # .interpolate("nearest") # use as_freq?
return ds.resample(time=time, label=label, **kwargs).nearest(
method=flox_method
) # .interpolate("nearest") # use as_freq?
elif time_cell_method == "mean":
if time_offset is True:
loffset = _get_loffset(time)
else:
loffset = None
return ds.resample(time=time, label=label, loffset=loffset, **kwargs).mean()
return ds.resample(time=time, label=label, loffset=loffset, **kwargs).mean(
method=flox_method
)
else:
raise Exception("unknown time_cell_method: {}".format(time_cell_method))

Expand Down Expand Up @@ -130,7 +135,7 @@ def _get_time_axis_name(time_cell_method):

def _define_axes(ds, table_id):
if "CORDEX_domain" in ds.attrs:
grid = cordex_domain(ds.attrs["CORDEX_domain"], add_vertices=True)
grid = cordex_domain(ds.attrs["CORDEX_domain"], bounds=True)
lon_vertices = grid.lon_vertices.to_numpy()
lat_vertices = grid.lat_vertices.to_numpy()
else:
Expand Down Expand Up @@ -222,7 +227,11 @@ def _cmor_write(da, table_id, cmorTime, cmorGrid, file_name=True):
else:
coords = [cmorTime, cmorGrid]
cmor_var = cmor.variable(da.name, da.units, coords)
cmor.write(cmor_var, da.values)
if "time" in da.coords:
ntimes_passed = da.time.size
else:
ntimes_passed = None
cmor.write(cmor_var, da.to_numpy(), ntimes_passed=ntimes_passed)
return cmor.close(cmor_var, file_name=file_name)


Expand All @@ -237,15 +246,14 @@ def _units_convert(da, cf_units, format=None):
return da


def _cf_units_convert(da, table_file, mapping_table={}):
def _cf_units_convert(da, table, mapping_table={}):
"""Convert units.

Convert units according to the rules in units_convert_rules dict.
Maybe metpy can do this also: https://unidata.github.io/MetPy/latest/tutorials/unit_tutorial.html

"""
with open(table_file) as f:
table = json.load(f)

if da.name in mapping_table:
map_units = mapping_table[da.name].get("units")
atr_units = da.attrs.get("units")
Expand Down Expand Up @@ -365,10 +373,13 @@ def _add_time_bounds(ds, cf_freq):
ds = _add_month_bounds(ds)
else:
try:
ds = ds.convert_calendar(ds.time.dt.calendar).cf.add_bounds("time")
ds = ds.convert_calendar(
ds.time.dt.calendar, use_cftime=False
).cf.add_bounds("time")
except Exception:
# wait for cftime arithemtics in xarry here:
warn("could not add time bounds.")

ds[time_bounds_name].encoding = ds.time.encoding
ds.time.attrs.update({"bounds": time_bounds_name})
return ds
Expand All @@ -383,7 +394,9 @@ def _adjust_frequency(ds, cf_freq, input_freq=None, time_cell_method=None):
pd_freq = freq_map[cf_freq]
if pd_freq != input_freq:
warn("resampling input data from {} to {}".format(input_freq, pd_freq))
resample = _resample(ds, pd_freq, time_cell_method=time_cell_method)
resample = _resample(
ds, pd_freq, time_cell_method=time_cell_method, **options["resample_kwargs"]
)
return resample
return ds

Expand All @@ -395,6 +408,14 @@ def cmorize_cmor(
if inpath is None:
inpath = os.path.dirname(cmor_table)

dataset_table_json = dataset_table
cmor_table_json = cmor_table

if isinstance(dataset_table, dict):
dataset_table_json = _tmp_table(dataset_table)
if isinstance(cmor_table, dict):
cmor_table_json = _tmp_table(cmor_table)

cfvarinfo = _get_cfvarinfo(out_name, cmor_table)

if cfvarinfo is None:
Expand All @@ -403,7 +424,7 @@ def cmorize_cmor(
time_cell_method = _strip_time_cell_method(cfvarinfo)

table_ids = _setup(
dataset_table, cmor_table, grids_table=grids_table, inpath=inpath
dataset_table_json, cmor_table_json, grids_table=grids_table, inpath=inpath
)

cmorTime, cmorGrid = _define_grid(ds, table_ids, time_cell_method=time_cell_method)
Expand All @@ -422,17 +443,29 @@ def prepare_variable(
input_freq=None,
CORDEX_domain=None,
time_units=None,
time_cell_method=None,
cf_freq=None,
rewrite_time_axis=False,
use_cftime=False,
squeeze=True,
):
"""prepares a variable for cmorization."""

ds = ds.copy(deep=False)

if isinstance(cmor_table, str):
cmor_table = _read_table(cmor_table)
cfvarinfo = _get_cfvarinfo(out_name, cmor_table)

cf_freq = cfvarinfo["frequency"]
time_cell_method = _strip_time_cell_method(cfvarinfo)

if isinstance(ds, xr.DataArray):
ds = ds.to_dataset()

# ensure that we propagate everything
# ds = xr.decode_cf(ds, decode_coords="all")

# no mapping table provided, we assume datasets has already correct out_names and units.
if mapping_table is None:
if out_name not in mapping_table:
try:
var_ds = ds[[out_name]]
except Exception:
Expand All @@ -448,11 +481,15 @@ def prepare_variable(
if squeeze is True:
var_ds = var_ds.squeeze(drop=True)
if CORDEX_domain is not None:
var_ds.attrs["CORDEX_domain"] = CORDEX_domain
var_ds = _crop_to_cordex_domain(var_ds, CORDEX_domain)
if replace_coords is True:
grid = cordex_domain(CORDEX_domain)
grid = cordex_domain(CORDEX_domain, bounds=True)
var_ds = var_ds.assign_coords(rlon=grid.rlon, rlat=grid.rlat)
var_ds = var_ds.assign_coords(lon=grid.lon, lat=grid.lat)
var_ds = var_ds.assign_coords(
lon_vertices=grid.lon_vertices, lat_vertices=grid.lat_vertices
)

if "time" in var_ds:
# ensure cftime
Expand All @@ -464,6 +501,8 @@ def prepare_variable(
if "time" not in ds.cf.bounds and time_cell_method == "mean":
warn("adding time bounds")
var_ds = _add_time_bounds(var_ds, cf_freq)
if use_cftime is False:
var_ds = var_ds.convert_calendar(ds.time.dt.calendar, use_cftime=False)
var_ds = _set_time_encoding(var_ds, time_units, ds)

if allow_units_convert is True:
Expand All @@ -477,7 +516,10 @@ def prepare_variable(
warn("adding pole from archive specs: {}".format(CORDEX_domain))
mapping = _get_cordex_pole(CORDEX_domain)

var_ds = xr.merge([var_ds, mapping])
if "time" in mapping.coords:
raise Exception("grid_mapping variable should have no time coordinate!")

var_ds[mapping.name] = mapping

return var_ds

Expand Down Expand Up @@ -509,10 +551,10 @@ def cmorize_variable(
out_name: str
CF out_name of the variable that should be cmorized. The corresponding variable name
in the dataset is looked up from the mapping_table if provided.
cmor_table : str
Filepath to cmor table.
dataset_table: str
Filepath to dataset cmor table.
cmor_table : str or dict
Cmor table dict of filepath to cmor table (json).
dataset_table: str or dict
Dataset table dict of filepath to dataset cmor table (json).
mapping_table: dict
Mapping of input variable names and meta data to CF out_name. Required if
the variable name in the input dataset is not equal to out_name.
Expand Down Expand Up @@ -548,7 +590,7 @@ def cmorize_variable(
Rewrite the time axis to CF compliant timestamps.
outpath: str
Root directory for output (can be either a relative or full path). This will override
the outpath defined in the dataset cmor input table.
the outpath defined in the dataset cmor input table (``dataset_table``).
**kwargs:
Argumets passed to prepare_variable.

Expand All @@ -573,14 +615,11 @@ def cmorize_variable(
if inpath is None:
inpath = os.path.dirname(cmor_table)

# get meta info from cmor table
cfvarinfo = _get_cfvarinfo(out_name, cmor_table)

if cfvarinfo is None:
raise Exception("{} not found in {}".format(out_name, cmor_table))
if op.isfile(dataset_table):
dataset_table = _read_table(dataset_table)

cf_freq = cfvarinfo["frequency"]
time_cell_method = _strip_time_cell_method(cfvarinfo)
if outpath:
dataset_table["outpath"] = outpath

ds_prep = prepare_variable(
ds,
Expand All @@ -590,8 +629,6 @@ def cmorize_variable(
mapping_table=mapping_table,
replace_coords=replace_coords,
input_freq=input_freq,
cf_freq=cf_freq,
time_cell_method=time_cell_method,
rewrite_time_axis=rewrite_time_axis,
time_units=time_units,
allow_resample=allow_resample,
Expand Down
6 changes: 5 additions & 1 deletion cordex/cmor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import numpy as np

options = {"table_prefix": "CORDEX-CMIP6", "exit_control": "CMOR_NORMAL"}
options = {
"table_prefix": "CORDEX-CMIP6",
"exit_control": "CMOR_NORMAL",
"resample_kwargs": {"closed": "left"},
}


# time offsets relative to left labeling for resampling.
Expand Down
41 changes: 37 additions & 4 deletions cordex/cmor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
import datetime as dt
import json
import tempfile
from warnings import warn

import cftime as cfdt
Expand Down Expand Up @@ -290,7 +291,7 @@ def _encode_time(time):
return xr.conventions.encode_cf_variable(time)


def _read_cmor_table(table):
def _read_table(table):
return _read_json_file(table)


Expand All @@ -300,9 +301,41 @@ def _read_json_file(filename):
return data


def _get_cfvarinfo(cf_varname, table):
data = _read_cmor_table(table)
return data["variable_entry"].get(cf_varname, None)
def _write_json_file(filename, data):
with open(filename, "w") as fp:
json.dump(data, fp, indent=4)
return filename


def _get_cfvarinfo(out_name, table):
"""Returns variable entry from cmor table"""
if isinstance(table, str):
table = _read_table(table)
info = table["variable_entry"].get(out_name, None)
if info is None:
raise Exception(
"{} not found in table {}".format(out_name, get_table_id(table))
)
return info


def get_table_id(table):
"""parse the table_id from a cmor table header"""
separator = " "
table_id = table["Header"].get("table_id", None)
if table_id is None:
raise Exception("no table_id in Header")
if separator in table_id:
return table_id.split(separator)[1]
return table_id


def _tmp_table(table, format="json"):
"""creates a temporay table json file"""
_, filename = tempfile.mkstemp()
warn(f"writing temporary table to {filename}")
if format == "json":
return _write_json_file(filename, table)


def _get_time_cell_method(cf_varname, table):
Expand Down
9 changes: 9 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ What's New

import cordex

v0.6.0 (Unreleased)
-------------------

Internal Changes
~~~~~~~~~~~~~~~~

- CMOR updates, including fixing of time step warnings and resampling options, includes options for using `flox <https://flox.readthedocs.io>`_ in resampling operations (:pull:`116`).


v0.5.1 (2 March 2023)
---------------------

Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ def _importorskip(modname):
has_regionmask, requires_regionmask = _importorskip("regionmask")
has_xesmf, requires_xesmf = _importorskip("xesmf")
has_geopandas, requires_geopandas = _importorskip("geopandas")
has_pint_xarray, requires_pint_xarray = _importorskip("pint_xarray")
Loading