Skip to content

Commit

Permalink
Do not add scalar coords from the target grid to the regridding output (
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal authored Feb 26, 2025
1 parent fceddc8 commit 87a30ce
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 16 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ repos:
rev: 7.1.1
hooks:
- id: flake8

- repo: https://github.com/PyCQA/isort
rev: 6.0.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ What's new
0.8.9 (unreleased)
------------------
* Destroy grids explicitly once weights are computed. Do not store them in `grid_in` and `grid_out` attributes. This fixes segmentation faults introduced by the memory fix of last version. By `Pascal Bourgault <https://github.com/aulemahal>`_.
* Do not add scalar coordinates of the target grid to the regridded output (:issue:`417`, :pull:`418`). `xe.Regridder.out_coords` is now a dataset instead of a dictionary. By `Pascal Bourgault <https://github.com/aulemahal>`_.

0.8.8 (2024-11-01)
------------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ tag_regex = "^(?P<prefix>v)?(?P<version>[^\\+]+)(?P<suffix>.*)?$"
[tool.black]
line-length = 100
target-version = [
'py310',
'py311',
]
skip-string-normalization = true

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ ignore =
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9
extend-ignore = E203,E501,E402,W605
extend-ignore = E203,E501,E402,W503,W605
33 changes: 21 additions & 12 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,18 +962,23 @@ def __init__(
self.out_horiz_dims = (lat_out.dims[0], lon_out.dims[0])

if isinstance(ds_out, Dataset):
self.out_coords = {
name: crd
for name, crd in ds_out.coords.items()
if set(self.out_horiz_dims).issuperset(crd.dims)
}
out_coords = ds_out.coords.to_dataset()
grid_mapping = {
var.attrs['grid_mapping']
for var in ds_out.data_vars.values()
if 'grid_mapping' in var.attrs
}
if grid_mapping:
self.out_coords.update({gm: ds_out[gm] for gm in grid_mapping if gm in ds_out})
# to keep : grid_mappings and non-scalar coords that have the spatial dims
self.out_coords = out_coords.drop_vars(
[
name
for name, crd in out_coords.coords.items()
if not (
(name in grid_mapping)
or (len(crd.dims) > 0 and set(self.out_horiz_dims).issuperset(crd.dims))
)
]
)
else:
self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out}

Expand Down Expand Up @@ -1055,10 +1060,14 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs):
chunks = out_chunks | in_chunks

# Rename coords to avoid issues in xr.map_blocks
for coord in list(self.out_coords.keys()):
# If coords and dims are the same, renaming has already been done.
if coord not in self.out_horiz_dims:
ds_out = ds_out.rename({coord: coord + '_out'})
# If coords and dims are the same, renaming has already been done.
ds_out = ds_out.rename(
{
coord: coord + '_out'
for coord in self.out_coords.coords.keys()
if coord not in self.out_horiz_dims
}
)

weights_dims = ('y_out', 'x_out', 'y_in', 'x_in')
templ = sps.zeros((self.shape_out + self.shape_in))
Expand Down Expand Up @@ -1102,7 +1111,7 @@ def _format_xroutput(self, out, new_dims=None):
# rename dimension name to match output grid
out = out.rename({nd: od for nd, od in zip(new_dims, self.out_horiz_dims)})

out = out.assign_coords(**self.out_coords)
out = out.assign_coords(self.out_coords.coords)
out.attrs['regrid_method'] = self.method

if self.sequence_out:
Expand Down
9 changes: 7 additions & 2 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,17 +732,22 @@ def test_regrid_dataset_extracoords():
x=np.arange(24),
y=np.arange(20), # coords to be transfered
latitude_longitude=xr.DataArray(), # grid_mapping
bogus=ds_out.lev * ds_out.lon, # coord not to be transfered
bogus=ds_out.lev * ds_out.lon, # coords not to be transfered
scalar1=1, #
scalar2=1, #
)
ds_out2['data_ref'].attrs['grid_mapping'] = 'latitude_longitude'
ds_out2['data4D_ref'].attrs['grid_mapping'] = 'latitude_longitude'

ds_in2 = ds_in.assign_coords(scalar2=5)
regridder = xe.Regridder(ds_in, ds_out2, 'conservative')
ds_result = regridder(ds_in)
ds_result = regridder(ds_in2)

assert 'x' in ds_result.coords
assert 'y' in ds_result.coords
assert 'bogus' not in ds_result.coords
assert 'scalar1' not in ds_result.coords
assert ds_result.scalar2 == 5
assert 'latitude_longitude' in ds_result.coords


Expand Down

0 comments on commit 87a30ce

Please sign in to comment.