Skip to content

Commit

Permalink
Fix group kwarg (#338)
Browse files Browse the repository at this point in the history
* clarify when to update the release notes

* say to check PyPI and conda-forge

* social media

* can release at any time

* convince myself that opening a single group does work fine

* reproduce errors from issue

* change default behaviour to try to open root group

* simplify new tests

* should fix group issue...

* double check we can definitely open empty groups correctly

* found and fixed the bug

* change test to reflect new expected behaviour

* remove now-nonsensical test

* fix failing kerchunk parquet test

* remove print statements

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* mark tests as needed kerchunk dependency

* release notes

* same fix but for the other kerchunk-based readers

* note breaking change

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
TomNicholas and pre-commit-ci[bot] authored Dec 10, 2024
1 parent af9c374 commit fcdd5e4
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 48 deletions.
23 changes: 23 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ def pytest_runtest_setup(item):
)


@pytest.fixture
def empty_netcdf4_file(tmpdir):
# Set up example xarray dataset
ds = xr.Dataset() # Save it to disk as netCDF (in temporary directory)
filepath = f"{tmpdir}/empty.nc"
ds.to_netcdf(filepath, format="NETCDF4")
ds.close()

return filepath


@pytest.fixture
def netcdf4_file(tmpdir):
# Set up example xarray dataset
Expand All @@ -37,6 +48,18 @@ def netcdf4_file(tmpdir):
return filepath


@pytest.fixture
def netcdf4_file_with_data_in_multiple_groups(tmpdir):
filepath = str(tmpdir / "test.nc")

ds1 = xr.DataArray([1, 2, 3], name="foo").to_dataset()
ds1.to_netcdf(filepath)
ds2 = xr.DataArray([4, 5], name="bar").to_dataset()
ds2.to_netcdf(filepath, group="subgroup", mode="a")

return filepath


