From 87a30ce5c023312cb41340fe4ee07689227a2c64 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 26 Feb 2025 17:07:24 -0500 Subject: [PATCH] Do not add scalar coords from the target grid to the regridding output (#418) --- .pre-commit-config.yaml | 1 + CHANGES.rst | 1 + pyproject.toml | 2 +- setup.cfg | 2 +- xesmf/frontend.py | 33 +++++++++++++++++++++------------ xesmf/tests/test_frontend.py | 9 +++++++-- 6 files changed, 32 insertions(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0f2ebd39..c4e63398 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,7 @@ repos: rev: 7.1.1 hooks: - id: flake8 + - repo: https://github.com/PyCQA/isort rev: 6.0.0 hooks: diff --git a/CHANGES.rst b/CHANGES.rst index a55028e2..2c82fdc8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,7 @@ What's new 0.8.9 (unreleased) ------------------ * Destroy grids explicitly once weights are computed. Do not store them in `grid_in` and `grid_out` attributes. This fixes segmentation faults introduced by the memory fix of last version. By `Pascal Bourgault `_. +* Do not add scalar coordinates of the target grid to the regridded output (:issue:`417`, :pull:`418`). `xe.Regridder.out_coords` is now a dataset instead of a dictionary. By `Pascal Bourgault `_. 0.8.8 (2024-11-01) ------------------ diff --git a/pyproject.toml b/pyproject.toml index 98908b07..32e32517 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ tag_regex = "^(?Pv)?(?P[^\\+]+)(?P.*)?$" [tool.black] line-length = 100 target-version = [ - 'py310', + 'py311', ] skip-string-normalization = true diff --git a/setup.cfg b/setup.cfg index e6774ae8..4e2cd68e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,4 +7,4 @@ ignore = max-line-length = 100 max-complexity = 18 select = B,C,E,F,W,T4,B9 -extend-ignore = E203,E501,E402,W605 +extend-ignore = E203,E501,E402,W503,W605 diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 989662a9..a5bd611b 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -962,18 +962,23 @@ def __init__( self.out_horiz_dims = (lat_out.dims[0], lon_out.dims[0]) if isinstance(ds_out, Dataset): - self.out_coords = { - name: crd - for name, crd in ds_out.coords.items() - if set(self.out_horiz_dims).issuperset(crd.dims) - } + out_coords = ds_out.coords.to_dataset() grid_mapping = { var.attrs['grid_mapping'] for var in ds_out.data_vars.values() if 'grid_mapping' in var.attrs } - if grid_mapping: - self.out_coords.update({gm: ds_out[gm] for gm in grid_mapping if gm in ds_out}) + # to keep : grid_mappings and non-scalar coords that have the spatial dims + self.out_coords = out_coords.drop_vars( + [ + name + for name, crd in out_coords.coords.items() + if not ( + (name in grid_mapping) + or (len(crd.dims) > 0 and set(self.out_horiz_dims).issuperset(crd.dims)) + ) + ] + ) else: self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out} @@ -1055,10 +1060,14 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs): chunks = out_chunks | in_chunks # Rename coords to avoid issues in xr.map_blocks - for coord in list(self.out_coords.keys()): - # If coords and dims are the same, renaming has already been done. - if coord not in self.out_horiz_dims: - ds_out = ds_out.rename({coord: coord + '_out'}) + # If coords and dims are the same, renaming has already been done. + ds_out = ds_out.rename( + { + coord: coord + '_out' + for coord in self.out_coords.coords.keys() + if coord not in self.out_horiz_dims + } + ) weights_dims = ('y_out', 'x_out', 'y_in', 'x_in') templ = sps.zeros((self.shape_out + self.shape_in)) @@ -1102,7 +1111,7 @@ def _format_xroutput(self, out, new_dims=None): # rename dimension name to match output grid out = out.rename({nd: od for nd, od in zip(new_dims, self.out_horiz_dims)}) - out = out.assign_coords(**self.out_coords) + out = out.assign_coords(self.out_coords.coords) out.attrs['regrid_method'] = self.method if self.sequence_out: diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 4a5a595c..1a394a73 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -732,17 +732,22 @@ def test_regrid_dataset_extracoords(): x=np.arange(24), y=np.arange(20), # coords to be transfered latitude_longitude=xr.DataArray(), # grid_mapping - bogus=ds_out.lev * ds_out.lon, # coord not to be transfered + bogus=ds_out.lev * ds_out.lon, # coords not to be transfered + scalar1=1, # + scalar2=1, # ) ds_out2['data_ref'].attrs['grid_mapping'] = 'latitude_longitude' ds_out2['data4D_ref'].attrs['grid_mapping'] = 'latitude_longitude' + ds_in2 = ds_in.assign_coords(scalar2=5) regridder = xe.Regridder(ds_in, ds_out2, 'conservative') - ds_result = regridder(ds_in) + ds_result = regridder(ds_in2) assert 'x' in ds_result.coords assert 'y' in ds_result.coords assert 'bogus' not in ds_result.coords + assert 'scalar1' not in ds_result.coords + assert ds_result.scalar2 == 5 assert 'latitude_longitude' in ds_result.coords