diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 8a791c5..d1ae3c6 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -151,15 +151,21 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, # always read flags -- easier that way flag = group.FLAG if not noflags else None flag_row = group.FLAG_ROW if not noflags else None - - a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data) - a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data) - baselines = msinfo.baseline_number(a1, a2) - + if not iter_baseline: + # if group by then these are attributes, not data arrays + a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data) + a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data) + baselines = msinfo.baseline_number(a1, a2) + else: + baselines = None freqs = chan_freqs[ddid] chans = xarray.DataArray(range(len(freqs)), dims=("chan",)) wavel = freq_to_wavel(freqs) - extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines) + extras = dict(chans=chans, + freqs=freqs, + wavel=wavel, + rows=group.row, + baselines=baselines if baselines else np.array([baseline])) nchan = len(group.chan) if flag is not None: @@ -468,10 +474,12 @@ def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap, log.debug('done') # Set plot limits based on data extent or user values for axis labels - - xmin, xmax = bounds[xaxis] - ymin, ymax = bounds[yaxis] - + limits = { + "xmin": bounds[xaxis][0], + "xmax": bounds[xaxis][1], + "ymin": bounds[yaxis][0], + "ymax": bounds[yaxis][1] + } log.debug('rendering image') fig = pylab.figure(figsize=(figx, figy)) @@ -479,9 +487,20 @@ def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap, for funcname, args, kwargs in extra_markup: getattr(ax, funcname)(*args, **kwargs) - - ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax], - aspect='auto', origin='lower', interpolation='nearest') + + # any 1D arrays like freq and WAVEL that is dask arrays at this point needs + # compute called + compute_arrays = dict(filter(lambda x: isinstance(x[1], da.Array), limits.items())) + limits.update(dict(zip(compute_arrays.keys(), da.compute(*compute_arrays.values())))) + + ax.imshow(X=rgb.data, + extent=[limits['xmin'], + limits['xmax'], + limits['ymin'], + limits['ymax']], + aspect='auto', + origin='lower', + interpolation='nearest') ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center', fontdict=dict(fontsize=options.fontsize)) ax.set_xlabel(xlabel, fontdict=dict(fontsize=options.fontsize)) @@ -489,9 +508,9 @@ def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap, # ax.plot(xmin,ymin,'.',alpha=0.0) # ax.plot(xmax,ymax,'.',alpha=0.0) - dx, dy = xmax - xmin, ymax - ymin - ax.set_xlim([xmin - dx/100, xmax + dx/100]) - ax.set_ylim([ymin - dy/100, ymax + dy/100]) + dx, dy = limits['xmax'] - limits['xmin'], limits['ymax'] - limits['ymin'] + ax.set_xlim([limits['xmin'] - dx/100, limits['xmax'] + dx/100]) + ax.set_ylim([limits['ymin'] - dy/100, limits['ymax'] + dy/100]) def decimate_list(x, maxel): """Helper function to reduce a list to < given max number of elements, dividing it by decimal factors of 2 and 5"""