Skip to content

Commit

Permalink
ENH: Avoid encoding enum to CF
Browse files Browse the repository at this point in the history
  • Loading branch information
Abel Aoun committed Jan 9, 2024
1 parent 26bb8ce commit d21d73a
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 37 deletions.
Binary file added toto.nc
Binary file not shown.
18 changes: 7 additions & 11 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from collections.abc import Iterable
from contextlib import suppress
from enum import Enum
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -421,11 +422,7 @@ def open_store_variable(self, name: str, var):
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
if isinstance(var.datatype, netCDF4.EnumType):
enum_dict = var.datatype.enum_dict
enum_name = var.datatype.name
attributes["enum"] = enum_name
attributes["flag_values"] = tuple(enum_dict.values())
attributes["flag_meanings"] = tuple(enum_dict.keys())
attributes["enum"] = Enum(var.datatype.name, var.datatype.enum_dict)
_ensure_fill_value_valid(data, attributes)
# netCDF4 specific encoding; save _FillValue for later
encoding = {}
Expand Down Expand Up @@ -537,20 +534,19 @@ def prepare_variable(
def _build_and_get_enum(
self, var_name: str, attributes: dict, dtype: np.dtype
) -> object:
flag_meanings = attributes.pop("flag_meanings")
flag_values = attributes.pop("flag_values")
enum_name = attributes.pop("enum")
enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)}
enum = attributes.pop("enum")
enum_dict = {e.name: e.value for e in enum}
enum_name = enum.__name__
if enum_name in self.ds.enumtypes:
datatype = self.ds.enumtypes[enum_name]
if datatype.enum_dict != enum_dict:
error_msg = (
f"Cannot save variable `{var_name}` because an enum"
f" `{enum_name}` already exists in the Dataset but have"
" a different definition. To fix this error, make sure"
" each variable have a unique name for their `attrs['enum']`"
" each variable have a unique name in `attrs['enum']`"
" or, if they should share same enum type, make sure"
" their flag_values and flag_meanings are identical."
" the enums are identical."
)
raise ValueError(error_msg)
else:
Expand Down
16 changes: 4 additions & 12 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import warnings
from collections.abc import Hashable, MutableMapping
from enum import Enum, EnumMeta
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union

Expand Down Expand Up @@ -567,7 +567,7 @@ def decode(self):

class ObjectVLenStringCoder(VariableCoder):
def encode(self):
return NotImplementedError
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
Expand All @@ -578,18 +578,10 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:


class EnumCoder(VariableCoder):
"""Encode and decode Enum to CF"""
"""Decode CF flag_* to python Enum"""

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
"""From python Enum to CF flag_*"""
dims, data, attrs, encoding = unpack_for_encoding(variable)
if isinstance(attrs.get("enum"), EnumMeta):
enum = attrs.pop("enum")
enum_name = enum.__name__
attrs["flag_meanings"] = " ".join(i.name for i in enum)
attrs["flag_values"] = ", ".join(str(i.value) for i in enum)
attrs["enum"] = enum_name
return Variable(dims, data, attrs, encoding, fastpath=True)
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
"""From CF flag_* to python Enum"""
Expand Down
2 changes: 0 additions & 2 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def encode_cf_variable(
- Rescaling via: scale_factor and add_offset
- datetimes are converted to the CF 'units since time' format
- dtype encodings are enforced.
- enum is turned into flag_values and flag_meanings
Parameters
----------
Expand All @@ -188,7 +187,6 @@ def encode_cf_variable(
variables.NonStringCoder(),
variables.DefaultFillvalueCoder(),
variables.BooleanCoder(),
variables.EnumCoder(),
]:
var = coder.encode(var, name=name)

Expand Down
13 changes: 1 addition & 12 deletions xarray/tests/test_coding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from contextlib import suppress
from enum import Enum, EnumMeta
from enum import EnumMeta

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -164,14 +164,3 @@ def test_decode_enum() -> None:
assert isinstance(decoded.attrs["enum"], EnumMeta)
assert decoded.attrs["enum"].flag.value == 0
assert decoded.attrs["enum"].galf.value == 1


def test_encode_enum() -> None:
decoded = xr.Variable(
("x",), [42], attrs={"enum": Enum("an_enum", {"flag": 0, "galf": 1})}
)
coder = variables.EnumCoder()
encoded = coder.encode(decoded)
assert encoded.attrs["enum"] == "an_enum"
assert encoded.attrs["flag_values"] == "0, 1"
assert encoded.attrs["flag_meanings"] == "flag galf"

0 comments on commit d21d73a

Please sign in to comment.