import builtins
import functools
from collections.abc import Callable
from typing import Any, TypeVar, cast
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
from hypothesis import assume, note
from hypothesis.stateful import (
RuleBasedStateMachine,
initialize,
invariant,
precondition,
rule,
)
from hypothesis.strategies import DataObject
import zarr
from zarr import Array
from zarr.abc.store import Store
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import Buffer, BufferPrototype, cpu, default_buffer_prototype
from zarr.core.sync import SyncMixin
from zarr.storage import LocalStore, MemoryStore
from zarr.testing.strategies import (
basic_indices,
chunk_paths,
dimension_names,
key_ranges,
node_names,
np_array_and_chunks,
orthogonal_indices,
)
from zarr.testing.strategies import keys as zarr_keys
MAX_BINARY_SIZE = 100
F = TypeVar("F", bound=Callable[..., Any])
[docs]
def with_frequency(frequency: float) -> Callable[[F], F]:
"""This needs to be deterministic for hypothesis replaying"""
def decorator(func: F) -> F:
counter_attr = f"__{func.__name__}_counter"
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
@precondition
def frequency_check(f: Any) -> Any:
if not hasattr(f, counter_attr):
setattr(f, counter_attr, 0)
current_count = getattr(f, counter_attr) + 1
setattr(f, counter_attr, current_count)
return (current_count * frequency) % 1.0 >= (1.0 - frequency)
return cast(F, frequency_check(wrapper))
return decorator
[docs]
def split_prefix_name(path: str) -> tuple[str, str]:
split = path.rsplit("/", maxsplit=1)
if len(split) > 1:
prefix, name = split
else:
prefix = ""
(name,) = split
return prefix, name
[docs]
class ZarrHierarchyStateMachine(SyncMixin, RuleBasedStateMachine):
"""
This state machine models operations that modify a zarr store's
hierarchy. That is, user actions that modify arrays/groups as well
as list operations. It is intended to be used by external stores, and
compares their results to a MemoryStore that is assumed to be perfect.
"""
def __init__(self, store: Store) -> None:
super().__init__()
self.store = store
self.model = MemoryStore()
zarr.group(store=self.model)
# Track state of the hierarchy, these should contain fully qualified paths
self.all_groups: set[str] = set()
self.all_arrays: set[str] = set()
[docs]
@initialize()
def init_store(self) -> None:
# This lets us reuse the fixture provided store.
self._sync(self.store.clear())
zarr.group(store=self.store)
[docs]
def can_add(self, path: str) -> bool:
return path not in self.all_groups and path not in self.all_arrays
# -------------------- store operations -----------------------
[docs]
@rule(name=node_names, data=st.data())
def add_group(self, name: str, data: DataObject) -> None:
# Handle possible case-insensitive file systems (e.g. MacOS)
if isinstance(self.store, LocalStore):
name = name.lower()
if self.all_groups:
parent = data.draw(st.sampled_from(sorted(self.all_groups)), label="Group parent")
else:
parent = ""
path = f"{parent}/{name}".lstrip("/")
assume(self.can_add(path))
note(f"Adding group: path='{path}'")
self.all_groups.add(path)
zarr.group(store=self.store, path=path)
zarr.group(store=self.model, path=path)
[docs]
@rule(data=st.data(), name=node_names, array_and_chunks=np_array_and_chunks())
def add_array(
self,
data: DataObject,
name: str,
array_and_chunks: tuple[np.ndarray[Any, Any], tuple[int, ...]],
) -> None:
# Handle possible case-insensitive file systems (e.g. MacOS)
if isinstance(self.store, LocalStore):
name = name.lower()
array, chunks = array_and_chunks
fill_value = data.draw(npst.from_dtype(array.dtype))
if self.all_groups:
parent = data.draw(st.sampled_from(sorted(self.all_groups)), label="Array parent")
else:
parent = ""
# TODO: support creating deeper paths
# TODO: support overwriting potentially by just skipping `self.can_add`
path = f"{parent}/{name}".lstrip("/")
assume(self.can_add(path))
note(f"Adding array: path='{path}' shape={array.shape} chunks={chunks}")
for store in [self.store, self.model]:
zarr.array(
array,
chunks=chunks,
path=path,
store=store,
fill_value=fill_value,
zarr_format=3,
dimension_names=data.draw(
dimension_names(ndim=array.ndim), label="dimension names"
),
# Chose bytes codec to avoid wasting time compressing the data being written
codecs=[BytesCodec()],
)
self.all_arrays.add(path)
[docs]
@rule()
@with_frequency(0.25)
def clear(self) -> None:
note("clearing")
import zarr
self._sync(self.store.clear())
self._sync(self.model.clear())
assert self._sync(self.store.is_empty("/"))
assert self._sync(self.model.is_empty("/"))
self.all_groups.clear()
self.all_arrays.clear()
zarr.group(store=self.store)
zarr.group(store=self.model)
# TODO: MemoryStore is broken?
# assert not self._sync(self.store.is_empty("/"))
# assert not self._sync(self.model.is_empty("/"))
[docs]
def draw_directory(self, data: DataObject) -> str:
group_st = st.sampled_from(sorted(self.all_groups)) if self.all_groups else st.nothing()
array_st = st.sampled_from(sorted(self.all_arrays)) if self.all_arrays else st.nothing()
array_or_group = data.draw(st.one_of(group_st, array_st))
if data.draw(st.booleans()) and array_or_group in self.all_arrays:
arr = zarr.open_array(path=array_or_group, store=self.model)
path = data.draw(
st.one_of(
st.sampled_from([array_or_group]),
chunk_paths(ndim=arr.ndim, numblocks=arr.cdata_shape).map(
lambda x: f"{array_or_group}/c/"
),
)
)
else:
path = array_or_group
return path
[docs]
@precondition(lambda self: bool(self.all_groups))
@rule(data=st.data())
def check_list_dir(self, data: DataObject) -> None:
path = self.draw_directory(data)
note(f"list_dir for {path=!r}")
# Consider .list_dir("path/to/array") for an array with a single chunk.
# The MemoryStore model will return `"c", "zarr.json"` only if the chunk exists
# If that chunk was deleted, then `"c"` is not returned.
# LocalStore will not have this behaviour :/
# There are similar consistency issues with delete_dir("/path/to/array/c/0/0")
assume(not isinstance(self.store, LocalStore))
model_ls = sorted(self._sync_iter(self.model.list_dir(path)))
store_ls = sorted(self._sync_iter(self.store.list_dir(path)))
assert model_ls == store_ls, (model_ls, store_ls)
[docs]
@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def delete_chunk(self, data: DataObject) -> None:
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
arr = zarr.open_array(path=array, store=self.model)
chunk_path = data.draw(chunk_paths(ndim=arr.ndim, numblocks=arr.cdata_shape, subset=False))
path = f"{array}/c/{chunk_path}"
note(f"deleting chunk {path=!r}")
self._sync(self.model.delete(path))
self._sync(self.store.delete(path))
[docs]
@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def check_array(self, data: DataObject) -> None:
path = data.draw(st.sampled_from(sorted(self.all_arrays)))
actual = zarr.open_array(self.store, path=path)[:]
expected = zarr.open_array(self.model, path=path)[:]
np.testing.assert_equal(actual, expected)
[docs]
@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def overwrite_array_basic_indexing(self, data: DataObject) -> None:
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
model_array = zarr.open_array(path=array, store=self.model)
store_array = zarr.open_array(path=array, store=self.store)
slicer = data.draw(basic_indices(shape=model_array.shape))
note(f"overwriting array with basic indexer: {slicer=}")
new_data = data.draw(
npst.arrays(shape=np.shape(model_array[slicer]), dtype=model_array.dtype)
)
model_array[slicer] = new_data
store_array[slicer] = new_data
[docs]
@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def overwrite_array_orthogonal_indexing(self, data: DataObject) -> None:
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
model_array = zarr.open_array(path=array, store=self.model)
store_array = zarr.open_array(path=array, store=self.store)
indexer, _ = data.draw(orthogonal_indices(shape=model_array.shape))
note(f"overwriting array orthogonal {indexer=}")
new_data = data.draw(
npst.arrays(shape=model_array.oindex[indexer].shape, dtype=model_array.dtype) # type: ignore[union-attr]
)
model_array.oindex[indexer] = new_data
store_array.oindex[indexer] = new_data
[docs]
@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def resize_array(self, data: DataObject) -> None:
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
model_array = zarr.open_array(path=array, store=self.model)
store_array = zarr.open_array(path=array, store=self.store)
ndim = model_array.ndim
new_shape = tuple(
0 if oldsize == 0 else newsize
for newsize, oldsize in zip(
data.draw(npst.array_shapes(max_dims=ndim, min_dims=ndim, min_side=0)),
model_array.shape,
strict=True,
)
)
note(f"resizing array from {model_array.shape} to {new_shape}")
model_array.resize(new_shape)
store_array.resize(new_shape)
[docs]
@precondition(lambda self: bool(self.all_arrays) or bool(self.all_groups))
@rule(data=st.data())
def delete_dir(self, data: DataObject) -> None:
path = self.draw_directory(data)
note(f"delete_dir with {path=!r}")
self._sync(self.model.delete_dir(path))
self._sync(self.store.delete_dir(path))
matches = set()
for node in self.all_groups | self.all_arrays:
if node.startswith(path):
matches.add(node)
self.all_groups = self.all_groups - matches
self.all_arrays = self.all_arrays - matches
# @precondition(lambda self: bool(self.all_groups))
# @precondition(lambda self: bool(self.all_arrays))
# @rule(data=st.data())
# def move_array(self, data):
# array_path = data.draw(st.sampled_from(self.all_arrays), label="Array move source")
# to_group = data.draw(st.sampled_from(self.all_groups), label="Array move destination")
# # fixme renaming to self?
# array_name = os.path.basename(array_path)
# assume(self.model.can_add(to_group, array_name))
# new_path = f"{to_group}/{array_name}".lstrip("/")
# note(f"moving array '{array_path}' -> '{new_path}'")
# self.model.rename(array_path, new_path)
# self.repo.store.rename(array_path, new_path)
# @precondition(lambda self: len(self.all_groups) >= 2)
# @rule(data=st.data())
# def move_group(self, data):
# from_group = data.draw(st.sampled_from(self.all_groups), label="Group move source")
# to_group = data.draw(st.sampled_from(self.all_groups), label="Group move destination")
# assume(not to_group.startswith(from_group))
# from_group_name = os.path.basename(from_group)
# assume(self.model.can_add(to_group, from_group_name))
# # fixme renaming to self?
# new_path = f"{to_group}/{from_group_name}".lstrip("/")
# note(f"moving group '{from_group}' -> '{new_path}'")
# self.model.rename(from_group, new_path)
# self.repo.store.rename(from_group, new_path)
[docs]
@precondition(lambda self: self.store.supports_deletes)
@precondition(lambda self: len(self.all_arrays) >= 1)
@rule(data=st.data())
def delete_array_using_del(self, data: DataObject) -> None:
array_path = data.draw(
st.sampled_from(sorted(self.all_arrays)), label="Array deletion target"
)
prefix, array_name = split_prefix_name(array_path)
note(f"Deleting array '{array_path}' ({prefix=!r}, {array_name=!r}) using del")
for store in [self.model, self.store]:
group = zarr.open_group(path=prefix, store=store)
group[array_name] # check that it exists
del group[array_name]
self.all_arrays.remove(array_path)
[docs]
@precondition(lambda self: self.store.supports_deletes)
@precondition(lambda self: len(self.all_groups) >= 2) # fixme don't delete root
@rule(data=st.data())
def delete_group_using_del(self, data: DataObject) -> None:
# ensure that we don't include the root group in the list of member names that we try
# to delete
member_names = tuple(filter(lambda v: "/" in v, sorted(self.all_groups)))
group_path = data.draw(st.sampled_from(member_names), label="Group deletion target")
prefix, group_name = split_prefix_name(group_path)
note(f"Deleting group '{group_path=!r}', {prefix=!r}, {group_name=!r} using delete")
members = zarr.open_group(store=self.model, path=group_path).members(max_depth=None)
for _, obj in members:
if isinstance(obj, Array):
self.all_arrays.remove(obj.path)
else:
self.all_groups.remove(obj.path)
for store in [self.store, self.model]:
group = zarr.open_group(store=store, path=prefix)
group[group_name] # check that it exists
del group[group_name]
if group_path != "/":
# The root group is always present
self.all_groups.remove(group_path)
# # --------------- assertions -----------------
# def check_group_arrays(self, group):
# # note(f"Checking arrays of '{group}'")
# g1 = self.model.get_group(group)
# g2 = zarr.open_group(path=group, mode="r", store=self.repo.store)
# model_arrays = sorted(g1.arrays(), key=itemgetter(0))
# our_arrays = sorted(g2.arrays(), key=itemgetter(0))
# for (n1, a1), (n2, a2) in zip_longest(model_arrays, our_arrays):
# assert n1 == n2
# assert_array_equal(a1, a2)
# def check_subgroups(self, group_path):
# g1 = self.model.get_group(group_path)
# g2 = zarr.open_group(path=group_path, mode="r", store=self.repo.store)
# g1_children = [name for (name, _) in g1.groups()]
# g2_children = [name for (name, _) in g2.groups()]
# # note(f"Checking {len(g1_children)} subgroups of group '{group_path}'")
# assert g1_children == g2_children
# def check_list_prefix_from_group(self, group):
# prefix = f"meta/root/{group}"
# model_list = sorted(self.model.list_prefix(prefix))
# al_list = sorted(self.repo.store.list_prefix(prefix))
# # note(f"Checking {len(model_list)} keys under '{prefix}'")
# assert model_list == al_list
# prefix = f"data/root/{group}"
# model_list = sorted(self.model.list_prefix(prefix))
# al_list = sorted(self.repo.store.list_prefix(prefix))
# # note(f"Checking {len(model_list)} keys under '{prefix}'")
# assert model_list == al_list
# @precondition(lambda self: self.model.is_persistent_session())
# @rule(data=st.data())
# def check_group_path(self, data):
# t0 = time.time()
# group = data.draw(st.sampled_from(self.all_groups))
# self.check_list_prefix_from_group(group)
# self.check_subgroups(group)
# self.check_group_arrays(group)
# t1 = time.time()
# note(f"Checks took {t1 - t0} sec.")
[docs]
@invariant()
def check_list_prefix_from_root(self) -> None:
model_list = self._sync_iter(self.model.list_prefix(""))
store_list = self._sync_iter(self.store.list_prefix(""))
note(f"Checking {len(model_list)} expected keys vs {len(store_list)} actual keys")
assert sorted(model_list) == sorted(store_list), (
sorted(model_list),
sorted(store_list),
)
# check that our internal state matches that of the store and model
assert all(f"{path}/zarr.json" in model_list for path in self.all_groups | self.all_arrays)
assert all(f"{path}/zarr.json" in store_list for path in self.all_groups | self.all_arrays)
[docs]
class SyncStoreWrapper(zarr.core.sync.SyncMixin):
def __init__(self, store: Store) -> None:
"""Synchronous Store wrapper
This class holds synchronous methods that map to async methods of Store classes.
The synchronous wrapper is needed because hypothesis' stateful testing infra does
not support asyncio so we redefine sync versions of the Store API.
https://github.com/HypothesisWorks/hypothesis/issues/3712#issuecomment-1668999041
"""
self.store = store
@property
def read_only(self) -> bool:
return self.store.read_only
[docs]
def set(self, key: str, data_buffer: Buffer) -> None:
return self._sync(self.store.set(key, data_buffer))
[docs]
def list(self) -> builtins.list[str]:
return self._sync_iter(self.store.list())
[docs]
def get(self, key: str, prototype: BufferPrototype) -> Buffer | None:
return self._sync(self.store.get(key, prototype=prototype))
[docs]
def get_partial_values(
self, key_ranges: builtins.list[Any], prototype: BufferPrototype
) -> builtins.list[Buffer | None]:
return self._sync(self.store.get_partial_values(prototype=prototype, key_ranges=key_ranges))
[docs]
def delete(self, path: str) -> None:
return self._sync(self.store.delete(path))
[docs]
def is_empty(self, prefix: str) -> bool:
return self._sync(self.store.is_empty(prefix=prefix))
[docs]
def clear(self) -> None:
return self._sync(self.store.clear())
[docs]
def exists(self, key: str) -> bool:
return self._sync(self.store.exists(key))
[docs]
def list_dir(self, prefix: str) -> None:
raise NotImplementedError
[docs]
def list_prefix(self, prefix: str) -> None:
raise NotImplementedError
[docs]
def set_partial_values(self, key_start_values: Any) -> None:
raise NotImplementedError
@property
def supports_listing(self) -> bool:
return self.store.supports_listing
@property
def supports_partial_writes(self) -> bool:
return self.supports_partial_writes
@property
def supports_writes(self) -> bool:
return self.store.supports_writes
@property
def supports_deletes(self) -> bool:
return self.store.supports_deletes
[docs]
class ZarrStoreStateMachine(RuleBasedStateMachine):
""" "
Zarr store state machine
This is a subclass of a Hypothesis RuleBasedStateMachine.
It is testing a framework to ensure that the state of a Zarr store matches
an expected state after a set of random operations. It contains a store
(currently, a Zarr MemoryStore) and a model, a simplified version of a
zarr store (in this case, a dict). It also contains rules which represent
actions that can be applied to a zarr store. Rules apply an action to both
the store and the model, and invariants assert that the state of the model
is equal to the state of the store. Hypothesis then generates sequences of
rules, running invariants after each rule. It raises an error if a sequence
produces discontinuity between state of the model and state of the store
(ie. an invariant is violated).
https://hypothesis.readthedocs.io/en/latest/stateful.html
"""
def __init__(self, store: Store) -> None:
super().__init__()
self.model: dict[str, Buffer] = {}
self.store = SyncStoreWrapper(store)
self.prototype = default_buffer_prototype()
[docs]
@initialize()
def init_store(self) -> None:
self.store.clear()
[docs]
@rule(key=zarr_keys(), data=st.binary(min_size=0, max_size=MAX_BINARY_SIZE))
def set(self, key: str, data: bytes) -> None:
note(f"(set) Setting {key!r} with {data!r}")
assert not self.store.read_only
data_buf = cpu.Buffer.from_bytes(data)
self.store.set(key, data_buf)
self.model[key] = data_buf
[docs]
@precondition(lambda self: len(self.model.keys()) > 0)
@rule(key=zarr_keys(), data=st.data())
def get(self, key: str, data: DataObject) -> None:
key = data.draw(
st.sampled_from(sorted(self.model.keys()))
) # hypothesis wants to sample from sorted list
note("(get)")
store_value = self.store.get(key, self.prototype)
# to bytes here necessary because data_buf set to model in set()
assert self.model[key] == store_value
[docs]
@rule(key=zarr_keys(), data=st.data())
def get_invalid_zarr_keys(self, key: str, data: DataObject) -> None:
note("(get_invalid)")
assume(key not in self.model)
assert self.store.get(key, self.prototype) is None
[docs]
@precondition(lambda self: len(self.model.keys()) > 0)
@rule(data=st.data())
def get_partial_values(self, data: DataObject) -> None:
key_range = data.draw(
key_ranges(keys=st.sampled_from(sorted(self.model.keys())), max_size=MAX_BINARY_SIZE)
)
note(f"(get partial) {key_range=}")
obs_maybe = self.store.get_partial_values(key_range, self.prototype)
observed = []
for obs in obs_maybe:
assert obs is not None
observed.append(obs.to_bytes())
model_vals_ls = []
for key, byte_range in key_range:
start = byte_range.start
stop = byte_range.end
model_vals_ls.append(self.model[key][start:stop])
assert all(
obs == exp.to_bytes() for obs, exp in zip(observed, model_vals_ls, strict=True)
), (
observed,
model_vals_ls,
)
[docs]
@precondition(lambda self: self.store.supports_deletes)
@precondition(lambda self: len(self.model.keys()) > 0)
@rule(data=st.data())
def delete(self, data: DataObject) -> None:
key = data.draw(st.sampled_from(sorted(self.model.keys())))
note(f"(delete) Deleting {key=}")
self.store.delete(key)
del self.model[key]
[docs]
@rule()
def clear(self) -> None:
assert not self.store.read_only
note("(clear)")
self.store.clear()
self.model.clear()
assert self.store.is_empty("")
assert len(self.model.keys()) == len(list(self.store.list())) == 0
[docs]
@rule()
# Local store can be non-empty when there are subdirectories but no files
@precondition(lambda self: not isinstance(self.store.store, LocalStore))
def is_empty(self) -> None:
note("(is_empty)")
# make sure they either both are or both aren't empty (same state)
assert self.store.is_empty("") == (not self.model)
[docs]
@rule(key=zarr_keys())
def exists(self, key: str) -> None:
note("(exists)")
assert self.store.exists(key) == (key in self.model)
[docs]
@invariant()
def check_paths_equal(self) -> None:
note("Checking that paths are equal")
paths = sorted(self.store.list())
assert sorted(self.model.keys()) == paths
[docs]
@invariant()
def check_vals_equal(self) -> None:
note("Checking values equal")
for key, val in self.model.items():
store_item = self.store.get(key, self.prototype)
assert val == store_item
[docs]
@invariant()
def check_num_zarr_keys_equal(self) -> None:
note("check num zarr_keys equal")
assert len(self.model) == len(list(self.store.list()))
[docs]
@invariant()
def check_zarr_keys(self) -> None:
keys = list(self.store.list())
if not keys:
assert self.store.is_empty("") is True
else:
assert self.store.is_empty("") is False
for key in keys:
assert self.store.exists(key) is True
note("checking keys / exists / empty")