Skip to content

Commit

Permalink
Merge pull request #323 from maartenvanormondt/load_dataset-added
Browse files Browse the repository at this point in the history
added load_dataset method
  • Loading branch information
Huite authored Feb 17, 2025
2 parents cce3c48 + a1dc5dd commit 3e2f8cf
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ Top-level functions
.. autosummary::
:toctree: api/

load_dataarray
open_dataarray
load_dataset
open_dataset
open_mfdataset
open_zarr
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Added
``x_bounds`` and ``y_bounds`` are provided.
- Added :attr:`xugrid.Ugrid1d.directed_edge_edge_connectivity` and
:attr:`xugrid.Ugrid2d.directed_edge_edge_connectivity`.
- Added :func:`xugrid.load_dataset` and :func:`xugrid.load_dataarray`.

[0.12.2] 2025-01-31
-------------------
Expand Down
26 changes: 26 additions & 0 deletions tests/test_ugrid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,18 @@ def test_open_dataset(tmp_path):
assert "mesh2d_face_nodes" not in back.ugrid.obj


def test_load_dataset(tmp_path):
path = tmp_path / "ugrid-dataset.nc"
uds = xugrid.UgridDataset(UGRID_DS())
uds.ugrid.to_netcdf(path)

back = xugrid.load_dataset(path)
assert isinstance(back, xugrid.UgridDataset)
assert "b" in back
assert "mesh2d_face_nodes" in back.ugrid.grids[0].to_dataset()
assert "mesh2d_face_nodes" not in back.ugrid.obj


def test_open_dataset_cast_invalid(tmp_path):
grid = GRID()
vorgrid = grid.tesselate_centroidal_voronoi()
Expand All @@ -1040,6 +1052,20 @@ def test_open_dataarray_roundtrip(tmp_path):
assert back.name == "a"


def test_load_dataarray_roundtrip(tmp_path):
path = tmp_path / "ugrid-dataset.nc"
uds = xugrid.UgridDataset(UGRID_DS())
uds.ugrid.to_netcdf(path)
with pytest.raises(ValueError, match="Given file dataset contains more than one"):
xugrid.load_dataarray(path)

path = tmp_path / "ugrid-dataarray.nc"
uds["a"].ugrid.to_netcdf(path)
back = xugrid.load_dataarray(path)
assert isinstance(back, xugrid.UgridDataArray)
assert back.name == "a"


def test_open_mfdataset(tmp_path):
path1 = tmp_path / "ugrid-dataset_1.nc"
path2 = tmp_path / "ugrid-dataset_2.nc"
Expand Down
4 changes: 4 additions & 0 deletions xugrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from xugrid.core.common import (
concat,
full_like,
load_dataarray,
load_dataset,
merge,
ones_like,
open_dataarray,
Expand Down Expand Up @@ -38,6 +40,8 @@
"data",
"concat",
"full_like",
"load_dataarray",
"load_dataset",
"merge",
"ones_like",
"open_dataarray",
Expand Down
21 changes: 18 additions & 3 deletions xugrid/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ def open_dataset(*args, **kwargs):
return UgridDataset(ds)


def open_dataarray(*args, **kwargs):
ds = xr.open_dataset(*args, **kwargs)
dataset = UgridDataset(ds)
def load_dataset(*args, **kwargs):
ds = xr.load_dataset(*args, **kwargs)
return UgridDataset(ds)


def _dataarray_helper(ds: xr.Dataset):
dataset = UgridDataset(ds)
if len(dataset.data_vars) != 1:
raise ValueError(
"Given file dataset contains more than one data "
Expand All @@ -41,6 +44,16 @@ def open_dataarray(*args, **kwargs):
return UgridDataArray(data_array, dataset.grid)


def load_dataarray(*args, **kwargs):
ds = xr.load_dataset(*args, **kwargs)
return _dataarray_helper(ds)


def open_dataarray(*args, **kwargs):
ds = xr.open_dataset(*args, **kwargs)
return _dataarray_helper(ds)


def open_mfdataset(*args, **kwargs):
if "data_vars" in kwargs:
raise ValueError("data_vars kwargs is not supported in xugrid.open_mfdataset")
Expand All @@ -54,7 +67,9 @@ def open_zarr(*args, **kwargs):
return UgridDataset(ds)


load_dataset.__doc__ = xr.load_dataset.__doc__
open_dataset.__doc__ = xr.open_dataset.__doc__
load_dataarray.__doc__ = xr.load_dataarray.__doc__
open_dataarray.__doc__ = xr.open_dataarray.__doc__
open_mfdataset.__doc__ = xr.open_mfdataset.__doc__
open_zarr.__doc__ = xr.open_zarr.__doc__
Expand Down
4 changes: 2 additions & 2 deletions xugrid/core/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def from_structured2d(
2. "longitude" and "latitude" dimensions
3. "axis" attributes of "X" or "Y" on coordinates
4. "standard_name" attributes of "longitude", "latitude",
"projection_x_coordinate", or "projection_y_coordinate" on coordinate
variables
"projection_x_coordinate", or "projection_y_coordinate" on coordinate
variables
Parameters
----------
Expand Down

0 comments on commit 3e2f8cf

Please sign in to comment.