From 90f853c0b40c4a7572582476763f0372c726bd75 Mon Sep 17 00:00:00 2001 From: landmanbester Date: Sat, 30 Nov 2024 16:08:57 +0200 Subject: [PATCH] update smoovie for remote client --- pfb/parser/smoovie.yaml | 4 ++++ pfb/workers/smoovie.py | 39 ++++++++++++++------------------------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/pfb/parser/smoovie.yaml b/pfb/parser/smoovie.yaml index ff2d4f7b..5079cd8d 100644 --- a/pfb/parser/smoovie.yaml +++ b/pfb/parser/smoovie.yaml @@ -73,6 +73,10 @@ inputs: default: 12 info: Constant rate factor for controlling mp4 output quality. + scratch-dir: + dtype: URI + info: + Streamjoy scratch directory. _include: - (.)out.yml diff --git a/pfb/workers/smoovie.py b/pfb/workers/smoovie.py index a4ac61c8..5cb547d1 100644 --- a/pfb/workers/smoovie.py +++ b/pfb/workers/smoovie.py @@ -72,26 +72,22 @@ def _smoovie(**kw): import xarray as xr import numpy as np - from pfb.utils.fits import dds2fits from pfb.utils.naming import xds_from_url - from distributed import get_client, as_completed - from PIL import Image, ImageDraw, ImageFont + from distributed import get_client import matplotlib.pyplot as plt - from mpl_toolkits.axes_grid1 import make_axes_locatable from streamjoy import stream, wrap_matplotlib - from distributed.diagnostics.progressbar import progress from daskms.fsspec_store import DaskMSStore try: client = get_client() - names = list(client.scheduler_info()['workers'].keys()) except: - from pfb.utils.dist import fake_client - client = fake_client() - names = [0] - as_completed = lambda x: x + client = None basename = opts.output_filename + if opts.scratch_dir is not None: + scratch_dir = opts.scratch_dir + else: + scratch_dir = basename.rsplit('/', 1)[0] # xds contains vis products, no imaging weights applied fds_name = f'{basename}.fds' if opts.fds is None else opts.fds @@ -151,7 +147,7 @@ def plot_frame(ds): xycoords='axes fraction', textcoords='axes fraction', ha='left', va='bottom', - fontsize=20, + fontsize=15, color=opts.text_colour) return fig @@ -164,7 +160,7 @@ def plot_frame(ds): fds_dict.setdefault(b, []) fds_dict[b].append(ds) - for b, dslist in fdso.items(): + for b, dslist in fds_dict.items(): rmss = [ds.rms for ds in dslist] medrms = np.median(rmss) nframe = len(dslist) @@ -172,17 +168,6 @@ def plot_frame(ds): ds.attrs['rms'] = medrms ds.attrs['ffrac'] = f'{i}/{nframe}' - # results should have - # 0 - out image - # 1 - median rms - # 2 - utc - # 3 - scan number - # 4 - frame fraction - # 5 - band id - - # TODO: - # - progressbar - # - investigate writing frames to disk as xarray dataset and passing instead of frames idfy = f'fps{opts.fps}' if opts.out_format.lower() == 'gif': outim = stream( @@ -193,7 +178,9 @@ def plot_frame(ds): threads_per_worker=1, fps=opts.fps, max_frames=-1, - uri=f'{basename}_band{b}_{idfy}.gif' + uri=f'{basename}_band{b}_{idfy}.gif', + scratch_dir=f'{scratch_dir}/streamjoy_scratch', + client=client ) elif opts.out_format.lower() == 'mp4': outim = stream( @@ -204,7 +191,9 @@ def plot_frame(ds): threads_per_worker=1, fps=opts.fps, max_frames=-1, - uri=f'{basename}_band{b}_{idfy}.mp4' + uri=f'{basename}_band{b}_{idfy}.mp4', + scratch_dir=f'{scratch_dir}/streamjoy_scratch', + client=client ) else: raise ValueError(f"Unsupported format {opts.out_format}")