Skip to content

Commit 9690ecd

Browse files
authored
Optimise broadcast_arrays in katdal import (#326)
1 parent 350415c commit 9690ecd

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

HISTORY.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ History
44

55
X.Y.Z (YYYY-MM-DD)
66
------------------
7+
* Optimise `broadcast_arrays` in katdal import (:pr:`326`)
78
* Change `dask-ms katdal import` to `dask-ms import katdal` (:pr:`325`)
89
* Configure dependabot (:pr:`319`)
910
* Add chunk specification to ``dask-ms katdal import`` (:pr:`318`)

daskms/experimental/katdal/msv2_facade.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
################################################################################
2020

2121
from functools import partial
22+
from operator import getitem
2223

2324
import dask.array as da
2425
import numpy as np
@@ -126,14 +127,7 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
126127

127128
flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose))
128129
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose))
129-
vis = DaskLazyIndexer(
130-
dataset.vis,
131-
(),
132-
transforms=(
133-
rechunk,
134-
vis_transpose,
135-
),
136-
)
130+
vis = DaskLazyIndexer(dataset.vis, (), (rechunk, vis_transpose))
137131

138132
time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1))
139133
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0]))
@@ -147,7 +141,32 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
147141
row=self._row_view,
148142
)
149143

150-
time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2)
144+
# Better graph than da.broadcast_arrays
145+
bcast = da.blockwise(
146+
np.broadcast_arrays,
147+
("time", "bl"),
148+
time,
149+
("time", "bl"),
150+
ant1,
151+
("time", "bl"),
152+
ant2,
153+
("time", "bl"),
154+
align_arrays=False,
155+
adjust_chunks={"time": time.chunks[0], "bl": ant1.chunks[1]},
156+
meta=np.empty((0,) * 2, dtype=np.int32),
157+
)
158+
159+
time = da.blockwise(
160+
getitem, ("time", "bl"), bcast, ("time", "bl"), 0, None, dtype=time.dtype
161+
)
162+
163+
ant1 = da.blockwise(
164+
getitem, ("time", "bl"), bcast, ("time", "bl"), 1, None, dtype=ant1.dtype
165+
)
166+
167+
ant2 = da.blockwise(
168+
getitem, ("time", "bl"), bcast, ("time", "bl"), 2, None, dtype=ant2.dtype
169+
)
151170

152171
if self._row_view:
153172
primary_dims = ("row",)

0 commit comments

Comments
 (0)