From 3d0535ce04f352641a5acf1647440f61ded4367d Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Wed, 26 Jul 2023 11:15:21 +0200 Subject: [PATCH 01/26] Use nearest-neighbour interpolation in regions where extrapolation is required. (#285) * Fix version drift. * Bump to 0.2.0 * Use nearest-neighbour interpolation for points requiring extrapolation. --- pyproject.toml | 2 +- quartical/interpolation/interpolants.py | 23 ++++++++++++++++++++++- tbump.toml | 2 +- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2196b67..68c1f4ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quartical" -version = "0.1.10" +version = "0.2.0" description = "Fast and flexible calibration suite for radio interferometer data." repository = "https://github.com/ratt-ru/QuartiCal" documentation = "https://quartical.readthedocs.io" diff --git a/quartical/interpolation/interpolants.py b/quartical/interpolation/interpolants.py index 0c384645..07771157 100644 --- a/quartical/interpolation/interpolants.py +++ b/quartical/interpolation/interpolants.py @@ -35,11 +35,32 @@ def linear2d_interpolate_gains(source_xds, target_xds): if i_f_dim > 1: interp_axes[i_f_axis] = target_xds[t_f_axis].data - target_xda = source_xds.params.interp( + # NOTE: The below is the path of least resistance but may not be the most + # efficient method for this mixed-mode interpolation - it may be possible + # to do better using multiple RegularGridInterpoator objects. + + # Interpolate using linear interpolation, filling points outside the + # domain with NaNs. + in_domain_xda = source_xds.params.interp( + interp_axes, + kwargs={"fill_value": np.nan} + ) + + # Interpolate using nearest-neighbour interpolation, extrapolating points + # outside the domain. + out_domain_xda = source_xds.params.interp( interp_axes, + method="nearest", kwargs={"fill_value": "extrapolate"} ) + # Combine the linear and nearest neighbour interpolation done above i.e. + # use use linear interpolation inside the domain and nearest-neighbour + # interpolation anywhere extrapolation was required. + target_xda = in_domain_xda.where( + da.isfinite(in_domain_xda), out_domain_xda + ) + if i_t_dim == 1: target_xda = target_xda.reindex( {i_t_axis: target_xds[t_t_axis].data}, diff --git a/tbump.toml b/tbump.toml index e481d3d3..78ae3488 100644 --- a/tbump.toml +++ b/tbump.toml @@ -2,7 +2,7 @@ github_url = "https://github.com/ratt-ru/QuartiCal/" [version] -current = "0.1.11" +current = "0.2.0" # Example of a semver regexp. # Make sure this matches current_version before From 6b22750e309dc3d715c2512f3afc38218809dc66 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Fri, 28 Jul 2023 12:23:54 +0200 Subject: [PATCH 02/26] Add interface for degrid models. --- quartical/config/preprocess.py | 49 +++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/quartical/config/preprocess.py b/quartical/config/preprocess.py index d7fa2081..ce75f009 100644 --- a/quartical/config/preprocess.py +++ b/quartical/config/preprocess.py @@ -6,15 +6,31 @@ import os.path from dataclasses import dataclass from typing import List, Dict, Set, Any +from ast import literal_eval sky_model_nt = namedtuple("sky_model_nt", ("name", "tags")) +degrid_model_nt = namedtuple( + "degrid_model_nt", + ( + "name", + "nxo", + "nyo", + "cellxo", + "cellyo", + "x0o", + "y0o", + "ipi", + "cpi" + ) +) @dataclass class Ingredients: model_columns: Set[Any] sky_models: Set[sky_model_nt] + degrid_models: Set[degrid_model_nt] @dataclass @@ -49,6 +65,8 @@ def transcribe_recipe(user_recipe): model_columns = set() sky_models = set() + degrid_models = set() + instructions = {} # Strip accidental whitepsace from input recipe and splits on ":". @@ -92,6 +110,18 @@ def transcribe_recipe(user_recipe): sky_models.add(sky_model) instructions[recipe_index].append(sky_model) + elif ".mds" in ingredient: + + filename, _, options = ingredient.partition("@") + options = literal_eval(options) # Add fail on missing option. + + if not os.path.exists(filename): + raise FileNotFoundError("{} not found.".format(filename)) + + degrid_model = degrid_model_nt(filename, *options) + degrid_models.add(degrid_model) + instructions[recipe_index].append(degrid_model) + elif ingredient != "": model_columns.add(ingredient) instructions[recipe_index].append(ingredient) @@ -99,19 +129,32 @@ def transcribe_recipe(user_recipe): else: instructions[recipe_index].append(ingredient) + # TODO: Add message to log. logger.info("The following model sources were obtained from " "--input-model-recipe: \n" " Columns: {} \n" - " Sky Models: {}", + " Sky Models: {} \n" + " Degrid Models: {}", model_columns or 'None', - {sm.name for sm in sky_models} or 'None') + {sm.name for sm in sky_models} or 'None', + {dm.name for dm in degrid_models} or 'None') # Generate a named tuple containing all the information required to # build the model visibilities. - model_recipe = Recipe(Ingredients(model_columns, sky_models), instructions) + model_recipe = Recipe( + Ingredients( + model_columns, + sky_models, + degrid_models + ), + instructions + ) if model_recipe.ingredients.sky_models: logger.info("Recipe contains sky models - enabling prediction step.") + if model_recipe.ingredients.degrid_models: + logger.info("Recipe contains degrid models - enabling degridding.") + return model_recipe From e1e3dcb48ccbb8933beee66140a6bd9c75d4684d Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Fri, 28 Jul 2023 17:40:33 +0200 Subject: [PATCH 03/26] Commit initial, semi-working code for generating model visibilities from pfb-style models. --- quartical/config/preprocess.py | 2 +- quartical/data_handling/degridder.py | 196 +++++++++++++++++++++++ quartical/data_handling/model_handler.py | 40 ++++- 3 files changed, 230 insertions(+), 8 deletions(-) create mode 100644 quartical/data_handling/degridder.py diff --git a/quartical/config/preprocess.py b/quartical/config/preprocess.py index ce75f009..ea11d82b 100644 --- a/quartical/config/preprocess.py +++ b/quartical/config/preprocess.py @@ -155,6 +155,6 @@ def transcribe_recipe(user_recipe): logger.info("Recipe contains sky models - enabling prediction step.") if model_recipe.ingredients.degrid_models: - logger.info("Recipe contains degrid models - enabling degridding.") + logger.info("Recipe contains degrid models - enabling degridding.") return model_recipe diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py new file mode 100644 index 00000000..24711897 --- /dev/null +++ b/quartical/data_handling/degridder.py @@ -0,0 +1,196 @@ +from collections import defaultdict +import numpy as np +import dask.array as da +from daskms.experimental.zarr import xds_from_zarr +from scipy.interpolate import RegularGridInterpolator +import sympy as sm +from sympy.utilities.lambdify import lambdify +from sympy.parsing.sympy_parser import parse_expr +from ducc0.wgridder.experimental import dirty2vis +from quartical.utils.collections import freeze_default_dict + + +def _degrid(time, freq, uvw, model): + + name, nxo, nyo, cellxo, cellyo, x0o, y0o, ipi, cpi = model + + # TODO: Dodgy pattern, as this will produce dask arrays which will + # need to be evaluated i.e. task-in-task which is a bad idea. OK for + # the purposes of a prototype. + model_xds = xds_from_zarr(model.name)[0] + + # TODO: We want to go from the model to an image cube appropriate for this + # chunks of data. The dimensions of the image cube will be determined + # by the time and freqency dimensions of the chunk in conjunction with + # the integrations per chunk and channels per chunk arguments. + + utime, utime_inv = np.unique(time, return_inverse=True) + n_utime = utime.size + ipi = ipi or n_utime # Catch zero case. NOTE: -1? + n_mean_times = int(np.ceil(n_utime / ipi)) + + n_freq = freq.size + n_mean_freqs = int(np.ceil(n_freq / cpi)) + + # Let's start off by simply reconstructing the model image as generated + # by pfb-clean. + + nxi, nyi = model_xds.npix_x, model_xds.npix_y + image_in = np.zeros((nxi, nyi), dtype=float) + params = sm.symbols(('t', 'f')) + params += sm.symbols(tuple(model_xds.params.values)) + symexpr = parse_expr(model_xds.parametrisation) + modelf = lambdify(params, symexpr) + texpr = parse_expr(model_xds.texpr) + tfunc = lambdify(params[0], texpr) + fexpr = parse_expr(model_xds.fexpr) + ffunc = lambdify(params[1], fexpr) + Ix = model_xds.location_x.values + Iy = model_xds.location_y.values + coeffs = model_xds.coefficients.values + + vis = np.empty((time.size, freq.size, 4), dtype=np.complex128) + + for ti in range(n_mean_times): + for fi in range(n_mean_freqs): + + freq_sel = slice(fi * cpi, (fi + 1) * cpi) + degrid_freq = freq[freq_sel] + mean_freq = degrid_freq.mean() + + time_sel = slice(ti * ipi, (ti + 1) * ipi) + degrid_time = utime[time_sel] + mean_time = degrid_time.mean() + + image_in[Ix, Iy] = modelf( + tfunc(mean_time), ffunc(mean_freq), *coeffs + ) + + # NOTE: Select out appropriate rows for ipi and make selection + # consistent. + row_sel = slice(None) + + # Degrid + dirty2vis( + vis=vis[row_sel, freq_sel, 0], + uvw=uvw, + freq=degrid_freq, + dirty=image_in, + pixsize_x=model_xds.cell_rad_x, # Should be output value. + pixsize_y=model_xds.cell_rad_y, # Should be output value. + center_x=model_xds.center_x, # Should be output value. + center_y=model_xds.center_y, # Should be output value. + epsilon=1e-7, # TODO: Is this too high? + do_wgridding=True, # Should be ok to leave True. + divide_by_n=False, # Until otherwise informed. + nthreads=6 # Should be equivalent to solver threads. + ) + + # Zero the image array between image slices. + image_in[:, :] = 0 + + # Degridder only produces I - will need to be more sophisticated. + vis[..., -1] = vis[..., 0] + + return vis + + # TODO: This was omitted for the sake of simplicity but we ultimately will + # want this functionality. Need to be very cautious with regard to which + # parameters get used. + + cellxi, cellyi = model_xds.cell_rad_x, model_xds.cell_rad_y + x0i, y0i = model_xds.center_x, model_xds.center_y + + xin = (-(nxi//2) + np.arange(nxi))*cellxi + x0i + yin = (-(nyi//2) + np.arange(nyi))*cellyi + y0i + xo = (-(nxo//2) + np.arange(nxo))*cellxo + x0o + yo = (-(nyo//2) + np.arange(nyo))*cellyo + y0o + + # how many pixels to pad by to extrapolate with zeros + xldiff = xin.min() - xo.min() + if xldiff > 0.0: + npadxl = int(np.ceil(xldiff/cellxi)) + else: + npadxl = 0 + yldiff = yin.min() - yo.min() + if yldiff > 0.0: + npadyl = int(np.ceil(yldiff/cellyi)) + else: + npadyl = 0 + + xudiff = xo.max() - xin.max() + if xudiff > 0.0: + npadxu = int(np.ceil(xudiff/cellxi)) + else: + npadxu = 0 + yudiff = yo.max() - yin.max() + if yudiff > 0.0: + npadyu = int(np.ceil(yudiff/cellyi)) + else: + npadyu = 0 + + do_pad = npadxl > 0 + do_pad |= npadxu > 0 + do_pad |= npadyl > 0 + do_pad |= npadyu > 0 + if do_pad: + image_in = np.pad( + image_in, + ((npadxl, npadxu), (npadyl, npadyu)), + mode='constant' + ) + + xin = (-(nxi//2+npadxl) + np.arange(nxi + npadxl + npadxu))*cellxi + x0i + nxi = nxi + npadxl + npadxu + yin = (-(nyi//2+npadyl) + np.arange(nyi + npadyl + npadyu))*cellyi + y0i + nyi = nyi + npadyl + npadyu + + do_interp = cellxi != cellxo + do_interp |= cellyi != cellyo + do_interp |= x0i != x0o + do_interp |= y0i != y0o + do_interp |= nxi != nxo + do_interp |= nyi != nyo + if do_interp: + interpo = RegularGridInterpolator((xin, yin), image_in, + bounds_error=True, method='linear') + xx, yy = np.meshgrid(xo, yo, indexing='ij') + return interpo((xx, yy)) + # elif (nxi != nxo) or (nyi != nyo): + # # only need the overlap in this case + # _, idx0, idx1 = np.intersect1d(xin, xo, assume_unique=True, return_indices=True) + # _, idy0, idy1 = np.intersect1d(yin, yo, assume_unique=True, return_indices=True) + # return image[idx0, idy0] + else: + return image_in + + +def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): + + degrid_models = model_vis_recipe.ingredients.degrid_models + + degrid_list = [] + + for data_xds in data_xds_list: + + model_vis = defaultdict(list) + + for degrid_model in degrid_models: + + degrid_vis = da.blockwise( + _degrid, ("rowlike", "chan", "corr"), + data_xds.TIME.data, ("rowlike",), + data_xds.CHAN_FREQ.data, ("chan",), + data_xds.UVW.data, ("rowlike", "uvw"), + degrid_model, None, + concatenate=True, + align_arrays=False, + meta=np.empty([0, 0, 0], dtype=np.complex128), + new_axes={"corr": 4} # TODO: Shouldn't be hardcoded. + ) + + model_vis[degrid_model].append(degrid_vis) + + degrid_list.append(freeze_default_dict(model_vis)) + + return degrid_list diff --git a/quartical/data_handling/model_handler.py b/quartical/data_handling/model_handler.py index 9871c5ad..e5b6943e 100644 --- a/quartical/data_handling/model_handler.py +++ b/quartical/data_handling/model_handler.py @@ -2,6 +2,7 @@ import dask.array as da import numpy as np from quartical.data_handling.predict import predict +from quartical.data_handling.degridder import degrid from quartical.data_handling.angles import apply_parangles from quartical.config.preprocess import IdentityRecipe, Ingredients from quartical.utils.array import flat_ident_like @@ -36,13 +37,38 @@ def add_model_graph( # Generates a predicition scheme (graph) per-xds. If no predict is # required, it is a list of empty dictionaries. - if model_vis_recipe.ingredients.sky_models: - predict_schemes = predict(data_xds_list, - model_vis_recipe, - ms_path, - model_opts) - else: - predict_schemes = [{}]*len(data_xds_list) + # TODO: Add handling for mds inputs. This will need to read the model + # and figure out the relevant steps to take. + + predict_required = bool(model_vis_recipe.ingredients.sky_models) + degrid_required = bool(model_vis_recipe.ingredients.degrid_models) + + # TODO: Ensure that things work correctly when we have a mixture of the + # below. + + predict_schemes = [{}]*len(data_xds_list) + + if predict_required: + rime_schemes = predict( + data_xds_list, + model_vis_recipe, + ms_path, + model_opts + ) + predict_schemes = [ + {**ps, **rs} for ps, rs in zip(predict_schemes, rime_schemes) + ] + + if degrid_required: + degrid_schemes = degrid( + data_xds_list, + model_vis_recipe, + ms_path, + model_opts + ) + predict_schemes = [ + {**ps, **ds} for ps, ds in zip(predict_schemes, degrid_schemes) + ] # Special case: in the event that we have an IdentityRecipe, modify the # datasets and model appropriately. From d7e69e9821675dc056fbc8bcf69e227691baecaa Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 09:45:05 +0200 Subject: [PATCH 04/26] Add ducc0 and sympy as extras under degrid banner in pyproject.toml. --- pyproject.toml | 5 +++++ quartical/data_handling/model_handler.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 68c1f4ee..ceee48d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,11 @@ pytest = "^7.3.1" omegaconf = "^2.3.0" colorama = "^0.4.6" stimela = "2.0rc4" +ducc0 = "^0.31.0" +sympy = "^1.12" + +[tool.poetry.extras] +degrid = ["ducc0", "sympy"] [tool.poetry.scripts] goquartical = 'quartical.executor:execute' diff --git a/quartical/data_handling/model_handler.py b/quartical/data_handling/model_handler.py index e5b6943e..9c237809 100644 --- a/quartical/data_handling/model_handler.py +++ b/quartical/data_handling/model_handler.py @@ -2,7 +2,6 @@ import dask.array as da import numpy as np from quartical.data_handling.predict import predict -from quartical.data_handling.degridder import degrid from quartical.data_handling.angles import apply_parangles from quartical.config.preprocess import IdentityRecipe, Ingredients from quartical.utils.array import flat_ident_like @@ -60,6 +59,17 @@ def add_model_graph( ] if degrid_required: + + try: + from quartical.data_handling.degridder import degrid + except ImportError: + raise ImportError( + "QuartiCal was unable to import the degrid module. This may " + "indicate that QuartiCal was installed without the necessary " + "extras. Please try 'pip install quartical[degrid]'. If the " + "error persists, please raise an issue." + ) + degrid_schemes = degrid( data_xds_list, model_vis_recipe, From 0b26002afa922af766ed0d5d71a438a8798e3d9b Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 14:59:56 +0200 Subject: [PATCH 05/26] Add/change functionlity to enable new-style models. --- quartical/config/__init__.py | 22 ++- quartical/config/argument_schema.yaml | 6 + quartical/config/config_classes.py | 9 ++ quartical/config/external.py | 22 ++- quartical/config/helper.py | 11 +- quartical/config/internal.py | 10 ++ quartical/config/model_component_schema.yaml | 61 +++++++++ quartical/config/preprocess.py | 134 ++++++++++++++++++- quartical/executor.py | 11 +- 9 files changed, 274 insertions(+), 12 deletions(-) create mode 100644 quartical/config/model_component_schema.yaml diff --git a/quartical/config/__init__.py b/quartical/config/__init__.py index b4167112..b66a5315 100644 --- a/quartical/config/__init__.py +++ b/quartical/config/__init__.py @@ -27,8 +27,8 @@ class _GainSchema(object): gain: Dict[str, Parameter] - # The gain section is loaded explicitly, since we need to form up multiple - # instances. + # The gain and model sections are loaded explicitly, since we need to form + # up multiple instances. gain_schema = oc.merge( oc.structured(_GainSchema), oc.load(f"{dirname}/gain_schema.yaml") @@ -42,3 +42,21 @@ class _GainSchema(object): bases=(BaseConfigSection,), post_init=POST_INIT_MAP['gain'] ) + + @dataclass + class _ModelComponentSchema(object): + model_component: Dict[str, Parameter] + + model_component_schema = oc.merge( + oc.structured(_ModelComponentSchema), + oc.load(f"{dirname}/model_component_schema.yaml") + ) + model_component_schema = model_component_schema.model_component + + # Create model dataclass. + ModelComponent = schema_utils.schema_to_dataclass( + model_component_schema, + "ModelComponent", + bases=(BaseConfigSection,), + post_init=POST_INIT_MAP['model_component'] + ) diff --git a/quartical/config/argument_schema.yaml b/quartical/config/argument_schema.yaml index 5e5ad044..28644749 100644 --- a/quartical/config/argument_schema.yaml +++ b/quartical/config/argument_schema.yaml @@ -119,6 +119,12 @@ input_model: given by the dE tagged clusters in LSM1. Leaving this value unset (the default) will use an identity model. + advanced_recipe: + dtype: Optional[str] + info: + Optional advanced recipe specification for use with degridder/more + complicated modes of model construction. + beam: dtype: Optional[str] info: diff --git a/quartical/config/config_classes.py b/quartical/config/config_classes.py index 5f67e0f7..9baffab9 100644 --- a/quartical/config/config_classes.py +++ b/quartical/config/config_classes.py @@ -111,6 +111,14 @@ def __input_model_post_init__(self): self.__validate_choices__() self.__validate_element_choices__() + assert not (self.recipe and self.advanced_recipe), \ + "recipe and advanced_recipe are mutually exclusive." + + +def __model_component_post_init__(self): + self.__validate_choices__() + self.__validate_element_choices__() + def __output_post_init__(self): @@ -182,6 +190,7 @@ def __gain_post_init__(self): POST_INIT_MAP = { "input_ms": __input_ms_post_init__, "input_model": __input_model_post_init__, + "model_component": __model_component_post_init__, "output": __output_post_init__, "mad_flags": __mad_flags_post_init__, "solver": __solver_post_init__, diff --git a/quartical/config/external.py b/quartical/config/external.py index ce426e05..5c3de452 100644 --- a/quartical/config/external.py +++ b/quartical/config/external.py @@ -1,8 +1,9 @@ +import re from dataclasses import make_dataclass from omegaconf import OmegaConf as oc from typing import Dict, Any from scabha.cargo import Parameter -from quartical.config import Gain, BaseConfig, gain_schema +from quartical.config import Gain, ModelComponent, BaseConfig, gain_schema def finalize_structure(additional_config): @@ -18,9 +19,26 @@ def finalize_structure(additional_config): # Use the default terms if no alternative is specified. terms = terms or BaseConfig.solver.terms + recipe = None + models = [] # No components by default. + + # Get last specified version of input_model.recipe. + for cfg in additional_config[::-1]: + recipe = oc.select(cfg, "input_model.advanced_recipe") + if recipe is not None: + ingredients = re.split(r'([\+~:])', recipe) + ingredients = [ + i for i in ingredients if not bool(re.search(r'([\+~:])', i)) + ] + models = set(i.split("@")[0] for i in ingredients) + break + FinalConfig = make_dataclass( "FinalConfig", - [(t, Gain, Gain()) for t in terms], + [ + *[(m, ModelComponent, ModelComponent()) for m in models], + *[(t, Gain, Gain()) for t in terms] + ], bases=(BaseConfig,) ) diff --git a/quartical/config/helper.py b/quartical/config/helper.py index f4885da6..00ebecdb 100644 --- a/quartical/config/helper.py +++ b/quartical/config/helper.py @@ -79,8 +79,15 @@ def help(): if len(sys.argv) != 1 and not help_arg: return - # Include a generic gain term in the help config. - additional_config = [oc.from_dotlist(["solver.terms=['gain']"])] + # Include a generic gain term and model component in the help config. + additional_config = [ + oc.from_dotlist( + [ + "input_model.advanced_recipe=model_component", + "solver.terms=['gain']" + ] + ) + ] HelpConfig = finalize_structure(additional_config) if len(sys.argv) == 1 or help_arg == "help": diff --git a/quartical/config/internal.py b/quartical/config/internal.py index e86e6a9f..a8b06231 100644 --- a/quartical/config/internal.py +++ b/quartical/config/internal.py @@ -1,5 +1,15 @@ from daskms.fsspec_store import DaskMSStore from quartical.gains import TERM_TYPES +from quartical.config import ModelComponent + + +def get_component_dict(opts): + + return { + k: getattr(opts, k) + for k in opts.__dataclass_fields__.keys() + if isinstance(getattr(opts, k), ModelComponent) + } def gains_to_chain(opts): diff --git a/quartical/config/model_component_schema.yaml b/quartical/config/model_component_schema.yaml new file mode 100644 index 00000000..b1629c23 --- /dev/null +++ b/quartical/config/model_component_schema.yaml @@ -0,0 +1,61 @@ +model_component: + path_or_name: + dtype: str + required: true + info: + Path to model/name of column. + type: + dtype: str + choices: + - mds # Replace with awesome acronym eventually. + - tigger-lsm + - column + default: column + info: + Type of model component + tags: + dtype: Optional[List[str]] + info: + Tag for use in the predict. + region: + dtype: Optional[str] + info: + Name of region file to use in degridder. + npix_x: + dtype: Optional[int] + info: + Image x size in pixels for use in degridding. + npix_y: + dtype: Optional[int] + info: + Image y size in pixels for use in degridding. + cellsize_x: + dtype: Optional[float] + info: + Pixel x cellsize in radians. + cellsize_y: + dtype: Optional[float] + info: + Pixel y cellsize in radians. + centre_x: + dtype: int + default: 0 + info: + x coordinate of central pixel. + centre_y: + dtype: int + default: 0 + info: + y coordinate of central pixel. + integrations_per_image: + dtype: int + default: 0 + info: + Number of integrations per image to use in degridding. The default + (zero) is all times in a chunk. + channels_per_image: + dtype: int + default: 0 + info: + Number of channels per image to use in degridding. The default + (zero) is all channels in a chunk. diff --git a/quartical/config/preprocess.py b/quartical/config/preprocess.py index ea11d82b..c016067d 100644 --- a/quartical/config/preprocess.py +++ b/quartical/config/preprocess.py @@ -44,7 +44,7 @@ class IdentityRecipe(Recipe): pass -def transcribe_recipe(user_recipe): +def transcribe_legacy_recipe(user_recipe): """Interpret the model recipe string. Given the config object, create an internal recipe implementing the user @@ -72,10 +72,6 @@ def transcribe_recipe(user_recipe): # Strip accidental whitepsace from input recipe and splits on ":". input_recipes = user_recipe.replace(" ", "").split(":") - if input_recipes == ['']: - raise ValueError("No model recipe was specified. Please set/check " - "--input-model-recipe.") - for recipe_index, recipe in enumerate(input_recipes): instructions[recipe_index] = [] @@ -158,3 +154,131 @@ def transcribe_recipe(user_recipe): logger.info("Recipe contains degrid models - enabling degridding.") return model_recipe + + +def transcribe_recipe(user_recipe, model_components): + """Interpret the model recipe string. + + Given the config object, create an internal recipe implementing the user + specified recipe. + + Args: + model_opts: An ModelInputs configuration object. + + Returns: + model_Recipe: A Recipe object. + """ + + if user_recipe is None: + logger.warning( + "input_model.recipe was not supplied. Assuming identity model." + ) + return IdentityRecipe(Ingredients(set(), set()), dict()) + + model_columns = set() + sky_models = set() + degrid_models = set() + + instructions = {} + + # Strip accidental whitepspace from input recipe and splits on ":". + input_recipes = user_recipe.replace(" ", "").split(":") + + for recipe_index, recipe in enumerate(input_recipes): + + instructions[recipe_index] = [] + + # A raw string is required to avoid insane escape characters. Splits + # on understood operators, ~ for subtract, + for add. + + ingredients = re.split(r'([\+~])', recipe) + + # Behaviour of re.split guarantees every second term is either a column + # or .lsm. This may lead to the first element being an empty string. + + # Split the ingredients into operations and model sources. We preserve + # empty strings in the recipe to avoid more complicated code elsewhere. + + for ingredient in ingredients: + + if ingredient in "~+" and ingredient != "": + + operation = da.add if ingredient == "+" else da.subtract + instructions[recipe_index].append(operation) + continue + + component = model_components.get(ingredient) + + if component.type == "tigger-lsm": + + filename = component.path_or_name + tags = component.tags or () + + if not os.path.isfile(filename): + raise FileNotFoundError("{} not found.".format(filename)) + + sky_model = sky_model_nt(filename, tags) + sky_models.add(sky_model) + instructions[recipe_index].append(sky_model) + + elif component.type == "mds": + + filename = component.path_or_name + + options = ( + component.npix_x, + component.npix_y, + component.cellsize_x, + component.cellsize_y, + component.centre_x, + component.centre_y, + component.integrations_per_image, + component.channels_per_image, + ) + + if not os.path.exists(filename): + raise FileNotFoundError("{} not found.".format(filename)) + + degrid_model = degrid_model_nt(filename, *options) + degrid_models.add(degrid_model) + instructions[recipe_index].append(degrid_model) + + elif component.type == "column": + + column_name = component.path_or_name + + model_columns.add(column_name) + instructions[recipe_index].append(column_name) + + else: + instructions[recipe_index].append(ingredient) + + # TODO: Add message to log. + logger.info("The following model sources were obtained from " + "--input-model-recipe: \n" + " Columns: {} \n" + " Sky Models: {} \n" + " Degrid Models: {}", + model_columns or 'None', + {sm.name for sm in sky_models} or 'None', + {dm.name for dm in degrid_models} or 'None') + + # Generate a named tuple containing all the information required to + # build the model visibilities. + + model_recipe = Recipe( + Ingredients( + model_columns, + sky_models, + degrid_models + ), + instructions + ) + + if model_recipe.ingredients.sky_models: + logger.info("Recipe contains sky models - enabling prediction step.") + + if model_recipe.ingredients.degrid_models: + logger.info("Recipe contains degrid models - enabling degridding.") + + return model_recipe diff --git a/quartical/executor.py b/quartical/executor.py index ed0bcec7..1addfed7 100644 --- a/quartical/executor.py +++ b/quartical/executor.py @@ -44,6 +44,7 @@ def _execute(exitstack): output_opts = opts.output mad_flag_opts = opts.mad_flags dask_opts = opts.dask + model_components = internal.get_component_dict(opts) chain = internal.gains_to_chain(opts) # Special handling. # Init the logging proxy - an object which helps us ensure that logging @@ -60,7 +61,15 @@ def _execute(exitstack): # Now that we know where to put the log, log the final config state. parser.log_final_config(opts, config_files) - model_vis_recipe = preprocess.transcribe_recipe(model_opts.recipe) + # TODO: Deprecate legacy models. + if model_opts.recipe: + model_vis_recipe = preprocess.transcribe_legacy_recipe( + model_opts.recipe + ) + else: + model_vis_recipe = preprocess.transcribe_recipe( + model_opts.advanced_recipe, model_components + ) if dask_opts.scheduler == "distributed": From aed942a23a9aa0f81867f9dd103cd8aebf6372ff Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 15:17:44 +0200 Subject: [PATCH 06/26] Change advanced model to a simple boolean flag. --- quartical/config/argument_schema.yaml | 5 +++-- quartical/config/config_classes.py | 3 --- quartical/config/external.py | 7 ++++--- quartical/executor.py | 10 +++++----- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/quartical/config/argument_schema.yaml b/quartical/config/argument_schema.yaml index 28644749..8deb2d72 100644 --- a/quartical/config/argument_schema.yaml +++ b/quartical/config/argument_schema.yaml @@ -120,9 +120,10 @@ input_model: (the default) will use an identity model. advanced_recipe: - dtype: Optional[str] + dtype: bool + default: false info: - Optional advanced recipe specification for use with degridder/more + Enable advanced recipe specification for use with degridder/more complicated modes of model construction. beam: diff --git a/quartical/config/config_classes.py b/quartical/config/config_classes.py index 9baffab9..d06b7d39 100644 --- a/quartical/config/config_classes.py +++ b/quartical/config/config_classes.py @@ -111,9 +111,6 @@ def __input_model_post_init__(self): self.__validate_choices__() self.__validate_element_choices__() - assert not (self.recipe and self.advanced_recipe), \ - "recipe and advanced_recipe are mutually exclusive." - def __model_component_post_init__(self): self.__validate_choices__() diff --git a/quartical/config/external.py b/quartical/config/external.py index 5c3de452..f3246d7b 100644 --- a/quartical/config/external.py +++ b/quartical/config/external.py @@ -24,13 +24,14 @@ def finalize_structure(additional_config): # Get last specified version of input_model.recipe. for cfg in additional_config[::-1]: - recipe = oc.select(cfg, "input_model.advanced_recipe") - if recipe is not None: + advanced_recipe = oc.select(cfg, "input_model.advanced_recipe") + recipe = oc.select(cfg, "input_model.recipe") + if recipe is not None and advanced_recipe: ingredients = re.split(r'([\+~:])', recipe) ingredients = [ i for i in ingredients if not bool(re.search(r'([\+~:])', i)) ] - models = set(i.split("@")[0] for i in ingredients) + models = list(dict.fromkeys(i.split("@")[0] for i in ingredients)) break FinalConfig = make_dataclass( diff --git a/quartical/executor.py b/quartical/executor.py index 1addfed7..69f91a11 100644 --- a/quartical/executor.py +++ b/quartical/executor.py @@ -62,13 +62,13 @@ def _execute(exitstack): parser.log_final_config(opts, config_files) # TODO: Deprecate legacy models. - if model_opts.recipe: - model_vis_recipe = preprocess.transcribe_legacy_recipe( - model_opts.recipe + if model_opts.advanced_recipe: + model_vis_recipe = preprocess.transcribe_recipe( + model_opts.recipe, model_components ) else: - model_vis_recipe = preprocess.transcribe_recipe( - model_opts.advanced_recipe, model_components + model_vis_recipe = preprocess.transcribe_legacy_recipe( + model_opts.recipe ) if dask_opts.scheduler == "distributed": From 0d43f3e386c2187777de59791f3fbccfb23e8109 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 16:16:06 +0200 Subject: [PATCH 07/26] Renaming. --- quartical/config/preprocess.py | 16 +++---- quartical/data_handling/degridder.py | 62 +++++++++++++++++----------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/quartical/config/preprocess.py b/quartical/config/preprocess.py index c016067d..125e86bb 100644 --- a/quartical/config/preprocess.py +++ b/quartical/config/preprocess.py @@ -14,14 +14,14 @@ "degrid_model_nt", ( "name", - "nxo", - "nyo", - "cellxo", - "cellyo", - "x0o", - "y0o", - "ipi", - "cpi" + "npix_x", + "npix_y", + "cellsize_x", + "cellsize_y", + "centre_x", + "centre_y", + "integrations_per_image", + "channels_per_image" ) ) diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index 24711897..4852c3a6 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -10,14 +10,22 @@ from quartical.utils.collections import freeze_default_dict -def _degrid(time, freq, uvw, model): - - name, nxo, nyo, cellxo, cellyo, x0o, y0o, ipi, cpi = model +def _degrid(time, freq, uvw, component): + + name = component.name + npix_x = component.npix_x # Will be used in interpolation mode. + npix_y = component.npix_y # Will be used in interpolation mode. + cellsize_x = component.cellsize_x + cellsize_y = component.cellsize_y + centre_x = component.centre_x + centre_y = component.centre_y + integrations_per_image = component.integrations_per_image + channels_per_image = component.channels_per_image # TODO: Dodgy pattern, as this will produce dask arrays which will # need to be evaluated i.e. task-in-task which is a bad idea. OK for # the purposes of a prototype. - model_xds = xds_from_zarr(model.name)[0] + model_xds = xds_from_zarr(name)[0] # TODO: We want to go from the model to an image cube appropriate for this # chunks of data. The dimensions of the image cube will be determined @@ -26,29 +34,35 @@ def _degrid(time, freq, uvw, model): utime, utime_inv = np.unique(time, return_inverse=True) n_utime = utime.size - ipi = ipi or n_utime # Catch zero case. NOTE: -1? + ipi = integrations_per_image or n_utime n_mean_times = int(np.ceil(n_utime / ipi)) n_freq = freq.size + cpi = channels_per_image or n_freq n_mean_freqs = int(np.ceil(n_freq / cpi)) # Let's start off by simply reconstructing the model image as generated # by pfb-clean. - nxi, nyi = model_xds.npix_x, model_xds.npix_y - image_in = np.zeros((nxi, nyi), dtype=float) + native_npix_x = model_xds.npix_x + native_npix_y = model_xds.npix_y + native_image = np.zeros((native_npix_x, native_npix_y), dtype=float) + + # Sey up sympy symbols for expression evaluation. params = sm.symbols(('t', 'f')) params += sm.symbols(tuple(model_xds.params.values)) symexpr = parse_expr(model_xds.parametrisation) - modelf = lambdify(params, symexpr) - texpr = parse_expr(model_xds.texpr) - tfunc = lambdify(params[0], texpr) - fexpr = parse_expr(model_xds.fexpr) - ffunc = lambdify(params[1], fexpr) - Ix = model_xds.location_x.values - Iy = model_xds.location_y.values - coeffs = model_xds.coefficients.values + model_func = lambdify(params, symexpr) + time_expr = parse_expr(model_xds.texpr) + time_func = lambdify(params[0], time_expr) + freq_expr = parse_expr(model_xds.fexpr) + freq_func = lambdify(params[1], freq_expr) + + pixel_xs = model_xds.location_x.values + pixel_ys = model_xds.location_y.values + pixel_coeffs = model_xds.coefficients.values + # TODO: How do we handle the correlation axis neatly? vis = np.empty((time.size, freq.size, 4), dtype=np.complex128) for ti in range(n_mean_times): @@ -62,8 +76,8 @@ def _degrid(time, freq, uvw, model): degrid_time = utime[time_sel] mean_time = degrid_time.mean() - image_in[Ix, Iy] = modelf( - tfunc(mean_time), ffunc(mean_freq), *coeffs + native_image[pixel_xs, pixel_ys] = model_func( + time_func(mean_time), freq_func(mean_freq), *pixel_coeffs ) # NOTE: Select out appropriate rows for ipi and make selection @@ -75,19 +89,19 @@ def _degrid(time, freq, uvw, model): vis=vis[row_sel, freq_sel, 0], uvw=uvw, freq=degrid_freq, - dirty=image_in, - pixsize_x=model_xds.cell_rad_x, # Should be output value. - pixsize_y=model_xds.cell_rad_y, # Should be output value. - center_x=model_xds.center_x, # Should be output value. - center_y=model_xds.center_y, # Should be output value. + dirty=native_image, + pixsize_x=cellsize_x or model_xds.cell_rad_x, + pixsize_y=cellsize_y or model_xds.cell_rad_y, + center_x=centre_x or model_xds.center_x, + center_y=centre_y or model_xds.center_y, epsilon=1e-7, # TODO: Is this too high? do_wgridding=True, # Should be ok to leave True. divide_by_n=False, # Until otherwise informed. nthreads=6 # Should be equivalent to solver threads. ) - # Zero the image array between image slices. - image_in[:, :] = 0 + # Zero the image array between image slices as a precaution. + native_image[:, :] = 0 # Degridder only produces I - will need to be more sophisticated. vis[..., -1] = vis[..., 0] From 6aa67f28ee6d54ea5e5ea87caa8bfdfa83c37766 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 16:46:22 +0200 Subject: [PATCH 08/26] Fix tag format. --- quartical/config/model_component_schema.yaml | 8 ++++---- quartical/config/preprocess.py | 3 +-- quartical/data_handling/degridder.py | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/quartical/config/model_component_schema.yaml b/quartical/config/model_component_schema.yaml index b1629c23..1ed5b6a6 100644 --- a/quartical/config/model_component_schema.yaml +++ b/quartical/config/model_component_schema.yaml @@ -17,10 +17,10 @@ model_component: dtype: Optional[List[str]] info: Tag for use in the predict. - region: - dtype: Optional[str] - info: - Name of region file to use in degridder. + # region: # Add when implemented. + # dtype: Optional[str] + # info: + # Name of region file to use in degridder. npix_x: dtype: Optional[int] info: diff --git a/quartical/config/preprocess.py b/quartical/config/preprocess.py index 125e86bb..f8960031 100644 --- a/quartical/config/preprocess.py +++ b/quartical/config/preprocess.py @@ -217,7 +217,7 @@ def transcribe_recipe(user_recipe, model_components): if not os.path.isfile(filename): raise FileNotFoundError("{} not found.".format(filename)) - sky_model = sky_model_nt(filename, tags) + sky_model = sky_model_nt(filename, tuple(tags)) sky_models.add(sky_model) instructions[recipe_index].append(sky_model) @@ -253,7 +253,6 @@ def transcribe_recipe(user_recipe, model_components): else: instructions[recipe_index].append(ingredient) - # TODO: Add message to log. logger.info("The following model sources were obtained from " "--input-model-recipe: \n" " Columns: {} \n" diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index 4852c3a6..cdf0ee4b 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -97,13 +97,13 @@ def _degrid(time, freq, uvw, component): epsilon=1e-7, # TODO: Is this too high? do_wgridding=True, # Should be ok to leave True. divide_by_n=False, # Until otherwise informed. - nthreads=6 # Should be equivalent to solver threads. + nthreads=6 # TODO: Should be equivalent to solver threads. ) # Zero the image array between image slices as a precaution. native_image[:, :] = 0 - # Degridder only produces I - will need to be more sophisticated. + # TODO: Degridder only produces I - will need to be more sophisticated. vis[..., -1] = vis[..., 0] return vis From bba12a2fd91a8f44f6b3ddd9113f68f61a56b481 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 17:09:05 +0200 Subject: [PATCH 09/26] Make dask usage less unholy. --- quartical/data_handling/degridder.py | 44 +++++++++++++++------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index cdf0ee4b..fefc869a 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -1,6 +1,7 @@ from collections import defaultdict import numpy as np import dask.array as da +import xarray from daskms.experimental.zarr import xds_from_zarr from scipy.interpolate import RegularGridInterpolator import sympy as sm @@ -10,9 +11,8 @@ from quartical.utils.collections import freeze_default_dict -def _degrid(time, freq, uvw, component): +def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds): - name = component.name npix_x = component.npix_x # Will be used in interpolation mode. npix_y = component.npix_y # Will be used in interpolation mode. cellsize_x = component.cellsize_x @@ -22,11 +22,6 @@ def _degrid(time, freq, uvw, component): integrations_per_image = component.integrations_per_image channels_per_image = component.channels_per_image - # TODO: Dodgy pattern, as this will produce dask arrays which will - # need to be evaluated i.e. task-in-task which is a bad idea. OK for - # the purposes of a prototype. - model_xds = xds_from_zarr(name)[0] - # TODO: We want to go from the model to an image cube appropriate for this # chunks of data. The dimensions of the image cube will be determined # by the time and freqency dimensions of the chunk in conjunction with @@ -44,23 +39,22 @@ def _degrid(time, freq, uvw, component): # Let's start off by simply reconstructing the model image as generated # by pfb-clean. - native_npix_x = model_xds.npix_x - native_npix_y = model_xds.npix_y + native_npix_x = meta_xds.npix_x + native_npix_y = meta_xds.npix_y native_image = np.zeros((native_npix_x, native_npix_y), dtype=float) # Sey up sympy symbols for expression evaluation. params = sm.symbols(('t', 'f')) - params += sm.symbols(tuple(model_xds.params.values)) - symexpr = parse_expr(model_xds.parametrisation) + params += sm.symbols(tuple(meta_xds.params.values)) + symexpr = parse_expr(meta_xds.parametrisation) model_func = lambdify(params, symexpr) - time_expr = parse_expr(model_xds.texpr) + time_expr = parse_expr(meta_xds.texpr) time_func = lambdify(params[0], time_expr) - freq_expr = parse_expr(model_xds.fexpr) + freq_expr = parse_expr(meta_xds.fexpr) freq_func = lambdify(params[1], freq_expr) - pixel_xs = model_xds.location_x.values - pixel_ys = model_xds.location_y.values - pixel_coeffs = model_xds.coefficients.values + pixel_xs = meta_xds.location_x.values + pixel_ys = meta_xds.location_y.values # TODO: How do we handle the correlation axis neatly? vis = np.empty((time.size, freq.size, 4), dtype=np.complex128) @@ -90,10 +84,10 @@ def _degrid(time, freq, uvw, component): uvw=uvw, freq=degrid_freq, dirty=native_image, - pixsize_x=cellsize_x or model_xds.cell_rad_x, - pixsize_y=cellsize_y or model_xds.cell_rad_y, - center_x=centre_x or model_xds.center_x, - center_y=centre_y or model_xds.center_y, + pixsize_x=cellsize_x or meta_xds.cell_rad_x, + pixsize_y=cellsize_y or meta_xds.cell_rad_y, + center_x=centre_x or meta_xds.center_x, + center_y=centre_y or meta_xds.center_y, epsilon=1e-7, # TODO: Is this too high? do_wgridding=True, # Should be ok to leave True. divide_by_n=False, # Until otherwise informed. @@ -191,12 +185,22 @@ def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): for degrid_model in degrid_models: + model_xds = xds_from_zarr(degrid_model.name)[0] + + # NOTE: This is convenient but will result in some extra + # information being embedded in the graph. Shouldn't be a problem. + meta_xds = xarray.Dataset( + coords=model_xds.coords, attrs=model_xds.attrs + ) + degrid_vis = da.blockwise( _degrid, ("rowlike", "chan", "corr"), data_xds.TIME.data, ("rowlike",), data_xds.CHAN_FREQ.data, ("chan",), data_xds.UVW.data, ("rowlike", "uvw"), + model_xds.coefficients.data, ("params", "comps"), degrid_model, None, + meta_xds, None, concatenate=True, align_arrays=False, meta=np.empty([0, 0, 0], dtype=np.complex128), From 808e17a35ce8e32b7b0dc496643d2e644b00da7c Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 17:28:55 +0200 Subject: [PATCH 10/26] Add clone. --- quartical/data_handling/degridder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index fefc869a..c325385e 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -1,6 +1,7 @@ from collections import defaultdict import numpy as np import dask.array as da +from dask.graph_manipulation import clone import xarray from daskms.experimental.zarr import xds_from_zarr from scipy.interpolate import RegularGridInterpolator @@ -196,7 +197,7 @@ def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): degrid_vis = da.blockwise( _degrid, ("rowlike", "chan", "corr"), data_xds.TIME.data, ("rowlike",), - data_xds.CHAN_FREQ.data, ("chan",), + clone(data_xds.CHAN_FREQ.data), ("chan",), data_xds.UVW.data, ("rowlike", "uvw"), model_xds.coefficients.data, ("params", "comps"), degrid_model, None, From 519106b88cfce0b21bf3b2207eaceb0081c05465 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Mon, 31 Jul 2023 17:58:23 +0200 Subject: [PATCH 11/26] Don't init vis as empty - duh. --- quartical/data_handling/degridder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index c325385e..82c8463c 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -58,7 +58,7 @@ def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds): pixel_ys = meta_xds.location_y.values # TODO: How do we handle the correlation axis neatly? - vis = np.empty((time.size, freq.size, 4), dtype=np.complex128) + vis = np.zeros((time.size, freq.size, 4), dtype=np.complex128) for ti in range(n_mean_times): for fi in range(n_mean_freqs): From 58eff1927ca800e912e8434a29f64f0b10111c82 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Tue, 1 Aug 2023 11:40:10 +0200 Subject: [PATCH 12/26] Fix correlation handling in degrid code. --- quartical/data_handling/degridder.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index 82c8463c..07aa2ed0 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -12,7 +12,7 @@ from quartical.utils.collections import freeze_default_dict -def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds): +def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds, n_corr): npix_x = component.npix_x # Will be used in interpolation mode. npix_y = component.npix_y # Will be used in interpolation mode. @@ -58,7 +58,7 @@ def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds): pixel_ys = meta_xds.location_y.values # TODO: How do we handle the correlation axis neatly? - vis = np.zeros((time.size, freq.size, 4), dtype=np.complex128) + vis = np.zeros((time.size, freq.size, n_corr), dtype=np.complex128) for ti in range(n_mean_times): for fi in range(n_mean_freqs): @@ -182,6 +182,8 @@ def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): for data_xds in data_xds_list: + n_corr = data_xds.dims["corr"] + model_vis = defaultdict(list) for degrid_model in degrid_models: @@ -202,10 +204,11 @@ def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): model_xds.coefficients.data, ("params", "comps"), degrid_model, None, meta_xds, None, + n_corr, None, concatenate=True, align_arrays=False, meta=np.empty([0, 0, 0], dtype=np.complex128), - new_axes={"corr": 4} # TODO: Shouldn't be hardcoded. + new_axes={"corr": n_corr} ) model_vis[degrid_model].append(degrid_vis) From 2371715cfb124a867e58c016eda6e8ea954103ec Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Tue, 1 Aug 2023 11:54:12 +0200 Subject: [PATCH 13/26] Expose degrid (predict) threading in model arguments. --- quartical/config/argument_schema.yaml | 9 +++++++++ quartical/data_handling/degridder.py | 16 ++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/quartical/config/argument_schema.yaml b/quartical/config/argument_schema.yaml index 8deb2d72..ac7d768c 100644 --- a/quartical/config/argument_schema.yaml +++ b/quartical/config/argument_schema.yaml @@ -126,6 +126,15 @@ input_model: Enable advanced recipe specification for use with degridder/more complicated modes of model construction. + threads: + dtype: int + default: 1 + info: + Controls the number of threads used internally by the various predict + methods. This should typically be set to the same value as + solver.threads if solver.threads is in use. Currently only supported + for degridding. + beam: dtype: Optional[str] info: diff --git a/quartical/data_handling/degridder.py b/quartical/data_handling/degridder.py index 07aa2ed0..f4630b37 100644 --- a/quartical/data_handling/degridder.py +++ b/quartical/data_handling/degridder.py @@ -12,7 +12,16 @@ from quartical.utils.collections import freeze_default_dict -def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds, n_corr): +def _degrid( + time, + freq, + uvw, + pixel_coeffs, + component, + meta_xds, + n_corr, + n_thread +): npix_x = component.npix_x # Will be used in interpolation mode. npix_y = component.npix_y # Will be used in interpolation mode. @@ -92,7 +101,7 @@ def _degrid(time, freq, uvw, pixel_coeffs, component, meta_xds, n_corr): epsilon=1e-7, # TODO: Is this too high? do_wgridding=True, # Should be ok to leave True. divide_by_n=False, # Until otherwise informed. - nthreads=6 # TODO: Should be equivalent to solver threads. + nthreads=n_thread ) # Zero the image array between image slices as a precaution. @@ -178,6 +187,8 @@ def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): degrid_models = model_vis_recipe.ingredients.degrid_models + n_thread = model_opts.threads + degrid_list = [] for data_xds in data_xds_list: @@ -205,6 +216,7 @@ def degrid(data_xds_list, model_vis_recipe, ms_path, model_opts): degrid_model, None, meta_xds, None, n_corr, None, + n_thread, None, concatenate=True, align_arrays=False, meta=np.empty([0, 0, 0], dtype=np.complex128), From 7f1bc9aa7dab125474275c125bdc3f7d01465dce Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Tue, 1 Aug 2023 12:42:11 +0200 Subject: [PATCH 14/26] Disable degrid inputs in non-advanced mode. --- quartical/config/preprocess.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/quartical/config/preprocess.py b/quartical/config/preprocess.py index f8960031..13dd69be 100644 --- a/quartical/config/preprocess.py +++ b/quartical/config/preprocess.py @@ -6,7 +6,6 @@ import os.path from dataclasses import dataclass from typing import List, Dict, Set, Any -from ast import literal_eval sky_model_nt = namedtuple("sky_model_nt", ("name", "tags")) @@ -108,15 +107,10 @@ def transcribe_legacy_recipe(user_recipe): elif ".mds" in ingredient: - filename, _, options = ingredient.partition("@") - options = literal_eval(options) # Add fail on missing option. - - if not os.path.exists(filename): - raise FileNotFoundError("{} not found.".format(filename)) - - degrid_model = degrid_model_nt(filename, *options) - degrid_models.add(degrid_model) - instructions[recipe_index].append(degrid_model) + # TODO: Add link to documentation. + raise ValueError( + ".mds inputs are only supported in advanced model mode." + ) elif ingredient != "": model_columns.add(ingredient) @@ -150,9 +144,6 @@ def transcribe_legacy_recipe(user_recipe): if model_recipe.ingredients.sky_models: logger.info("Recipe contains sky models - enabling prediction step.") - if model_recipe.ingredients.degrid_models: - logger.info("Recipe contains degrid models - enabling degridding.") - return model_recipe From 00af5188547223cf8391f47b5eaeac98fc4c821f Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Tue, 1 Aug 2023 15:45:34 +0200 Subject: [PATCH 15/26] Use legacy models in tests until such time as I can reconfigure them. --- testing/fixtures/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/fixtures/config.py b/testing/fixtures/config.py index 7850157c..bbc9830a 100644 --- a/testing/fixtures/config.py +++ b/testing/fixtures/config.py @@ -1,5 +1,5 @@ import pytest -from quartical.config.preprocess import transcribe_recipe +from quartical.config.preprocess import transcribe_legacy_recipe from quartical.config.internal import gains_to_chain @@ -40,4 +40,4 @@ def chain(opts): @pytest.fixture(scope="module") def recipe(model_opts): - return transcribe_recipe(model_opts.recipe) + return transcribe_legacy_recipe(model_opts.recipe) From aaaf76afb3978c2f292b1ab6d36263d648593e40 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Thu, 3 Aug 2023 18:35:25 +0200 Subject: [PATCH 16/26] Utilise environment variable when dask.address is unset. (#288) * Fix version drift. * Bump to 0.2.0 * Inspect envvar for scheduler address when one isn't specified. * Encode environment varraible as ascii. * Simplify. --- quartical/executor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/quartical/executor.py b/quartical/executor.py index ed0bcec7..06fc85ae 100644 --- a/quartical/executor.py +++ b/quartical/executor.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from contextlib import ExitStack +import os import sys from loguru import logger import dask @@ -68,9 +69,14 @@ def _execute(exitstack): # distributed enviroment. This *may* be dangerous. Monitor. dask.config.set({"distributed.worker.daemon": False}) - if dask_opts.address: - logger.info("Initializing distributed client.") - client = exitstack.enter_context(Client(dask_opts.address)) + address = dask_opts.address or os.environ.get("DASK_SCHEDULER_ADDRESS") + + if address: + logger.info( + f"Initializing distributed client using scheduler address: " + f"{address}" + ) + client = exitstack.enter_context(Client(address)) else: logger.info("Initializing distributed client using LocalCluster.") cluster = LocalCluster( From ad4f91bb2869eeedb975a8d5baed3e4ebdbbf5b0 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Fri, 4 Aug 2023 11:31:23 +0200 Subject: [PATCH 17/26] Fix incorrect import in tests. --- testing/tests/config/test_preprocess.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/testing/tests/config/test_preprocess.py b/testing/tests/config/test_preprocess.py index ed6032d3..6b983725 100644 --- a/testing/tests/config/test_preprocess.py +++ b/testing/tests/config/test_preprocess.py @@ -1,5 +1,5 @@ import pytest -from quartical.config.preprocess import transcribe_recipe, sky_model_nt +from quartical.config.preprocess import transcribe_legacy_recipe, sky_model_nt import dask.array as da import os.path @@ -51,7 +51,7 @@ # we do not attempt to validate column names in the preprocess step. invalid_recipes = { - "": ValueError, + # "": ValueError, # NOTE: This case may not be needed. Omitting. "dummy.lsm.html": FileNotFoundError } @@ -73,7 +73,7 @@ def test_transcribe_recipe_valid(valid_recipe, monkeypatch): # Patch isfile functionality to allow use of ficticious files. monkeypatch.setattr(os.path, "isfile", lambda filename: True) - recipe = transcribe_recipe(input_recipe) + recipe = transcribe_legacy_recipe(input_recipe) # Check that the opts has been updated with the correct internal recipe. assert recipe.instructions == expected_output @@ -87,4 +87,4 @@ def test_transcribe_recipe_invalid(invalid_recipe): input_recipe, expected_output = invalid_recipe with pytest.raises(expected_output): - transcribe_recipe(input_recipe) + transcribe_legacy_recipe(input_recipe) From 232ae41d937491b3d88a0134a6bf5f3595fa67e2 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Fri, 4 Aug 2023 16:00:55 +0200 Subject: [PATCH 18/26] Commit WIP code for experimental fragment branch. --- pyproject.toml | 2 +- quartical/apps/summary.py | 31 +++++++++++++++------------ quartical/data_handling/angles.py | 8 +++---- quartical/data_handling/chunking.py | 9 +++++--- quartical/data_handling/ms_handler.py | 29 +++++++++++++++++-------- quartical/data_handling/predict.py | 20 ++++++++++------- 6 files changed, 60 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ceee48d2..df3ca683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ columnar = "^1.4.1" "ruamel.yaml" = "^0.17.26" dask = {extras = ["diagnostics"], version = "^2023.1.0"} distributed = "^2023.1.0" -dask-ms = {extras = ["s3", "xarray", "zarr"], version = "^0.2.16"} +dask-ms = {git = "https://github.com/ratt-ru/dask-ms.git", branch = "multisource-experimental", extras = ["s3", "xarray", "zarr"]} codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = "^0.3.4"} astro-tigger-lsm = "^1.7.2" loguru = "^0.7.0" diff --git a/quartical/apps/summary.py b/quartical/apps/summary.py index c6e5d915..1d0a8dd2 100644 --- a/quartical/apps/summary.py +++ b/quartical/apps/summary.py @@ -1,6 +1,9 @@ import argparse from pathlib import Path -from daskms import xds_from_storage_ms, xds_from_storage_table +from daskms.experimental.multisource import ( + xds_from_ms_fragment, + xds_from_table_fragment +) from daskms.fsspec_store import DaskMSStore import numpy as np import dask.array as da @@ -44,7 +47,7 @@ def configure_loguru(output_dir): def antenna_info(path): # NOTE: Assume one dataset for now. - ant_xds = xds_from_storage_table(path + "::ANTENNA")[0] + ant_xds = xds_from_table_fragment(path + "::ANTENNA")[0] antenna_names = ant_xds.NAME.values antenna_mounts = ant_xds.MOUNT.values @@ -64,7 +67,7 @@ def antenna_info(path): def data_desc_info(path): - dd_xds_list = xds_from_storage_table( # noqa + dd_xds_list = xds_from_table_fragment( # noqa path + "::DATA_DESCRIPTION", group_cols=["__row__"], chunks={"row": 1, "chan": -1} @@ -76,7 +79,7 @@ def data_desc_info(path): def feed_info(path): - feed_xds_list = xds_from_storage_table( + feed_xds_list = xds_from_table_fragment( path + "::FEED", group_cols=["SPECTRAL_WINDOW_ID"], chunks={"row": -1} @@ -106,7 +109,7 @@ def feed_info(path): def flag_cmd_info(path): - flag_cmd_xds = xds_from_storage_table(path + "::FLAG_CMD") # noqa + flag_cmd_xds = xds_from_table_fragment(path + "::FLAG_CMD") # noqa # Not printing any summary information for this subtable yet - not sure # what is relevant. @@ -114,7 +117,7 @@ def flag_cmd_info(path): def field_info(path): - field_xds = xds_from_storage_table(path + "::FIELD")[0] + field_xds = xds_from_table_fragment(path + "::FIELD")[0] ids = [i for i in field_xds.SOURCE_ID.values] names = [n for n in field_xds.NAME.values] @@ -141,7 +144,7 @@ def field_info(path): def history_info(path): - history_xds = xds_from_storage_table(path + "::HISTORY")[0] # noqa + history_xds = xds_from_table_fragment(path + "::HISTORY")[0] # noqa # Not printing any summary information for this subtable yet - not sure # what is relevant. @@ -149,7 +152,7 @@ def history_info(path): def observation_info(path): - observation_xds = xds_from_storage_table(path + "::OBSERVATION")[0] # noqa + observation_xds = xds_from_table_fragment(path + "::OBSERVATION")[0] # noqa # Not printing any summary information for this subtable yet - not sure # what is relevant. @@ -157,7 +160,7 @@ def observation_info(path): def polarization_info(path): - polarization_xds = xds_from_storage_table(path + "::POLARIZATION")[0] + polarization_xds = xds_from_table_fragment(path + "::POLARIZATION")[0] corr_types = polarization_xds.CORR_TYPE.values @@ -175,7 +178,7 @@ def polarization_info(path): def processor_info(path): - processor_xds = xds_from_storage_table(path + "::PROCESSOR")[0] # noqa + processor_xds = xds_from_table_fragment(path + "::PROCESSOR")[0] # noqa # Not printing any summary information for this subtable yet - not sure # what is relevant. @@ -183,7 +186,7 @@ def processor_info(path): def spw_info(path): - spw_xds_list = xds_from_storage_table( + spw_xds_list = xds_from_table_fragment( path + "::SPECTRAL_WINDOW", group_cols=["__row__"], chunks={"row": 1, "chan": -1} @@ -207,7 +210,7 @@ def spw_info(path): def state_info(path): - state_xds = xds_from_storage_table(path + "::STATE")[0] # noqa + state_xds = xds_from_table_fragment(path + "::STATE")[0] # noqa # Not printing any summary information for this subtable yet - not sure # what is relevant. @@ -226,7 +229,7 @@ def source_info(path): def pointing_info(path): - pointing_xds = xds_from_storage_table(path + "::POINTING")[0] # noqa + pointing_xds = xds_from_table_fragment(path + "::POINTING")[0] # noqa # Not printing any summary information for this subtable yet - not sure # what is relevant. @@ -355,7 +358,7 @@ def summary(): # Open the data, grouping by the usual columns. Use these datasets to # produce some useful summaries. - data_xds_list = xds_from_storage_ms( + data_xds_list = xds_from_ms_fragment( path, index_cols=("TIME",), columns=("TIME", "FLAG", "FLAG_ROW", "DATA"), diff --git a/quartical/data_handling/angles.py b/quartical/data_handling/angles.py index 343ab466..f7e6098e 100644 --- a/quartical/data_handling/angles.py +++ b/quartical/data_handling/angles.py @@ -2,7 +2,7 @@ import casacore.measures import casacore.quanta as pq -from daskms import xds_from_storage_table +from daskms.experimental.multisource import xds_from_table_fragment import dask.array as da import threading from dask.graph_manipulation import clone @@ -24,9 +24,9 @@ def make_parangle_xds_list(ms_path, data_xds_list): # This may need to be more sophisticated. TODO: Can we guarantee that # these only ever have one element? - anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0] - feedtab = xds_from_storage_table(ms_path + "::FEED")[0] - fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0] + anttab = xds_from_table_fragment(ms_path + "::ANTENNA")[0] + feedtab = xds_from_table_fragment(ms_path + "::FEED")[0] + fieldtab = xds_from_table_fragment(ms_path + "::FIELD")[0] # We do this eagerly to make life easier. feeds = feedtab.POLARIZATION_TYPE.values diff --git a/quartical/data_handling/chunking.py b/quartical/data_handling/chunking.py index ce6fe353..4003e737 100644 --- a/quartical/data_handling/chunking.py +++ b/quartical/data_handling/chunking.py @@ -1,7 +1,10 @@ import dask.delayed as dd import numpy as np import dask.array as da -from daskms import xds_from_storage_ms, xds_from_storage_table +from daskms.experimental.multisource import ( + xds_from_ms_fragment, + xds_from_table_fragment +) def compute_chunking(ms_opts, compute=True): @@ -10,7 +13,7 @@ def compute_chunking(ms_opts, compute=True): # necessary to determine initial chunking over row and chan. TODO: Test # multi-SPW/field cases. Implement a memory budget. - indexing_xds_list = xds_from_storage_ms( + indexing_xds_list = xds_from_ms_fragment( ms_opts.path, columns=("TIME", "INTERVAL"), index_cols=("TIME",), @@ -24,7 +27,7 @@ def compute_chunking(ms_opts, compute=True): compute=False ) - spw_xds_list = xds_from_storage_table( + spw_xds_list = xds_from_table_fragment( ms_opts.path + "::SPECTRAL_WINDOW", group_cols=["__row__"], columns=["CHAN_FREQ", "CHAN_WIDTH"], diff --git a/quartical/data_handling/ms_handler.py b/quartical/data_handling/ms_handler.py index 3b923432..edafa2fa 100644 --- a/quartical/data_handling/ms_handler.py +++ b/quartical/data_handling/ms_handler.py @@ -2,9 +2,12 @@ import warnings import dask.array as da import numpy as np -from daskms import (xds_from_storage_ms, - xds_from_storage_table, - xds_to_storage_table) +from daskms import xds_to_storage_table +from daskms.experimental.multisource import ( + xds_from_ms_fragment, + xds_from_table_fragment, + xds_to_table_fragment +) from dask.graph_manipulation import clone from loguru import logger from quartical.weights.weights import initialize_weights @@ -28,7 +31,7 @@ def read_xds_list(model_columns, ms_opts): data_xds_list: A list of appropriately chunked xarray datasets. """ - antenna_xds = xds_from_storage_table(ms_opts.path + "::ANTENNA")[0] + antenna_xds = xds_from_table_fragment(ms_opts.path + "::ANTENNA")[0] n_ant = antenna_xds.dims["row"] @@ -36,7 +39,7 @@ def read_xds_list(model_columns, ms_opts): "observation.", n_ant) # Determine the number/type of correlations present in the measurement set. - pol_xds = xds_from_storage_table(ms_opts.path + "::POLARIZATION")[0] + pol_xds = xds_from_table_fragment(ms_opts.path + "::POLARIZATION")[0] try: corr_types = [CORR_TYPES[ct] for ct in pol_xds.CORR_TYPE.values[0]] @@ -56,7 +59,7 @@ def read_xds_list(model_columns, ms_opts): # probably need to be done on a per xds basis. Can probably be accomplished # by merging the field xds grouped by DDID into data grouped by DDID. - field_xds = xds_from_storage_table(ms_opts.path + "::FIELD")[0] + field_xds = xds_from_table_fragment(ms_opts.path + "::FIELD")[0] phase_dir = np.squeeze(field_xds.PHASE_DIR.values) field_names = field_xds.NAME.values @@ -90,7 +93,7 @@ def read_xds_list(model_columns, ms_opts): schema[ms_opts.weight_column] = {'dims': ('chan', 'corr')} try: - data_xds_list = xds_from_storage_ms( + data_xds_list = xds_from_ms_fragment( ms_opts.path, columns=columns, index_cols=("TIME",), @@ -103,7 +106,7 @@ def read_xds_list(model_columns, ms_opts): f"Invalid/missing column specified. Underlying error: {e}." ) from e - spw_xds_list = xds_from_storage_table( + spw_xds_list = xds_from_table_fragment( ms_opts.path + "::SPECTRAL_WINDOW", group_cols=["__row__"], columns=["CHAN_FREQ", "CHAN_WIDTH"], @@ -213,7 +216,7 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts): if not (output_opts.products or output_opts.flags): return [None] * len(xds_list) # Write nothing to the MS. - pol_xds = xds_from_storage_table(ms_path + "::POLARIZATION")[0] + pol_xds = xds_from_table_fragment(ms_path + "::POLARIZATION")[0] corr_types = [CORR_TYPES[ct] for ct in pol_xds.CORR_TYPE.values[0]] ms_n_corr = len(corr_types) @@ -302,6 +305,14 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts): rechunk=True # Needed to ensure zarr chunks map correctly to disk. ) + # write_xds_list = xds_to_table_fragment( + # xds_list, + # "delta1.ms", + # ms_path, + # columns=output_cols, + # rechunk=True # Needed to ensure zarr chunks map correctly to disk. + # ) + return write_xds_list diff --git a/quartical/data_handling/predict.py b/quartical/data_handling/predict.py index 6bac1ef4..cc8bda04 100644 --- a/quartical/data_handling/predict.py +++ b/quartical/data_handling/predict.py @@ -7,7 +7,8 @@ import dask from xarray import DataArray, Dataset from dask.graph_manipulation import clone -from daskms import xds_from_storage_table +from daskms.experimental.multisource import xds_from_table_fragment + from loguru import logger import numpy as np import Tigger @@ -310,21 +311,24 @@ def get_support_tables(ms_path): "SPECTRAL_WINDOW", "POLARIZATION", "FEED")} # All rows at once - lazy_tables = {"ANTENNA": xds_from_storage_table(n["ANTENNA"]), - "FEED": xds_from_storage_table(n["FEED"])} + lazy_tables = {"ANTENNA": xds_from_table_fragment(n["ANTENNA"]), + "FEED": xds_from_table_fragment(n["FEED"])} compute_tables = { # NOTE: Even though this has a fixed shape, I have ammended it to # also group by row. This just makes life fractionally easier. - "DATA_DESCRIPTION": xds_from_storage_table(n["DATA_DESCRIPTION"], - group_cols="__row__"), + "DATA_DESCRIPTION": xds_from_table_fragment( + n["DATA_DESCRIPTION"], group_cols="__row__" + ), # Variably shaped, need a dataset per row "FIELD": - xds_from_storage_table(n["FIELD"], group_cols="__row__"), + xds_from_table_fragment(n["FIELD"], group_cols="__row__"), "SPECTRAL_WINDOW": - xds_from_storage_table(n["SPECTRAL_WINDOW"], group_cols="__row__"), + xds_from_table_fragment( + n["SPECTRAL_WINDOW"], group_cols="__row__" + ), "POLARIZATION": - xds_from_storage_table(n["POLARIZATION"], group_cols="__row__"), + xds_from_table_fragment(n["POLARIZATION"], group_cols="__row__"), } lazy_tables.update(dask.compute(compute_tables)[0]) From ec2f4967ea8e3ed5d0814b0e9b2db075d9dffe50 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Fri, 4 Aug 2023 16:24:14 +0200 Subject: [PATCH 19/26] Import from fragments. --- quartical/apps/summary.py | 2 +- quartical/data_handling/angles.py | 2 +- quartical/data_handling/chunking.py | 2 +- quartical/data_handling/ms_handler.py | 3 ++- quartical/data_handling/predict.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/quartical/apps/summary.py b/quartical/apps/summary.py index 1d0a8dd2..0d3dab90 100644 --- a/quartical/apps/summary.py +++ b/quartical/apps/summary.py @@ -1,6 +1,6 @@ import argparse from pathlib import Path -from daskms.experimental.multisource import ( +from daskms.experimental.fragments import ( xds_from_ms_fragment, xds_from_table_fragment ) diff --git a/quartical/data_handling/angles.py b/quartical/data_handling/angles.py index f7e6098e..7911a24f 100644 --- a/quartical/data_handling/angles.py +++ b/quartical/data_handling/angles.py @@ -2,7 +2,7 @@ import casacore.measures import casacore.quanta as pq -from daskms.experimental.multisource import xds_from_table_fragment +from daskms.experimental.fragments import xds_from_table_fragment import dask.array as da import threading from dask.graph_manipulation import clone diff --git a/quartical/data_handling/chunking.py b/quartical/data_handling/chunking.py index 4003e737..867a28a7 100644 --- a/quartical/data_handling/chunking.py +++ b/quartical/data_handling/chunking.py @@ -1,7 +1,7 @@ import dask.delayed as dd import numpy as np import dask.array as da -from daskms.experimental.multisource import ( +from daskms.experimental.fragments import ( xds_from_ms_fragment, xds_from_table_fragment ) diff --git a/quartical/data_handling/ms_handler.py b/quartical/data_handling/ms_handler.py index edafa2fa..e1682cfd 100644 --- a/quartical/data_handling/ms_handler.py +++ b/quartical/data_handling/ms_handler.py @@ -3,7 +3,7 @@ import dask.array as da import numpy as np from daskms import xds_to_storage_table -from daskms.experimental.multisource import ( +from daskms.experimental.fragments import ( xds_from_ms_fragment, xds_from_table_fragment, xds_to_table_fragment @@ -305,6 +305,7 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts): rechunk=True # Needed to ensure zarr chunks map correctly to disk. ) + # TODO: This needs to be controlled in a sensible way. # write_xds_list = xds_to_table_fragment( # xds_list, # "delta1.ms", diff --git a/quartical/data_handling/predict.py b/quartical/data_handling/predict.py index cc8bda04..e3015e5e 100644 --- a/quartical/data_handling/predict.py +++ b/quartical/data_handling/predict.py @@ -7,7 +7,7 @@ import dask from xarray import DataArray, Dataset from dask.graph_manipulation import clone -from daskms.experimental.multisource import xds_from_table_fragment +from daskms.experimental.fragments import xds_from_table_fragment from loguru import logger import numpy as np From 1bb7448996ec60f273032df48876b4d54d0b2ee4 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Thu, 10 Aug 2023 14:55:40 +0200 Subject: [PATCH 20/26] Add output.fragment_path, which, if set, causes QC to write columns to a dask-ms fragment. --- quartical/config/argument_schema.yaml | 9 ++++++++ quartical/data_handling/ms_handler.py | 30 ++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/quartical/config/argument_schema.yaml b/quartical/config/argument_schema.yaml index ac7d768c..7fa416a0 100644 --- a/quartical/config/argument_schema.yaml +++ b/quartical/config/argument_schema.yaml @@ -193,6 +193,15 @@ output: Name of directory in which QuartiCal logging outputs will be stored. s3 is not currently supported for these outputs. + fragment_path: + dtype: Optional[str] + info: + If set, instead of mutating the input by e.g. writing flags, instead + writes a fragment to this location. A fragment is a zarr backed data + format that is read and dynamically combined with any parent datasets. + This allows QuartiCal to operate in an entirely read-only fashion. + This option is experimental. + log_to_terminal: default: true dtype: bool diff --git a/quartical/data_handling/ms_handler.py b/quartical/data_handling/ms_handler.py index e1682cfd..4768d62e 100644 --- a/quartical/data_handling/ms_handler.py +++ b/quartical/data_handling/ms_handler.py @@ -298,21 +298,23 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts): with warnings.catch_warnings(): # We anticipate spurious warnings. warnings.simplefilter("ignore") - write_xds_list = xds_to_storage_table( - xds_list, - ms_path, - columns=output_cols, - rechunk=True # Needed to ensure zarr chunks map correctly to disk. - ) - # TODO: This needs to be controlled in a sensible way. - # write_xds_list = xds_to_table_fragment( - # xds_list, - # "delta1.ms", - # ms_path, - # columns=output_cols, - # rechunk=True # Needed to ensure zarr chunks map correctly to disk. - # ) + if output_opts.fragment_path: + write_xds_list = xds_to_table_fragment( + xds_list, + output_opts.fragment_path, + ms_path, + columns=output_cols, + rechunk=True # Ensure zarr chunks map correctly to disk. + ) + + else: + write_xds_list = xds_to_storage_table( + xds_list, + ms_path, + columns=output_cols, + rechunk=True # Ensure zarr chunks map correctly to disk. + ) return write_xds_list From e73450782faedf9234b048713f9356481abb69e8 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Fri, 25 Aug 2023 16:03:11 +0200 Subject: [PATCH 21/26] Add plotting functionality (#290) * Fix version drift. * Bump to 0.2.0 * Initial commit of basic plotting functionality. * Change naming convention. * Improve transform argument. * Simplify transform selection. * Add rudimentary time and frequency selection. * Checkpoint ploter changes. Can now handle scans and spws, but is very slow. * More work on plotter - can now plot datasets in parallel. * Some tidying. * Slightly improve plot speed. Dominant cost is still saving the figures. * Commit some minor changes which speed up figure saving. * Lots of tiny fixes. * Tiny cosmetic changes. * Add custom tick formatter so that plots are the same size regardless. * Add matplotlib dependency. * Rework construction of plotting dictionary. Add a few utility functions which will likely be useful in other places in QC. * Rename variable to avoid confusion. * Fix bug affecting recursive grouping. * Avoid copies in grouping code. * Checkpoint work on extending functionality. * Make plotter more powerful. Add colourization option. Begin simplifying interface. * Allow user specification of colourmap. * Add plotsize parameter. --- pyproject.toml | 3 +- quartical/apps/plotter.py | 335 +++++++++++++++++++++++++++++++++ quartical/utils/collections.py | 15 ++ quartical/utils/datasets.py | 36 ++++ 4 files changed, 388 insertions(+), 1 deletion(-) create mode 100644 quartical/apps/plotter.py create mode 100644 quartical/utils/datasets.py diff --git a/pyproject.toml b/pyproject.toml index 68c1f4ee..a67c5df9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ pytest = "^7.3.1" omegaconf = "^2.3.0" colorama = "^0.4.6" stimela = "2.0rc4" +matplotlib = "^3.5.1" [tool.poetry.scripts] goquartical = 'quartical.executor:execute' @@ -47,7 +48,7 @@ goquartical-config = 'quartical.config.parser:create_user_config' goquartical-backup = 'quartical.apps.backup:backup' goquartical-restore = 'quartical.apps.backup:restore' goquartical-summary = 'quartical.apps.summary:summary' - +goquartical-plot = 'quartical.apps.plotter:plot' [build-system] requires = ["poetry-core"] diff --git a/quartical/apps/plotter.py b/quartical/apps/plotter.py new file mode 100644 index 00000000..35f18977 --- /dev/null +++ b/quartical/apps/plotter.py @@ -0,0 +1,335 @@ +import argparse +import math +import xarray +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import ticker +import matplotlib.cm as cm +from itertools import product, chain +from daskms.experimental.zarr import xds_from_zarr +from daskms.fsspec_store import DaskMSStore +from concurrent.futures import ProcessPoolExecutor, as_completed +from quartical.utils.collections import flatten +from quartical.utils.datasets import recursive_group_by_attr + + +TRANSFORMS = { + "raw": np.array, + "amplitude": np.abs, + "phase": np.angle, + "real": np.real, + "imag": np.imag +} + + +class CustomFormatter(ticker.ScalarFormatter): + + def __init__(self, *args, precision=None, **kwargs): + + super().__init__(*args, *kwargs) + + self.precision = precision + + def _set_format(self): + # set the format string to format all the ticklabels + if len(self.locs) < 2: + # Temporarily augment the locations with the axis end points. + _locs = [*self.locs, *self.axis.get_view_interval()] + else: + _locs = self.locs + locs = (np.asarray(_locs) - self.offset) / 10. ** self.orderOfMagnitude + loc_range = np.ptp(locs) + # Curvilinear coordinates can yield two identical points. + if loc_range == 0: + loc_range = np.max(np.abs(locs)) + # Both points might be zero. + if loc_range == 0: + loc_range = 1 + if len(self.locs) < 2: + # We needed the end points only for the loc_range calculation. + locs = locs[:-2] + loc_range_oom = int(math.floor(math.log10(loc_range))) + # first estimate: + sigfigs = max(0, 3 - loc_range_oom) + # refined estimate: + thresh = 1e-3 * 10 ** loc_range_oom + while sigfigs >= 0: + if np.abs(locs - np.round(locs, decimals=sigfigs)).max() < thresh: + sigfigs -= 1 + else: + break + sigfigs += 1 + sigfigs = self.precision or sigfigs + self.format = f'%{sigfigs + 3}.{sigfigs}f' + if self._usetex or self._useMathText: + self.format = r'$\mathdefault{%s}$' % self.format + + +def cli(): + + parser = argparse.ArgumentParser( + description="Rudimentary plotter for QuartiCal gain solutions." + ) + + parser.add_argument( + "input_path", + type=DaskMSStore, + help="Path to input gains, e.g. path/to/dir/G. Accepts valid s3 urls." + ) + parser.add_argument( + "output_path", + type=DaskMSStore, + help="Path to desired output location." + ) + parser.add_argument( + "--plot-var", + type=str, + default="gains", + help="Name of data variable to plot." + ) + parser.add_argument( + "--flag-var", + type=str, + default="gain_flags", + help="Name of data variable to use as flags." + ) + parser.add_argument( + "--xaxis", + type=str, + default="gain_time", + choices=("gain_time", "gain_freq", "param_time", "param_freq"), + help="Name of coordinate to use for x-axis." + ) + parser.add_argument( + "--transform", + type=str, + default="raw", + choices=list(TRANSFORMS.keys()), + help="Transform to apply to data before plotting." + ) + parser.add_argument( + "--iter-attrs", + type=str, + nargs="+", + default=["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"], + help=( + "Attributes (datasets) over which to iterate. Omission will " + "result in concatenation in the omitted axis i.e. omit " + "SCAN_NUMBER to to include all scans in a single plot." + ) + ) + parser.add_argument( + "--iter-axes", + type=str, + nargs="+", + default=["antenna", "direction", "correlation"], + help=( + "Axes over which to iterate when generating plots i.e. produce a " + "plot per unique combination of the specified axes." + ) + ) + parser.add_argument( + "--mean-axis", + type=str, + default=None, + help=( + "If set, will plot a heavier line to indicate the mean of the " + "plotted quantity along this axis." + ) + ) + parser.add_argument( + "--colourize-axis", + type=str, + default=None, + help="Axis to colour by." + ) + parser.add_argument( + "--time-range", + type=float, + nargs=2, + default=[None], + help="Time range to plot." + ) + parser.add_argument( + "--freq-range", + type=float, + nargs=2, + default=[None], + help="Frequency range to plot." + ) + parser.add_argument( + "--nworker", + type=int, + default=1, + help="Number of processes to use while plotting." + ) + parser.add_argument( + "--colourmap", + type=str, + default="plasma", + help=( + "Colourmap to use with --colourize-axis. Supports all matplotlib " + "colourmaps." + ) + ) + parser.add_argument( + "--fig-size", + type=float, + nargs=2, + default=[5, 5], + help="Figure size in inches. Expects two values, width and height." + ) + return parser.parse_args() + + +def to_plot_dict(xdsl, iter_attrs): + + grouped = recursive_group_by_attr(xdsl, iter_attrs) + + return { + k: xarray.combine_by_coords(v, combine_attrs="drop_conflicts") + for k, v in flatten(grouped).items() + } + + +def _plot(group, xds, args): + + xds = xds.compute(scheduler="single-threaded") + + if args.freq_range or args.time_range: + time_ax, freq_ax = xds[args.plot_var].dims[:2] + xds = xds.sel( + { + time_ax: slice(*args.time_range), + freq_ax: slice(*args.freq_range) + } + ) + + dims = xds[args.plot_var].dims # Dimensions of plot quantity. + assert all(map(lambda x: x in dims, args.iter_axes)), ( + f"Some or all of {args.iter_axes} are not present on " + f"{args.plot_var}." + ) + + # Grab the required transform from the dict. + transform = TRANSFORMS[args.transform] + + # NOTE: This mututates the data variables in place. + data = xds[args.plot_var].values + flags = xds[args.flag_var].values + data[np.where(flags)] = np.nan # Set flagged values to nan (not plotted). + xds = xds.drop_vars(args.flag_var) # No more use for flags. + + # Construct list of lists containing axes over which we iterate i.e. + # produce a plot per combination of these values. + iter_axes_itr = [xds[x].values.tolist() for x in args.iter_axes] + + # Figure out axes included in a single plot. + excluded_dims = {*args.iter_axes, args.xaxis} + agg_axes = [d for d in dims if d not in excluded_dims] + agg_axes_itr = [range(xds.sizes[x]) for x in agg_axes] + + # Figure out axes included in a single plot after taking the mean. + excluded_dims = {*args.iter_axes, args.xaxis, args.mean_axis} + mean_agg_axes = [d for d in dims if d not in excluded_dims] + mean_agg_axes_itr = [range(xds.sizes[x]) for x in mean_agg_axes] + + if args.colourize_axis: + n_colour = xds.sizes[args.colourize_axis] + colourmap = cm.get_cmap(args.colourmap) + colours = [colourmap(i / n_colour) for i in range(n_colour)] + else: + n_colour = 2 + colours = ["k", "r"] + + fig, ax = plt.subplots(figsize=args.fig_size) + + for ia in product(*iter_axes_itr): + + sel = {ax: val for ax, val in zip(args.iter_axes, ia)} + + xda = xds.sel(sel)[args.plot_var] + + ax.clear() + + for aa in product(*agg_axes_itr): + + subsel = {ax: val for ax, val in zip(agg_axes, aa)} + pxda = xda.isel(subsel) + + ax.plot( + pxda[args.xaxis].values, + transform(pxda.values), + color=colours[subsel.get(args.colourize_axis, 0)], + linewidth=0.1 + ) + + if args.mean_axis: + + mxda = xda.mean(args.mean_axis) + + for ma in product(*mean_agg_axes_itr): + + subsel = {ax: val for ax, val in zip(mean_agg_axes, ma)} + pxda = mxda.isel(subsel) + + ax.plot( + pxda[args.xaxis].values, + transform(pxda.values), + color=colours[subsel.get(args.colourize_axis, 1)] + ) + + ax.title.set_text("\n".join([f"{k}: {v}" for k, v in sel.items()])) + ax.title.set_fontsize("medium") + ax.set_xlabel(f"{args.xaxis}") + ax.set_ylabel(f"{args.transform}({ia[-1]})") + + formatter = CustomFormatter(precision=2) + formatter.set_scientific(True) + formatter.set_powerlimits((0, 0)) + ax.yaxis.set_major_formatter(formatter) + ax.yaxis.set_major_locator(ticker.LinearLocator(numticks=5)) + + fig_name = "-".join(map(str, chain.from_iterable(sel.items()))) + + root_subdir = f"{xds.NAME}-{args.plot_var}-{args.transform}" + leaf_subdir = "-".join(group) + subdir_path = f"{root_subdir}/{leaf_subdir}" + + args.output_path.makedirs(subdir_path, exist_ok=True) + + fig.savefig( + f"{args.output_path.full_path}/{subdir_path}/{fig_name}.png", + bbox_inches="tight" # SLOW, but slightly unavoidable. + ) + + plt.close() + + +def plot(): + + args = cli() + + non_colourizable_axes = {*args.iter_axes, args.mean_axis, args.xaxis} + if args.colourize_axis and args.colourize_axis in non_colourizable_axes: + raise ValueError(f"Cannot colourize using axis {args.colourize_axis}.") + + # Path to gain location. + gain_path = DaskMSStore("::".join(args.input_path.url.rsplit("/", 1))) + + xdsl = xds_from_zarr(gain_path) + + # Select only the necessary fields for plotting on each dataset. + xdsl = [xds[[args.plot_var, args.flag_var]] for xds in xdsl] + + # Partitioned dictionary of xarray.Datasets. + xdsd = to_plot_dict(xdsl, args.iter_attrs) + + with ProcessPoolExecutor(max_workers=args.nworker) as ppe: + futures = [ppe.submit(_plot, k, xds, args) for k, xds in xdsd.items()] + + for future in as_completed(futures): + try: + future.result() + except Exception as exc: + print(f"Exception raised in process pool: {exc}") diff --git a/quartical/utils/collections.py b/quartical/utils/collections.py index 31da7d34..42199925 100644 --- a/quartical/utils/collections.py +++ b/quartical/utils/collections.py @@ -1,3 +1,4 @@ +from collections.abc import MutableMapping from collections import defaultdict @@ -9,3 +10,17 @@ def freeze_default_dict(ddict): ddict[k] = freeze_default_dict(v) return dict(ddict) + + +def flatten(dictionary, parent_key=()): + """Flatten dictionary. Adapted from https://stackoverflow.com/a/6027615.""" + items = [] + + for key, value in dictionary.items(): + new_key = parent_key + (key,) if parent_key else (key,) + if isinstance(value, MutableMapping): + items.extend(flatten(value, new_key).items()) + else: + items.append((new_key, value)) + + return dict(items) diff --git a/quartical/utils/datasets.py b/quartical/utils/datasets.py new file mode 100644 index 00000000..537bc3a9 --- /dev/null +++ b/quartical/utils/datasets.py @@ -0,0 +1,36 @@ +from collections.abc import Iterable + + +def group_by_attr(xdsl, attr, default="?"): + """Group list of xarray datasets based on value of attribute.""" + + attr_vals = {xds.attrs.get(attr, default) for xds in xdsl} + + return { + f"{attr}_{attr_val}": [ + xds for xds in xdsl if xds.attrs.get(attr, default) == attr_val + ] + for attr_val in attr_vals + } + + +def _recursive_group_by_attr(partition_dict, attrs): + + for k, v in partition_dict.items(): + partition_dict[k] = group_by_attr(v, attrs[0]) + + if len(attrs[1:]): + _recursive_group_by_attr(partition_dict[k], attrs[1:]) + + +def recursive_group_by_attr(xdsl, keys): + + if not isinstance(keys, Iterable): + keys = [keys] + + group_dict = group_by_attr(xdsl, keys[0]) + + if len(keys[1:]): + _recursive_group_by_attr(group_dict, keys[1:]) + + return group_dict From 2e0b9e8c41a99a61d663781c177678437acb30e8 Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Wed, 30 Aug 2023 12:16:04 +0200 Subject: [PATCH 22/26] Fix #293 - OOB access caused by `output.subtract_directions` (#294) * Fix version drift. * Bump to 0.2.0 * Fix #293. --- quartical/calibration/calibrate.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/quartical/calibration/calibrate.py b/quartical/calibration/calibrate.py index 83ed022b..dba5f2f2 100644 --- a/quartical/calibration/calibrate.py +++ b/quartical/calibration/calibrate.py @@ -253,6 +253,18 @@ def make_visibility_output( itr = enumerate(zip(data_xds_list, mapping_xds_list)) + if output_opts.subtract_directions: + n_dir = data_xds_list[0].dims['dir'] # Should be the same on all xdss. + requested = set(output_opts.subtract_directions) + valid = set(range(n_dir)) + invalid = requested - valid + if invalid: + raise ValueError( + f"User has specified output.subtract_directions as " + f"{requested} but the following directions are not present " + f"in the model: {invalid}." + ) + for xds_ind, (data_xds, mapping_xds) in itr: data_col = data_xds.DATA.data model_col = data_xds.MODEL_DATA.data From 9a2879466ecf4b9093c8f4d55da9876392e08fe9 Mon Sep 17 00:00:00 2001 From: Landman Bester Date: Wed, 13 Sep 2023 13:14:04 +0200 Subject: [PATCH 23/26] Namedbackups (#296) * Fix version drift. * Bump to 0.2.0 * Add optional label and single field selection to backup app * remove item instead of pop@index * do not .remove() from xds_list * Simplify using some existing functionality. --------- Co-authored-by: JSKenyon Co-authored-by: landmanbester --- quartical/apps/backup.py | 35 +++++++++++++++++++++++----- quartical/data_handling/selection.py | 12 +++++++--- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/quartical/apps/backup.py b/quartical/apps/backup.py index fcb44a69..1a15ea8c 100644 --- a/quartical/apps/backup.py +++ b/quartical/apps/backup.py @@ -1,5 +1,6 @@ import argparse from math import prod, ceil +from quartical.data_handling.selection import filter_xds_list from daskms import xds_from_storage_ms, xds_to_storage_table from daskms.experimental.zarr import xds_to_zarr, xds_from_zarr from daskms.fsspec_store import DaskMSStore @@ -10,8 +11,9 @@ def backup(): parser = argparse.ArgumentParser( description='Backup any Measurement Set column to zarr. Backups will ' - 'be labelled automatically using the current datetime, ' - 'the Measurement Set name and the column name.' + 'be labelled using a combination of the passed in label ' + '(defaults to datetime), the Measurement Set name and ' + 'the column name.' ) parser.add_argument( @@ -33,19 +35,34 @@ def backup(): type=str, help='Name of column to be backed up.' ) + parser.add_argument( + '--label', + type=str, + help='An explicit label to include in the backup name. Defaults to ' + 'datetime at which the backup was created. Full name will be ' + 'given by [label]-[msname]-[column].bkp.qc.' + ) parser.add_argument( '--nthread', type=int, default=1, help='Number of threads to use.' ) + parser.add_argument( + '--field-id', + type=int, + help='Field ID to back up.' + ) args = parser.parse_args() ms_name = args.ms_path.full_path.rsplit("/", 1)[1] column_name = args.column_name - timestamp = time.strftime("%Y%m%d-%H%M%S") + if args.label: + label = args.label + else: + label = time.strftime("%Y%m%d-%H%M%S") # This call exists purely to get the relevant shape and dtype info. data_xds_list = xds_from_storage_ms( @@ -55,8 +72,11 @@ def backup(): group_cols=("FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"), ) + # Use existing functionality. TODO: Improve and expose DDID selection. + xdso = filter_xds_list(data_xds_list, args.field_id) + # Compute appropriate chunks (256MB by default) to keep zarr happy. - chunks = [chunk_by_size(xds[column_name]) for xds in data_xds_list] + chunks = [chunk_by_size(xds[column_name]) for xds in xdso] # Repeat of above call but now with correct chunking information. data_xds_list = xds_from_storage_ms( @@ -67,9 +87,12 @@ def backup(): chunks=chunks ) + # Use existing functionality. TODO: Improve and expose DDID selection. + xdso = filter_xds_list(data_xds_list, args.field_id) + bkp_xds_list = xds_to_zarr( - data_xds_list, - f"{args.zarr_dir.url}::{timestamp}-{ms_name}-{column_name}.bkp.qc", + xdso, + f"{args.zarr_dir.url}::{label}-{ms_name}-{column_name}.bkp.qc", ) dask.compute(bkp_xds_list, num_workers=args.nthread) diff --git a/quartical/data_handling/selection.py b/quartical/data_handling/selection.py index 4f6f9301..bd0f66c3 100644 --- a/quartical/data_handling/selection.py +++ b/quartical/data_handling/selection.py @@ -1,7 +1,13 @@ -def filter_xds_list(xds_list, fields, ddids): +def filter_xds_list(xds_list, fields=[], ddids=[]): - filter_fields = {"FIELD_ID": fields, - "DATA_DESC_ID": ddids} + # If we specify an int, make it a list. Might be worth improving. + fields = [fields] if isinstance(fields, int) else fields + ddids = [ddids] if isinstance(ddids, int) else ddids + + filter_fields = { + "FIELD_ID": fields, + "DATA_DESC_ID": ddids + } for k, v in filter_fields.items(): fil = filter(lambda xds: getattr(xds, k) in v, xds_list) From d8a971969c0774705493ef69141c082b99c4d89b Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Thu, 21 Sep 2023 10:33:12 +0200 Subject: [PATCH 24/26] Depend on dask-ms master. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cba975d1..af3e5f49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ columnar = "^1.4.1" "ruamel.yaml" = "^0.17.26" dask = {extras = ["diagnostics"], version = "^2023.1.0"} distributed = "^2023.1.0" -dask-ms = {git = "https://github.com/ratt-ru/dask-ms.git", branch = "multisource-experimental", extras = ["s3", "xarray", "zarr"]} +dask-ms = {git = "https://github.com/ratt-ru/dask-ms.git", extras = ["s3", "xarray", "zarr"]} codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = "^0.3.4"} astro-tigger-lsm = "^1.7.2" loguru = "^0.7.0" From 2cdaf80ac219c2528746aaaa94101754f43eac20 Mon Sep 17 00:00:00 2001 From: Jonathan Kenyon Date: Fri, 29 Nov 2024 16:05:57 +0200 Subject: [PATCH 25/26] Fix old usage. --- quartical/data_handling/angles.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/quartical/data_handling/angles.py b/quartical/data_handling/angles.py index 09b4d277..6fb77c67 100644 --- a/quartical/data_handling/angles.py +++ b/quartical/data_handling/angles.py @@ -21,9 +21,9 @@ def assign_parangle_data(ms_path, data_xds_list): - anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0] - feedtab = xds_from_storage_table(ms_path + "::FEED")[0] - fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0] + anttab = xds_from_table_fragment(ms_path + "::ANTENNA")[0] + feedtab = xds_from_table_fragment(ms_path + "::FEED")[0] + fieldtab = xds_from_table_fragment(ms_path + "::FIELD")[0] # We do the following eagerly to reduce graph complexity. feeds = feedtab.POLARIZATION_TYPE.values From 7b9cf90904bdd001758f84c2e07460f8f89af63c Mon Sep 17 00:00:00 2001 From: Jonathan Kenyon Date: Fri, 29 Nov 2024 16:08:28 +0200 Subject: [PATCH 26/26] Fix uncommited pyproject. --- pyproject.toml | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7975983..2783ce21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,29 +24,6 @@ include = [ ] [tool.poetry.dependencies] -<<<<<<< HEAD -python = "^3.8" -tbump = "^6.10.0" -columnar = "^1.4.1" -"ruamel.yaml" = "^0.17.26" -dask = {extras = ["diagnostics"], version = "^2023.1.0"} -distributed = "^2023.1.0" -dask-ms = {git = "https://github.com/ratt-ru/dask-ms.git", extras = ["s3", "xarray", "zarr"]} -codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = "^0.3.4"} -astro-tigger-lsm = "^1.7.2" -loguru = "^0.7.0" -requests = "^2.31.0" -pytest = "^7.3.1" -omegaconf = "^2.3.0" -colorama = "^0.4.6" -stimela = "2.0rc4" -ducc0 = "^0.31.0" -sympy = "^1.12" -matplotlib = "^3.5.1" - -[tool.poetry.extras] -degrid = ["ducc0", "sympy"] -======= python = ">=3.10, <3.13" astro-tigger-lsm = [ { version = ">=1.7.2, <=1.7.3", python = "<3.12" }, @@ -66,7 +43,6 @@ requests = ">=2.31.0, <=2.32.3" "ruamel.yaml" = ">=0.17.26, <=0.18.6" stimela = ">=2.0" tbump = ">=6.10.0, <=6.11.0" ->>>>>>> main [tool.poetry.scripts] goquartical = 'quartical.executor:execute'