Skip to content

Commit

Permalink
add abortive dtype class
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Feb 27, 2025
1 parent 64b9a37 commit 7784f21
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ class DataType(Enum):
float64 = "float64"
complex64 = "complex64"
complex128 = "complex128"
fixed_length_byte_string = "fixed_length_byte_string"
fixed_length_unicode_string = "fixed_length_unicode_string"
fixed_length_bytes = "fixed_length_bytes"
string = "string"
bytes = "bytes"

Expand Down Expand Up @@ -637,11 +640,19 @@ def byte_count(self) -> int | None:
# string and bytes have variable length
return None

@property
def has_length(self) -> _bool:
return self in (
DataType.fixed_length_byte_string,
DataType.fixed_length_unicode_string,
DataType.fixed_length_bytes,
)

@property
def has_endianness(self) -> _bool:
return self.byte_count is not None and self.byte_count != 1

def to_numpy_shortname(self) -> str:
def to_numpy_shortname(self, *, length_bytes: int | None = None) -> str:
data_type_to_numpy = {
DataType.bool: "bool",
DataType.int8: "i1",
Expand All @@ -657,10 +668,23 @@ def to_numpy_shortname(self) -> str:
DataType.float64: "f8",
DataType.complex64: "c8",
DataType.complex128: "c16",
DataType.fixed_length_byte_string: "S",
DataType.fixed_length_unicode_string: "U",
DataType.fixed_length_bytes: "r",
}

if self.has_length:
if length_bytes is None:
raise ValueError(
f"Must provide length in bytes to create a numpy dtype from {self}"
)
else:
return data_type_to_numpy[self] + str(length_bytes)
return data_type_to_numpy[self]

def to_numpy(self) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[Any]:
def to_numpy(
self, length_bytes: int | None = None
) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[Any]:
# note: it is not possible to round trip DataType <-> np.dtype
# due to the fact that DataType.string and DataType.bytes both
# generally return np.dtype("O") from this function, even though
Expand All @@ -672,14 +696,19 @@ def to_numpy(self) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[A
# TODO: consider whether we can use fixed-width types (e.g. '|S5') instead
return np.dtype("O")
else:
return np.dtype(self.to_numpy_shortname())
if self == DataType.fixed_length_bytes:
return np.dtype(self.to_numpy_shortname(length_bytes=length_bytes * 8))
else:
return np.dtype(self.to_numpy_shortname(length_bytes=length_bytes))

@classmethod
def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
if dtype.kind in "UT":
if dtype.kind == "U":
return DataType.fixed_length_unicode_string
elif dtype.kind == "T":
return DataType.string
elif dtype.kind == "S":
return DataType.bytes
return DataType.fixed_length_byte_string
elif not _NUMPY_SUPPORTS_VLEN_STRING and dtype.kind == "O":
# numpy < 2.0 does not support vlen string dtype
# so we fall back on object array of strings
Expand Down

0 comments on commit 7784f21

Please sign in to comment.