from __future__ import annotations
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
ClassVar,
Literal,
Self,
TypeGuard,
overload,
)
import numpy as np
from zarr.core.dtype.common import (
DataTypeValidationError,
DTypeConfig_V2,
DTypeJSON,
HasEndianness,
HasItemSize,
check_dtype_spec_v2,
)
from zarr.core.dtype.npy.common import (
ComplexLike,
TComplexDType_co,
TComplexScalar_co,
check_json_complex_float_v2,
check_json_complex_float_v3,
complex_float_from_json_v2,
complex_float_from_json_v3,
complex_float_to_json_v2,
complex_float_to_json_v3,
endianness_to_numpy_str,
get_endianness_from_numpy_dtype,
)
from zarr.core.dtype.wrapper import TBaseDType, ZDType
if TYPE_CHECKING:
from zarr.core.common import JSON, ZarrFormat
@dataclass(frozen=True)
class BaseComplex(ZDType[TComplexDType_co, TComplexScalar_co], HasEndianness, HasItemSize):
"""
A base class for Zarr data types that wrap NumPy complex float data types.
"""
# This attribute holds the possible zarr v2 JSON names for the data type
_zarr_v2_names: ClassVar[tuple[str, ...]]
@classmethod
def from_native_dtype(cls, dtype: TBaseDType) -> Self:
"""
Create an instance of this data type from a NumPy complex dtype.
Parameters
----------
dtype : TBaseDType
The native dtype to convert.
Returns
-------
Self
An instance of this data type.
Raises
------
DataTypeValidationError
If the dtype is not compatible with this data type.
"""
if cls._check_native_dtype(dtype):
return cls(endianness=get_endianness_from_numpy_dtype(dtype))
raise DataTypeValidationError(
f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}"
)
def to_native_dtype(self) -> TComplexDType_co:
"""
Convert this class to a NumPy complex dtype with the appropriate byte order.
Returns
-------
TComplexDType_co
A NumPy data type object representing the complex data type with the specified byte order.
"""
byte_order = endianness_to_numpy_str(self.endianness)
return self.dtype_cls().newbyteorder(byte_order) # type: ignore[return-value]
@classmethod
def _check_json_v2(cls, data: DTypeJSON) -> TypeGuard[DTypeConfig_V2[str, None]]:
"""
Check that the input is a valid JSON representation of this data type.
The input data must be a mapping that contains a "name" key that is one of
the strings from cls._zarr_v2_names and an "object_codec_id" key that is None.
Parameters
----------
data : DTypeJSON
The JSON data to check.
Returns
-------
bool
True if the input is a valid JSON representation, False otherwise.
"""
return (
check_dtype_spec_v2(data)
and data["name"] in cls._zarr_v2_names
and data["object_codec_id"] is None
)
@classmethod
def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[str]:
"""
Check that the input is a valid JSON representation of this data type in Zarr V3.
This method verifies that the provided data matches the expected Zarr V3
representation, which is the string specified by the class-level attribute _zarr_v3_name.
Parameters
----------
data : DTypeJSON
The JSON data to check.
Returns
-------
TypeGuard[str]
True if the input is a valid representation of this class in Zarr V3, False otherwise.
"""
return data == cls._zarr_v3_name
@classmethod
def _from_json_v2(cls, data: DTypeJSON) -> Self:
"""
Create an instance of this class from Zarr V2-flavored JSON.
Parameters
----------
data : DTypeJSON
The JSON data.
Returns
-------
Self
An instance of this class.
Raises
------
DataTypeValidationError
If the input JSON is not a valid representation of this class.
"""
if cls._check_json_v2(data):
# Going via numpy ensures that we get the endianness correct without
# annoying string parsing.
name = data["name"]
return cls.from_native_dtype(np.dtype(name))
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected one of the strings {cls._zarr_v2_names}."
raise DataTypeValidationError(msg)
@classmethod
def _from_json_v3(cls, data: DTypeJSON) -> Self:
"""
Create an instance of this class from Zarr V3-flavored JSON.
Parameters
----------
data : DTypeJSON
The JSON data.
Returns
-------
Self
An instance of this data type.
Raises
------
DataTypeValidationError
If the input JSON is not a valid representation of this class.
"""
if cls._check_json_v3(data):
return cls()
msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected {cls._zarr_v3_name}."
raise DataTypeValidationError(msg)
@overload
def to_json(self, zarr_format: Literal[2]) -> DTypeConfig_V2[str, None]: ...
@overload
def to_json(self, zarr_format: Literal[3]) -> str: ...
def to_json(self, zarr_format: ZarrFormat) -> DTypeConfig_V2[str, None] | str:
"""
Serialize this object to a JSON-serializable representation.
Parameters
----------
zarr_format : ZarrFormat
The Zarr format version. Supported values are 2 and 3.
Returns
-------
DTypeConfig_V2[str, None] | str
If ``zarr_format`` is 2, a dictionary with ``"name"`` and ``"object_codec_id"`` keys is
returned.
If ``zarr_format`` is 3, a string representation of the complex data type is returned.
Raises
------
ValueError
If `zarr_format` is not 2 or 3.
"""
if zarr_format == 2:
return {"name": self.to_native_dtype().str, "object_codec_id": None}
elif zarr_format == 3:
return self._zarr_v3_name
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover
def _check_scalar(self, data: object) -> TypeGuard[ComplexLike]:
"""
Check that the input is a scalar complex value.
Parameters
----------
data : object
The value to check.
Returns
-------
TypeGuard[ComplexLike]
True if the input is a scalar complex value, False otherwise.
"""
return isinstance(data, ComplexLike)
def _cast_scalar_unchecked(self, data: ComplexLike) -> TComplexScalar_co:
"""
Cast the provided scalar data to the native scalar type of this class.
Parameters
----------
data : ComplexLike
The data to cast.
Returns
-------
TComplexScalar_co
The casted data as a numpy complex scalar.
Notes
-----
This method does not perform any type checking.
The input data must be a scalar complex value.
"""
return self.to_native_dtype().type(data) # type: ignore[return-value]
def cast_scalar(self, data: object) -> TComplexScalar_co:
"""
Attempt to cast a given object to a numpy complex scalar.
Parameters
----------
data : object
The data to be cast to a numpy complex scalar.
Returns
-------
TComplexScalar_co
The data cast as a numpy complex scalar.
Raises
------
TypeError
If the data cannot be converted to a numpy complex scalar.
"""
if self._check_scalar(data):
return self._cast_scalar_unchecked(data)
msg = (
f"Cannot convert object {data!r} with type {type(data)} to a scalar compatible with the "
f"data type {self}."
)
raise TypeError(msg)
def default_scalar(self) -> TComplexScalar_co:
"""
Get the default value, which is 0 cast to this dtype
Returns
-------
Int scalar
The default value.
"""
return self._cast_scalar_unchecked(0)
def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> TComplexScalar_co:
"""
Read a JSON-serializable value as a numpy float.
Parameters
----------
data : JSON
The JSON-serializable value.
zarr_format : ZarrFormat
The zarr format version.
Returns
-------
TScalar_co
The numpy float.
"""
if zarr_format == 2:
if check_json_complex_float_v2(data):
return self._cast_scalar_unchecked(complex_float_from_json_v2(data))
raise TypeError(
f"Invalid type: {data}. Expected a float or a special string encoding of a float."
)
elif zarr_format == 3:
if check_json_complex_float_v3(data):
return self._cast_scalar_unchecked(complex_float_from_json_v3(data))
raise TypeError(
f"Invalid type: {data}. Expected a float or a special string encoding of a float."
)
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover
def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> JSON:
"""
Convert an object to a JSON-serializable float.
Parameters
----------
data : _BaseScalar
The value to convert.
zarr_format : ZarrFormat
The zarr format version.
Returns
-------
JSON
The JSON-serializable form of the complex number, which is a list of two floats,
each of which is encoding according to a zarr-format-specific encoding.
"""
if zarr_format == 2:
return complex_float_to_json_v2(self.cast_scalar(data))
elif zarr_format == 3:
return complex_float_to_json_v3(self.cast_scalar(data))
raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover
[docs]
@dataclass(frozen=True, kw_only=True)
class Complex64(BaseComplex[np.dtypes.Complex64DType, np.complex64]):
"""
A Zarr data type for arrays containing 64 bit complex floats.
Wraps the ``np.dtypes.Complex64DType`` data type. Scalars for this data type
are instances of ``np.complex64``.
Attributes
----------
dtype_cls : Type[np.dtypes.Complex64DType]
The numpy dtype class for this data type.
_zarr_v3_name : ClassVar[Literal["complex64"]]
The name of this data type in Zarr V3.
_zarr_v2_names : ClassVar[tuple[Literal[">c8"], Literal["<c8"]]]
The names of this data type in Zarr V2.
"""
dtype_cls = np.dtypes.Complex64DType
_zarr_v3_name: ClassVar[Literal["complex64"]] = "complex64"
_zarr_v2_names: ClassVar[tuple[Literal[">c8"], Literal["<c8"]]] = (">c8", "<c8")
@property
def item_size(self) -> int:
"""
The size of a single scalar in bytes.
Returns
-------
int
The size of a single scalar in bytes.
"""
return 8
[docs]
@dataclass(frozen=True, kw_only=True)
class Complex128(BaseComplex[np.dtypes.Complex128DType, np.complex128], HasEndianness):
"""
A Zarr data type for arrays containing 64 bit complex floats.
Wraps the ``np.dtypes.Complex128DType`` data type. Scalars for this data type
are instances of ``np.complex128``.
Attributes
----------
dtype_cls : Type[np.dtypes.Complex128DType]
The numpy dtype class for this data type.
_zarr_v3_name : ClassVar[Literal["complex128"]]
The name of this data type in Zarr V3.
_zarr_v2_names : ClassVar[tuple[Literal[">c16"], Literal["<c16"]]]
The names of this data type in Zarr V2.
"""
dtype_cls = np.dtypes.Complex128DType
_zarr_v3_name: ClassVar[Literal["complex128"]] = "complex128"
_zarr_v2_names: ClassVar[tuple[Literal[">c16"], Literal["<c16"]]] = (">c16", "<c16")
@property
def item_size(self) -> int:
"""
The size of a single scalar in bytes.
Returns
-------
int
The size of a single scalar in bytes.
"""
return 16