Skip to content

Commit 02f8040

Browse files
committed
Write _ARRAY_DIMENSIONS array attribute
1 parent c81ab2d commit 02f8040

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

daskms/experimental/zarr/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
import warnings
1212

13+
ARRAY_DIMENSION = "_ARRAY_DIMENSIONS"
14+
1315
from daskms.constants import DASKMS_PARTITION_KEY
1416
from daskms.dataset import Dataset, Variable
1517
from daskms.dataset_schema import DatasetSchema, encode_type, decode_type, decode_attr
@@ -111,6 +113,8 @@ def create_array(ds_group, column, column_schema, schema_chunks, coordinate=Fals
111113
exact=True,
112114
)
113115

116+
array.attrs[ARRAY_DIMENSION] = column_schema.dims
117+
114118
array.attrs[DASKMS_ATTR_KEY] = {
115119
**column_schema.attrs,
116120
"dims": column_schema.dims,

daskms/experimental/zarr/tests/test_zarr.py

+24
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def test_multiprocess_create(ms, tmp_path_factory):
235235
@pytest.mark.skipif(xarray is None, reason="depends on xarray")
236236
def test_xarray_to_zarr(ms, tmp_path_factory):
237237
store = tmp_path_factory.mktemp("zarr_store")
238+
238239
datasets = xds_from_ms(ms)
239240

240241
for i, ds in enumerate(datasets):
@@ -407,3 +408,26 @@ def test_xds_from_zarr_assert_on_empty_store(tmp_path_factory, ms):
407408

408409
with pytest.raises(UnknownStoreTypeError, match="Unable to infer table type"):
409410
xds_from_zarr(path)
411+
412+
413+
@pytest.mark.skipif(xarray is None, reason="depends on xarray")
414+
def test_xarray_reading_daskms_written_dataset(ms, tmp_path_factory):
415+
store = tmp_path_factory.mktemp("zarr_store")
416+
417+
datasets = xds_from_ms(ms)
418+
419+
for i, ds in enumerate(datasets):
420+
chunks = ds.chunks
421+
row = sum(chunks["row"])
422+
chan = sum(chunks["chan"])
423+
corr = sum(chunks["corr"])
424+
425+
datasets[i] = ds.assign_coords(
426+
row=np.arange(row), chan=np.arange(chan), corr=np.arange(corr)
427+
)
428+
429+
path = store / "test.zarr"
430+
dask.compute(xds_to_zarr(datasets, path, consolidated=True))
431+
432+
ds = xarray.open_zarr(path / "MAIN" / "MAIN_0")
433+
assert ds == datasets[0]

0 commit comments

Comments
 (0)