diff --git a/docs/api.rst b/docs/api.rst index 05598896..f0d2c5e9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -21,7 +21,9 @@ Top-level functions .. autosummary:: :toctree: api/ + load_dataarray open_dataarray + load_dataset open_dataset open_mfdataset open_zarr diff --git a/docs/changelog.rst b/docs/changelog.rst index ba7b3f69..fb264745 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,6 +32,7 @@ Added ``x_bounds`` and ``y_bounds`` are provided. - Added :attr:`xugrid.Ugrid1d.directed_edge_edge_connectivity` and :attr:`xugrid.Ugrid2d.directed_edge_edge_connectivity`. +- Added :func:`xugrid.load_dataset` and :func:`xugrid.load_dataarray`. [0.12.2] 2025-01-31 ------------------- diff --git a/tests/test_ugrid_dataset.py b/tests/test_ugrid_dataset.py index 3327506c..bb56fe37 100644 --- a/tests/test_ugrid_dataset.py +++ b/tests/test_ugrid_dataset.py @@ -1015,6 +1015,18 @@ def test_open_dataset(tmp_path): assert "mesh2d_face_nodes" not in back.ugrid.obj +def test_load_dataset(tmp_path): + path = tmp_path / "ugrid-dataset.nc" + uds = xugrid.UgridDataset(UGRID_DS()) + uds.ugrid.to_netcdf(path) + + back = xugrid.load_dataset(path) + assert isinstance(back, xugrid.UgridDataset) + assert "b" in back + assert "mesh2d_face_nodes" in back.ugrid.grids[0].to_dataset() + assert "mesh2d_face_nodes" not in back.ugrid.obj + + def test_open_dataset_cast_invalid(tmp_path): grid = GRID() vorgrid = grid.tesselate_centroidal_voronoi() @@ -1040,6 +1052,20 @@ def test_open_dataarray_roundtrip(tmp_path): assert back.name == "a" +def test_load_dataarray_roundtrip(tmp_path): + path = tmp_path / "ugrid-dataset.nc" + uds = xugrid.UgridDataset(UGRID_DS()) + uds.ugrid.to_netcdf(path) + with pytest.raises(ValueError, match="Given file dataset contains more than one"): + xugrid.load_dataarray(path) + + path = tmp_path / "ugrid-dataarray.nc" + uds["a"].ugrid.to_netcdf(path) + back = xugrid.load_dataarray(path) + assert isinstance(back, xugrid.UgridDataArray) + assert back.name == "a" + + def test_open_mfdataset(tmp_path): path1 = tmp_path / "ugrid-dataset_1.nc" path2 = tmp_path / "ugrid-dataset_2.nc" diff --git a/xugrid/__init__.py b/xugrid/__init__.py index 77290c4c..cf65efab 100644 --- a/xugrid/__init__.py +++ b/xugrid/__init__.py @@ -2,6 +2,7 @@ from xugrid.core.common import ( concat, full_like, + load_dataarray, load_dataset, merge, ones_like, @@ -39,11 +40,12 @@ "data", "concat", "full_like", + "load_dataarray", + "load_dataset", "merge", "ones_like", "open_dataarray", "open_dataset", - "load_dataset", "open_mfdataset", "open_zarr", "zeros_like", diff --git a/xugrid/core/common.py b/xugrid/core/common.py index 5106cc06..e2bbdbf6 100644 --- a/xugrid/core/common.py +++ b/xugrid/core/common.py @@ -19,10 +19,8 @@ def load_dataset(*args, **kwargs): return UgridDataset(ds) -def open_dataarray(*args, **kwargs): - ds = xr.open_dataset(*args, **kwargs) +def _dataarray_helper(ds: xr.Dataset): dataset = UgridDataset(ds) - if len(dataset.data_vars) != 1: raise ValueError( "Given file dataset contains more than one data " @@ -46,6 +44,16 @@ def open_dataarray(*args, **kwargs): return UgridDataArray(data_array, dataset.grid) +def load_dataarray(*args, **kwargs): + ds = xr.load_dataset(*args, **kwargs) + return _dataarray_helper(ds) + + +def open_dataarray(*args, **kwargs): + ds = xr.open_dataset(*args, **kwargs) + return _dataarray_helper(ds) + + def open_mfdataset(*args, **kwargs): if "data_vars" in kwargs: raise ValueError("data_vars kwargs is not supported in xugrid.open_mfdataset") @@ -59,7 +67,9 @@ def open_zarr(*args, **kwargs): return UgridDataset(ds) +load_dataset.__doc__ = xr.load_dataset.__doc__ open_dataset.__doc__ = xr.open_dataset.__doc__ +load_dataarray.__doc__ = xr.load_dataarray.__doc__ open_dataarray.__doc__ = xr.open_dataarray.__doc__ open_mfdataset.__doc__ = xr.open_mfdataset.__doc__ open_zarr.__doc__ = xr.open_zarr.__doc__ diff --git a/xugrid/core/wrap.py b/xugrid/core/wrap.py index 04d66bd8..60dd007f 100644 --- a/xugrid/core/wrap.py +++ b/xugrid/core/wrap.py @@ -252,8 +252,8 @@ def from_structured2d( 2. "longitude" and "latitude" dimensions 3. "axis" attributes of "X" or "Y" on coordinates 4. "standard_name" attributes of "longitude", "latitude", - "projection_x_coordinate", or "projection_y_coordinate" on coordinate - variables + "projection_x_coordinate", or "projection_y_coordinate" on coordinate + variables Parameters ----------