Skip to content

Commit

Permalink
Fixes preserving coordinates in regrid2 (#716)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Vo <tomvothecoder@gmail.com>
  • Loading branch information
jasonb5 and tomvothecoder authored Jan 29, 2025
1 parent 5cc9d23 commit 84e4a3f
Showing 1 changed file with 76 additions and 22 deletions.
98 changes: 76 additions & 22 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import xarray as xr

import xcdat as xc
from xcdat.axis import get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds

Expand Down Expand Up @@ -105,8 +106,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
ds,
data_var,
output_data,
dst_lat_bnds,
dst_lon_bnds,
self._input_grid,
self._output_grid,
)
Expand Down Expand Up @@ -228,38 +227,90 @@ def _build_dataset(
ds: xr.Dataset,
data_var: str,
output_data: np.ndarray,
dst_lat_bnds,
dst_lon_bnds,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
) -> xr.Dataset:
input_data_var = ds[data_var]
"""Build a new xarray Dataset with the given output data and coordinates.
Parameters
----------
ds : xr.Dataset
The input dataset containing the data variable to be regridded.
data_var : str
The name of the data variable in the input dataset to be regridded.
output_data : np.ndarray
The regridded data to be included in the output dataset.
input_grid : xr.Dataset
The input grid dataset containing the original grid information.
output_grid : xr.Dataset
The output grid dataset containing the new grid information.
output_coords: dict[str, xr.DataArray] = {}
output_data_vars: dict[str, xr.DataArray] = {}
Returns
-------
xr.Dataset
A new dataset containing the regridded data variable with updated
coordinates and attributes.
"""
dv_input = ds[data_var]

dims = list(input_data_var.dims)
output_coords = _get_output_coords(dv_input, output_grid)

output_da = xr.DataArray(
output_data,
dims=dims,
dims=dv_input.dims,
coords=output_coords,
attrs=ds[data_var].attrs.copy(),
name=data_var,
)

output_data_vars[data_var] = output_da

output_ds = xr.Dataset(
output_data_vars,
attrs=input_grid.attrs.copy(),
)

output_ds = output_da.to_dataset()
output_ds.attrs = input_grid.attrs.copy()
output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"])

return output_ds


def _get_output_coords(
dv_input: xr.DataArray, output_grid: xr.Dataset
) -> Dict[str, xr.DataArray]:
"""
Generate the output coordinates for regridding based on the input data
variable and output grid.
Parameters
----------
dv_input : xr.DataArray
The input data variable containing the original coordinates.
output_grid : xr.Dataset
The dataset containing the target grid coordinates.
Returns
-------
Dict[str, xr.DataArray]
A dictionary where keys are coordinate names and values are the
corresponding coordinates from the output grid or input data variable,
aligned with the dimensions of the input data variable.
"""
output_coords: Dict[str, xr.DataArray] = {}

# First get the X and Y axes from the output grid.
for key in ["X", "Y"]:
input_coord = xc.get_dim_coords(dv_input, key) # type: ignore
output_coord = xc.get_dim_coords(output_grid, key) # type: ignore

output_coords[str(input_coord.name)] = output_coord # type: ignore

# Get the remaining axes the input data variable (e.g., "time").
for dim in dv_input.dims:
if dim not in output_coords:
output_coords[str(dim)] = dv_input[dim]

# Sort the coords to align with the input data variable dims.
output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims}

return output_coords


def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
Expand Down Expand Up @@ -553,12 +604,15 @@ def _get_dimension(input_data_var, cf_axis_name):


def _get_bounds_ensure_dtype(ds, axis):
bounds = None

try:
name = ds.cf.bounds[axis][0]
except (KeyError, IndexError) as e:
raise RuntimeError(f"Could not determine {axis!r} bounds") from e
else:
bounds = ds[name]
bounds = ds.bounds.get_bounds(axis)
except KeyError:
pass

if bounds is None:
raise RuntimeError(f"Could not determine {axis!r} bounds")

if bounds.dtype != np.float32:
bounds = bounds.astype(np.float32)
Expand Down

0 comments on commit 84e4a3f

Please sign in to comment.