@pytest.fixture
def netcdf4_files_factory(tmpdir) -> callable:
def create_netcdf4_files(
Expand Down
7 changes: 7 additions & 0 deletions docs/releases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@ New Features
Breaking changes
~~~~~~~~~~~~~~~~

- Passing ``group=None`` (the default) to ``open_virtual_dataset`` for a file with multiple groups no longer raises an error, instead it gives you the root group.
This new behaviour is more consistent with ``xarray.open_dataset``.
(:issue:`336`, :pull:`337`) By `Tom Nicholas <https://github.com/TomNicholas>`_.

Deprecations
~~~~~~~~~~~~

Bug fixes
~~~~~~~~~

- Fix bug preventing generating references for the root group of a file when a subgroup exists.
(:issue:`336`, :pull:`337`) By `Tom Nicholas <https://github.com/TomNicholas>`_.

Documentation
~~~~~~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion virtualizarr/readers/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def open_virtual_dataset(
# handle inconsistency in kerchunk, see GH issue https://github.com/zarr-developers/VirtualiZarr/issues/160
refs = KerchunkStoreRefs({"refs": process_file(filepath, **reader_options)})

refs = extract_group(refs, group)
# both group=None and group='' mean to read root group
if group:
refs = extract_group(refs, group)

virtual_vars, attrs, coord_names = virtual_vars_and_metadata_from_kerchunk_refs(
refs,
Expand Down
4 changes: 3 additions & 1 deletion virtualizarr/readers/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def open_virtual_dataset(
filepath, inline_threshold=0, **reader_options
).translate()

refs = extract_group(refs, group)
# both group=None and group='' mean to read root group
if group:
refs = extract_group(refs, group)

virtual_vars, attrs, coord_names = virtual_vars_and_metadata_from_kerchunk_refs(
refs,
Expand Down
7 changes: 5 additions & 2 deletions virtualizarr/readers/netcdf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
open_loadable_vars_and_indexes,
)
from virtualizarr.translators.kerchunk import (
extract_group,
virtual_vars_and_metadata_from_kerchunk_refs,
)
from virtualizarr.utils import check_for_collisions
Expand Down Expand Up @@ -41,7 +40,11 @@ def open_virtual_dataset(

refs = NetCDF3ToZarr(filepath, inline_threshold=0, **reader_options).translate()

refs = extract_group(refs, group)
# both group=None and group='' mean to read root group
if group:
raise ValueError(
"group kwarg passed, but netCDF3 files can't have multiple groups!"
)

virtual_vars, attrs, coord_names = virtual_vars_and_metadata_from_kerchunk_refs(
refs,
Expand Down
4 changes: 3 additions & 1 deletion virtualizarr/readers/tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def open_virtual_dataset(
# handle inconsistency in kerchunk, see GH issue https://github.com/zarr-developers/VirtualiZarr/issues/160
refs = KerchunkStoreRefs({"refs": tiff_to_zarr(filepath, **reader_options)})

refs = extract_group(refs, group)
# both group=None and group='' mean to read root group
if group:
refs = extract_group(refs, group)

virtual_vars, attrs, coord_names = virtual_vars_and_metadata_from_kerchunk_refs(
refs,
Expand Down
41 changes: 38 additions & 3 deletions virtualizarr/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import xarray as xr
import xarray.testing as xrt
from xarray import open_dataset
from xarray import Dataset, open_dataset
from xarray.core.indexes import Index

from virtualizarr import open_virtual_dataset
Expand Down Expand Up @@ -309,6 +309,43 @@ def test_virtualizarr_vs_local_nisar(self, hdf_backend):
xrt.assert_equal(dsXR, dsV)


@requires_kerchunk
def test_open_empty_group(empty_netcdf4_file):
vds = open_virtual_dataset(empty_netcdf4_file, indexes={})
assert isinstance(vds, xr.Dataset)
expected = Dataset()
xrt.assert_identical(vds, expected)


@requires_kerchunk
class TestOpenVirtualDatasetHDFGroup:
def test_open_subgroup(self, netcdf4_file_with_data_in_multiple_groups):
vds = open_virtual_dataset(
netcdf4_file_with_data_in_multiple_groups, group="subgroup", indexes={}
)
assert list(vds.variables) == ["bar"]
assert isinstance(vds["bar"].data, ManifestArray)
assert vds["bar"].shape == (2,)

def test_open_root_group_manually(self, netcdf4_file_with_data_in_multiple_groups):
vds = open_virtual_dataset(
netcdf4_file_with_data_in_multiple_groups, group="", indexes={}
)
assert list(vds.variables) == ["foo"]
assert isinstance(vds["foo"].data, ManifestArray)
assert vds["foo"].shape == (3,)

def test_open_root_group_by_default(
self, netcdf4_file_with_data_in_multiple_groups
):
vds = open_virtual_dataset(
netcdf4_file_with_data_in_multiple_groups, indexes={}
)
assert list(vds.variables) == ["foo"]
assert isinstance(vds["foo"].data, ManifestArray)
assert vds["foo"].shape == (3,)


@requires_kerchunk
class TestLoadVirtualDataset:
@pytest.mark.parametrize("hdf_backend", [HDF5VirtualBackend, HDFVirtualBackend])
Expand Down Expand Up @@ -356,8 +393,6 @@ def test_group_kwarg(self, hdf5_groups_file, hdf_backend):
hdf5_groups_file, group="doesnt_exist", backend=hdf_backend
)
if hdf_backend == HDF5VirtualBackend:
with pytest.raises(ValueError, match="Multiple HDF Groups found"):
open_virtual_dataset(hdf5_groups_file)
with pytest.raises(ValueError, match="not found in"):
open_virtual_dataset(hdf5_groups_file, group="doesnt_exist")

Expand Down
7 changes: 0 additions & 7 deletions virtualizarr/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from virtualizarr.tests import requires_kerchunk
from virtualizarr.translators.kerchunk import (
dataset_from_kerchunk_refs,
find_var_names,
)
from virtualizarr.zarr import ZArray

Expand Down Expand Up @@ -49,12 +48,6 @@ def test_kerchunk_roundtrip_in_memory_no_concat():
xrt.assert_equal(roundtrip, ds)


def test_no_duplicates_find_var_names():
"""Verify that we get a deduplicated list of var names"""
ref_dict = {"refs": {"x/something": {}, "x/otherthing": {}}}
assert len(find_var_names(ref_dict)) == 1


@requires_kerchunk
@pytest.mark.parametrize(
"inline_threshold, vars_to_inline",
Expand Down
75 changes: 42 additions & 33 deletions virtualizarr/translators/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,42 +43,43 @@ def virtual_vars_and_metadata_from_kerchunk_refs(
return virtual_vars, ds_attrs, coord_names


def extract_group(vds_refs: KerchunkStoreRefs, group: str | None) -> KerchunkStoreRefs:
"""Extract only the part of the kerchunk reference dict that is relevant to a single HDF group"""
def extract_group(vds_refs: KerchunkStoreRefs, group: str) -> KerchunkStoreRefs:
"""
Extract only the part of the kerchunk reference dict that is relevant to a single HDF group.
Parameters
----------
vds_refs : KerchunkStoreRefs
group : str
Should be a non-empty string
"""
hdf_groups = [
k.removesuffix(".zgroup") for k in vds_refs["refs"].keys() if ".zgroup" in k
]
if len(hdf_groups) == 1:
return vds_refs
else:
if group is None:
raise ValueError(
f"Multiple HDF Groups found. Must specify group= keyword to select one of {hdf_groups}"
)
else:
# Ensure supplied group kwarg is consistent with kerchunk keys
if not group.endswith("/"):
group += "/"
if group.startswith("/"):
group = group.removeprefix("/")

if group not in hdf_groups:
raise ValueError(f'Group "{group}" not found in {hdf_groups}')

# Filter by group prefix and remove prefix from all keys
groupdict = {
k.removeprefix(group): v
for k, v in vds_refs["refs"].items()
if k.startswith(group)
}
# Also remove group prefix from _ARRAY_DIMENSIONS
for k, v in groupdict.items():
if isinstance(v, str):
groupdict[k] = v.replace("\\/", "/").replace(group, "")

vds_refs["refs"] = groupdict
# Ensure supplied group kwarg is consistent with kerchunk keys
if not group.endswith("/"):
group += "/"
if group.startswith("/"):
group = group.removeprefix("/")

if group not in hdf_groups:
raise ValueError(f'Group "{group}" not found in {hdf_groups}')

# Filter by group prefix and remove prefix from all keys
groupdict = {
k.removeprefix(group): v
for k, v in vds_refs["refs"].items()
if k.startswith(group)
}
# Also remove group prefix from _ARRAY_DIMENSIONS
for k, v in groupdict.items():
if isinstance(v, str):
groupdict[k] = v.replace("\\/", "/").replace(group, "")

vds_refs["refs"] = groupdict

return KerchunkStoreRefs(vds_refs)
return KerchunkStoreRefs(vds_refs)


def virtual_vars_from_kerchunk_refs(
Expand Down Expand Up @@ -222,9 +223,17 @@ def find_var_names(ds_reference_dict: KerchunkStoreRefs) -> list[str]:
"""Find the names of zarr variables in this store/group."""

refs = ds_reference_dict["refs"]
found_var_names = {key.split("/")[0] for key in refs.keys() if "/" in key}

return list(found_var_names)
found_var_names = []
for key in refs.keys():
# has to capture "foo/.zarray", but ignore ".zgroup", ".zattrs", and "subgroup/bar/.zarray"
# TODO this might be a sign that we should introduce a KerchunkGroupRefs type and cut down the references before getting to this point...
if key not in (".zgroup", ".zattrs", ".zmetadata"):
first_part, second_part, *_ = key.split("/")
if second_part == ".zarray":
found_var_names.append(first_part)

return found_var_names


def extract_array_refs(
Expand Down

0 comments on commit fcdd5e4

Please sign in to comment.