Skip to content

Commit

Permalink
Merge pull request #104 from ratt-ru/issue-97
Browse files Browse the repository at this point in the history
Ensure matplotlib bounds are computed for baseline arrays and fix regression in iter-baselines (#97 #105)
  • Loading branch information
bennahugo authored Oct 6, 2022
2 parents d96bf15 + f2d5e16 commit 4026ff4
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions shade_ms/data_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,21 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs,
# always read flags -- easier that way
flag = group.FLAG if not noflags else None
flag_row = group.FLAG_ROW if not noflags else None

a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data)
a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data)
baselines = msinfo.baseline_number(a1, a2)

if not iter_baseline:
# if group by then these are attributes, not data arrays
a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data)
a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data)
baselines = msinfo.baseline_number(a1, a2)
else:
baselines = None
freqs = chan_freqs[ddid]
chans = xarray.DataArray(range(len(freqs)), dims=("chan",))
wavel = freq_to_wavel(freqs)
extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines)
extras = dict(chans=chans,
freqs=freqs,
wavel=wavel,
rows=group.row,
baselines=baselines if baselines else np.array([baseline]))

nchan = len(group.chan)
if flag is not None:
Expand Down Expand Up @@ -468,30 +474,43 @@ def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap,
log.debug('done')

# Set plot limits based on data extent or user values for axis labels

xmin, xmax = bounds[xaxis]
ymin, ymax = bounds[yaxis]

limits = {
"xmin": bounds[xaxis][0],
"xmax": bounds[xaxis][1],
"ymin": bounds[yaxis][0],
"ymax": bounds[yaxis][1]
}
log.debug('rendering image')

fig = pylab.figure(figsize=(figx, figy))
ax = fig.add_subplot(111, facecolor=bgcol)

for funcname, args, kwargs in extra_markup:
getattr(ax, funcname)(*args, **kwargs)

ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax],
aspect='auto', origin='lower', interpolation='nearest')

# any 1D arrays like freq and WAVEL that is dask arrays at this point needs
# compute called
compute_arrays = dict(filter(lambda x: isinstance(x[1], da.Array), limits.items()))
limits.update(dict(zip(compute_arrays.keys(), da.compute(*compute_arrays.values()))))

ax.imshow(X=rgb.data,
extent=[limits['xmin'],
limits['xmax'],
limits['ymin'],
limits['ymax']],
aspect='auto',
origin='lower',
interpolation='nearest')

ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center', fontdict=dict(fontsize=options.fontsize))
ax.set_xlabel(xlabel, fontdict=dict(fontsize=options.fontsize))
ax.set_ylabel(ylabel, fontdict=dict(fontsize=options.fontsize))
# ax.plot(xmin,ymin,'.',alpha=0.0)
# ax.plot(xmax,ymax,'.',alpha=0.0)

dx, dy = xmax - xmin, ymax - ymin
ax.set_xlim([xmin - dx/100, xmax + dx/100])
ax.set_ylim([ymin - dy/100, ymax + dy/100])
dx, dy = limits['xmax'] - limits['xmin'], limits['ymax'] - limits['ymin']
ax.set_xlim([limits['xmin'] - dx/100, limits['xmax'] + dx/100])
ax.set_ylim([limits['ymin'] - dy/100, limits['ymax'] + dy/100])

def decimate_list(x, maxel):
"""Helper function to reduce a list to < given max number of elements, dividing it by decimal factors of 2 and 5"""
Expand Down

0 comments on commit 4026ff4

Please sign in to comment.