Skip to content

Commit

Permalink
FIX: encoding typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Abel Aoun committed Jan 11, 2024
1 parent 5c66563 commit d62ac29
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _check_encoding_dtype_is_vlen_string(dtype):

def _get_datatype(
var, nc_format="NETCDF4", raise_on_invalid_encoding=False
) -> np.dtype | None:
) -> np.dtype:
if nc_format == "NETCDF4":
return _nc4_dtype(var)
if "dtype" in var.encoding:
Expand Down Expand Up @@ -421,7 +421,7 @@ def open_store_variable(self, name: str, var):
dimensions = var.dimensions
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
encoding = {}
encoding: dict[str, Any] = {}
if isinstance(var.datatype, netCDF4.EnumType):
encoding["dtype"] = np.dtype(
data.dtype,
Expand Down Expand Up @@ -507,11 +507,11 @@ def prepare_variable(
)
# check enum metadata and use netCDF4.EnumType
if (
np.dtype(datatype).metadata
and datatype.metadata.get("enum_name")
and datatype.metadata.get("enum")
(meta := np.dtype(datatype).metadata)
and (e_name := meta.get("enum_name"))
and (e_dict := meta.get("enum"))
):
datatype = self._build_and_get_enum(name, datatype)
datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
encoding = _extract_nc4_variable_encoding(
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)
Expand Down Expand Up @@ -542,10 +542,14 @@ def prepare_variable(

return target, variable.data

def _build_and_get_enum(self, var_name: str, dtype: np.dtype) -> object:
"""Add or get the netCDF4 Enum based on the dtype in encoding."""
enum_dict = dtype.metadata["enum"]
enum_name = dtype.metadata["enum_name"]
def _build_and_get_enum(
self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
) -> Any:
"""
Add or get the netCDF4 Enum based on the dtype in encoding.
The return type should be ``netCDF4.EnumType``,
but we avoid importing netCDF4 globally for performances.
"""
if enum_name not in self.ds.enumtypes:
return self.ds.createEnumType(
dtype,
Expand Down

0 comments on commit d62ac29

Please sign in to comment.