19
19
################################################################################
20
20
21
21
from functools import partial
22
+ from operator import getitem
22
23
23
24
import dask .array as da
24
25
import numpy as np
@@ -126,14 +127,7 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
126
127
127
128
flags = DaskLazyIndexer (dataset .flags , (), (rechunk , flag_transpose ))
128
129
weights = DaskLazyIndexer (dataset .weights , (), (rechunk , weight_transpose ))
129
- vis = DaskLazyIndexer (
130
- dataset .vis ,
131
- (),
132
- transforms = (
133
- rechunk ,
134
- vis_transpose ,
135
- ),
136
- )
130
+ vis = DaskLazyIndexer (dataset .vis , (), (rechunk , vis_transpose ))
137
131
138
132
time = da .from_array (time_mjds [:, None ], chunks = (t_chunks , 1 ))
139
133
ant1 = da .from_array (cp_info .ant1_index [None , :], chunks = (1 , cpi .shape [0 ]))
@@ -147,7 +141,32 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
147
141
row = self ._row_view ,
148
142
)
149
143
150
- time , ant1 , ant2 = da .broadcast_arrays (time , ant1 , ant2 )
144
+ # Better graph than da.broadcast_arrays
145
+ bcast = da .blockwise (
146
+ np .broadcast_arrays ,
147
+ ("time" , "bl" ),
148
+ time ,
149
+ ("time" , "bl" ),
150
+ ant1 ,
151
+ ("time" , "bl" ),
152
+ ant2 ,
153
+ ("time" , "bl" ),
154
+ align_arrays = False ,
155
+ adjust_chunks = {"time" : time .chunks [0 ], "bl" : ant1 .chunks [1 ]},
156
+ meta = np .empty ((0 ,) * 2 , dtype = np .int32 ),
157
+ )
158
+
159
+ time = da .blockwise (
160
+ getitem , ("time" , "bl" ), bcast , ("time" , "bl" ), 0 , None , dtype = time .dtype
161
+ )
162
+
163
+ ant1 = da .blockwise (
164
+ getitem , ("time" , "bl" ), bcast , ("time" , "bl" ), 1 , None , dtype = ant1 .dtype
165
+ )
166
+
167
+ ant2 = da .blockwise (
168
+ getitem , ("time" , "bl" ), bcast , ("time" , "bl" ), 2 , None , dtype = ant2 .dtype
169
+ )
151
170
152
171
if self ._row_view :
153
172
primary_dims = ("row" ,)
0 commit comments