diff --git a/quartical/apps/plotter.py b/quartical/apps/plotter.py index 35f18977..45577001 100644 --- a/quartical/apps/plotter.py +++ b/quartical/apps/plotter.py @@ -193,6 +193,9 @@ def to_plot_dict(xdsl, iter_attrs): def _plot(group, xds, args): + # get rid of question marks + qstrip = lambda x: x.replace('?', 'N/A') + group = tuple(map(qstrip, group)) xds = xds.compute(scheduler="single-threaded") diff --git a/quartical/calibration/constructor.py b/quartical/calibration/constructor.py index ec6b6412..34462fac 100644 --- a/quartical/calibration/constructor.py +++ b/quartical/calibration/constructor.py @@ -5,7 +5,7 @@ term_spec_tup = namedtuple("term_spec_tup", "name type shape pshape") -aux_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID") +log_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID") def construct_solver( @@ -54,9 +54,9 @@ def construct_solver( corr_mode = data_xds.sizes["corr"] block_id_arr = get_block_id_arr(data_col) - aux_block_info = { - k: data_xds.attrs.get(k, "?") for k in aux_info_fields - } + data_xds_meta = data_xds.attrs.copy() + for k in log_info_fields: + data_xds_meta[k] = data_xds_meta.get(k, "?") # Grab the number of input chunks - doing this on the data should be # safe. @@ -87,7 +87,7 @@ def construct_solver( ) blocker.add_input("term_spec_list", spec_list, ("row", "chan")) blocker.add_input("corr_mode", corr_mode) - blocker.add_input("aux_block_info", aux_block_info) + blocker.add_input("data_xds_meta", data_xds_meta) blocker.add_input("solver_opts", solver_opts) blocker.add_input("chain", chain) diff --git a/quartical/calibration/solver.py b/quartical/calibration/solver.py index 620874c3..ddf3452f 100644 --- a/quartical/calibration/solver.py +++ b/quartical/calibration/solver.py @@ -54,7 +54,7 @@ def solver_wrapper( solver_opts, chain, block_id_arr, - aux_block_info, + data_xds_meta, corr_mode, **kwargs ): @@ -108,11 +108,11 @@ def solver_wrapper( # Perform term specific setup e.g. init gains and params. if term.is_parameterized: gains, gain_flags, params, param_flags = term.init_term( - term_spec, ref_ant, ms_kwargs, term_kwargs + term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta ) else: gains, gain_flags = term.init_term( - term_spec, ref_ant, ms_kwargs, term_kwargs + term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta ) # Dummy arrays with standard dtypes - aids compilation. params = np.empty(term_pshape, dtype=np.float64) @@ -190,6 +190,7 @@ def solver_wrapper( for ind, (term, iters) in enumerate(zip(cycle(chain), iter_recipe)): active_term = chain.index(term) + active_spec = term_spec_list[term_ind] ms_fields = term.ms_inputs._fields ms_inputs = term.ms_inputs( @@ -219,13 +220,17 @@ def solver_wrapper( term.solve_per ) - jhj, conv_iter, conv_perc = term.solver( - ms_inputs, - mapping_inputs, - chain_inputs, - meta_inputs, - corr_mode - ) + if term.solver: + jhj, conv_iter, conv_perc = term.solver( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + corr_mode + ) + else: + jhj = np.zeros(getattr(active_spec, "pshape", active_spec.shape)) + conv_iter, conv_perc = 0, 1 # If reweighting is enabled, do it when the epoch changes, except # for the final epoch - we don't reweight if we won't solve again. @@ -269,7 +274,7 @@ def solver_wrapper( corr_mode ) - log_chisq(presolve_chisq, postsolve_chisq, aux_block_info, block_id) + log_chisq(presolve_chisq, postsolve_chisq, data_xds_meta, block_id) results_dict["presolve_chisq"] = presolve_chisq results_dict["postsolve_chisq"] = postsolve_chisq diff --git a/quartical/config/gain_schema.yaml b/quartical/config/gain_schema.yaml index 3c3d0fc9..266c23b7 100644 --- a/quartical/config/gain_schema.yaml +++ b/quartical/config/gain_schema.yaml @@ -14,7 +14,9 @@ gain: - rotation_measure - rotation - crosshand_phase + - crosshand_phase_null_v - leakage + - parallactic_angle info: Type of gain to solve for. diff --git a/quartical/data_handling/angles.py b/quartical/data_handling/angles.py index 62ca7c34..aa05b1be 100644 --- a/quartical/data_handling/angles.py +++ b/quartical/data_handling/angles.py @@ -19,6 +19,46 @@ _thread_local = threading.local() +def assign_parangle_data(ms_path, data_xds_list): + + anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0] + feedtab = xds_from_storage_table(ms_path + "::FEED")[0] + fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0] + + # We do the following eagerly to reduce graph complexity. + feeds = feedtab.POLARIZATION_TYPE.values + unique_feeds = np.unique(feeds) + + if np.all([feed in "XxYy" for feed in unique_feeds]): + feed_type = "linear" + elif np.all([feed in "LlRr" for feed in unique_feeds]): + feed_type = "circular" + else: + raise ValueError("Unsupported feed type/configuration.") + + phase_dirs = fieldtab.PHASE_DIR.values + + updated_data_xds_list = [] + for xds in data_xds_list: + xds = xds.assign( + { + "RECEPTOR_ANGLE": ( + ("ant", "feed"), clone(feedtab.RECEPTOR_ANGLE.data) + ), + "POSITION": ( + ("ant", "xyz"), + clone(anttab.POSITION.data) + ) + } + ) + xds.attrs["FEED_TYPE"] = feed_type + xds.attrs["FIELD_CENTRE"] = tuple(phase_dirs[xds.FIELD_ID, 0]) + + updated_data_xds_list.append(xds) + + return updated_data_xds_list + + def make_parangle_xds_list(ms_path, data_xds_list): """Create a list of xarray.Datasets containing the parallactic angles.""" @@ -266,7 +306,7 @@ def nb_apply_parangle_rot(data_col, parangles, utime_ind, ant1_col, ant2_col, v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode) v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode) valloc = factories.valloc_factory(corr_mode) - rotmat = rotation_factory(corr_mode, feed_type) + rotmat = factories.rotation_factory(corr_mode, feed_type) def impl(data_col, parangles, utime_ind, ant1_col, ant2_col, corr_mode, feed_type): @@ -299,43 +339,3 @@ def impl(data_col, parangles, utime_ind, ant1_col, ant2_col, return data_col return impl - - -def rotation_factory(corr_mode, feed_type): - - if feed_type.literal_value == "circular": - if corr_mode.literal_value == 4: - def impl(rot0, rot1, out): - out[0] = np.exp(-1j*rot0) - out[1] = 0 - out[2] = 0 - out[3] = np.exp(1j*rot1) - elif corr_mode.literal_value == 2: # TODO: Is this sensible? - def impl(rot0, rot1, out): - out[0] = np.exp(-1j*rot0) - out[1] = np.exp(1j*rot1) - elif corr_mode.literal_value == 1: # TODO: Is this sensible? - def impl(rot0, rot1, out): - out[0] = np.exp(-1j*rot0) - else: - raise ValueError("Unsupported number of correlations.") - elif feed_type.literal_value == "linear": - if corr_mode.literal_value == 4: - def impl(rot0, rot1, out): - out[0] = np.cos(rot0) - out[1] = np.sin(rot0) - out[2] = -np.sin(rot1) - out[3] = np.cos(rot1) - elif corr_mode.literal_value == 2: # TODO: Is this sensible? - def impl(rot0, rot1, out): - out[0] = np.cos(rot0) - out[1] = np.cos(rot1) - elif corr_mode.literal_value == 1: # TODO: Is this sensible? - def impl(rot0, rot1, out): - out[0] = np.cos(rot0) - else: - raise ValueError("Unsupported number of correlations.") - else: - raise ValueError("Unsupported feed type.") - - return factories.qcjit(impl) diff --git a/quartical/data_handling/ms_handler.py b/quartical/data_handling/ms_handler.py index 3f5b38cf..13920fcb 100644 --- a/quartical/data_handling/ms_handler.py +++ b/quartical/data_handling/ms_handler.py @@ -15,6 +15,13 @@ from quartical.data_handling.selection import filter_xds_list from quartical.data_handling.angles import apply_parangles +DASKMS_ATTRS = { + "__daskms_partition_schema__", + "SCAN_NUMBER", + "FIELD_ID", + "DATA_DESC_ID" +} + def read_xds_list(model_columns, ms_opts): """Reads a measurement set and generates a list of xarray data sets. @@ -293,14 +300,18 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts): logger.info("Outputs will be written to {}.", ", ".join(output_cols)) + # Select only the output columns to simplify datasets. + xds_list = [xds[list(output_cols)] for xds in xds_list] + + # Remove all coords bar ROWID so that they do not get written. + xds_list = [ + xds.drop_vars(set(xds.coords.keys()) - {"ROWID"}, errors='ignore') + for xds in xds_list + ] + # Remove attrs added by QuartiCal so that they do not get written. for xds in xds_list: - xds.attrs.pop("UTIME_CHUNKS", None) - xds.attrs.pop("FIELD_NAME", None) - - # Remove coords added by QuartiCal so that they do not get written. - xds_list = [xds.drop_vars(["chan", "corr"], errors='ignore') - for xds in xds_list] + xds.attrs = {k: v for k, v in xds.attrs.items() if k in DASKMS_ATTRS} with warnings.catch_warnings(): # We anticipate spurious warnings. warnings.simplefilter("ignore") diff --git a/quartical/executor.py b/quartical/executor.py index b79a9fe8..704ecdaf 100644 --- a/quartical/executor.py +++ b/quartical/executor.py @@ -13,7 +13,8 @@ preprocess_xds_list, postprocess_xds_list) from quartical.data_handling.model_handler import add_model_graph -from quartical.data_handling.angles import make_parangle_xds_list +from quartical.data_handling.angles import (make_parangle_xds_list, + assign_parangle_data) from quartical.calibration.calibrate import add_calibration_graph from quartical.statistics.statistics import make_stats_xds_list from quartical.statistics.logging import log_summary_stats @@ -110,6 +111,7 @@ def _execute(exitstack): # Preprocess the xds_list - initialise some values and fix bad data. data_xds_list = preprocess_xds_list(data_xds_list, ms_opts) + data_xds_list = assign_parangle_data(ms_opts.path, data_xds_list) # Make a list of datasets containing the parallactic angles as these # can be expensive to compute and may be used several times. NOTE: At diff --git a/quartical/gains/__init__.py b/quartical/gains/__init__.py index e799443c..e652419d 100644 --- a/quartical/gains/__init__.py +++ b/quartical/gains/__init__.py @@ -6,9 +6,10 @@ from quartical.gains.tec_and_offset import TecAndOffset from quartical.gains.rotation import Rotation from quartical.gains.rotation_measure import RotationMeasure -from quartical.gains.crosshand_phase import CrosshandPhase +from quartical.gains.crosshand_phase import CrosshandPhase, CrosshandPhaseNullV from quartical.gains.leakage import Leakage from quartical.gains.delay_and_tec import DelayAndTec +from quartical.gains.parallactic_angle import ParallacticAngle TERM_TYPES = { @@ -22,6 +23,8 @@ "rotation": Rotation, "rotation_measure": RotationMeasure, "crosshand_phase": CrosshandPhase, + "crosshand_phase_null_v": CrosshandPhaseNullV, "leakage": Leakage, - "delay_and_tec": DelayAndTec + "delay_and_tec": DelayAndTec, + "parallactic_angle": ParallacticAngle } diff --git a/quartical/gains/amplitude/__init__.py b/quartical/gains/amplitude/__init__.py index 5e04ba09..a2d2285a 100644 --- a/quartical/gains/amplitude/__init__.py +++ b/quartical/gains/amplitude/__init__.py @@ -38,7 +38,7 @@ def make_param_names(cls, correlations): return [f"amplitude_{c}" for c in param_corr] - def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs): + def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None): """Initialise the gains (and parameters).""" gains, gain_flags, params, param_flags = super().init_term( diff --git a/quartical/gains/crosshand_phase/__init__.py b/quartical/gains/crosshand_phase/__init__.py index 80fd8e9c..f9296355 100644 --- a/quartical/gains/crosshand_phase/__init__.py +++ b/quartical/gains/crosshand_phase/__init__.py @@ -5,6 +5,9 @@ crosshand_phase_solver, crosshand_params_to_gains ) +from quartical.gains.crosshand_phase.null_v_kernel import ( + null_v_crosshand_phase_solver +) from quartical.gains.general.flagging import ( apply_gain_flags_to_gains, apply_param_flags_to_params @@ -39,7 +42,7 @@ def make_param_names(cls, correlations): return [f"crosshand_phase_{c}" for c in param_corr] - def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs): + def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None): """Initialise the gains (and parameters).""" gains, gain_flags, params, param_flags = super().init_term( @@ -54,3 +57,8 @@ def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs): apply_gain_flags_to_gains(gain_flags, gains) return gains, gain_flags, params, param_flags + + +class CrosshandPhaseNullV(CrosshandPhase): + + solver = staticmethod(null_v_crosshand_phase_solver) \ No newline at end of file diff --git a/quartical/gains/crosshand_phase/null_v_kernel.py b/quartical/gains/crosshand_phase/null_v_kernel.py new file mode 100644 index 00000000..33fb4f78 --- /dev/null +++ b/quartical/gains/crosshand_phase/null_v_kernel.py @@ -0,0 +1,643 @@ +# -*- coding: utf-8 -*- +import numpy as np +from numba import prange, njit +from numba.typed import List +from numba.extending import overload +from quartical.utils.numba import (coerce_literal, + JIT_OPTIONS, + PARALLEL_JIT_OPTIONS) +from quartical.gains.general.generics import (native_intermediaries, + upsampled_itermediaries, + per_array_jhj_jhr, + resample_solints, + downsample_jhj_jhr, + invert_gains) +from quartical.gains.general.flagging import (flag_intermediaries, + update_gain_flags, + finalize_gain_flags, + apply_gain_flags_to_flag_col, + update_param_flags) +from quartical.gains.general.convenience import (get_row, + get_extents) +import quartical.gains.general.factories as factories +from quartical.gains.general.inversion import (invert_factory, + inversion_buffer_factory) + + +def get_identity_params(corr_mode): + + if corr_mode.literal_value == 4: + return np.zeros((1,), dtype=np.float64) + else: + raise ValueError("Unsupported number of correlations.") + + +@njit(**JIT_OPTIONS) +def null_v_crosshand_phase_solver( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + corr_mode +): + return null_v_crosshand_phase_solver_impl( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + corr_mode + ) + + +def null_v_crosshand_phase_solver_impl( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + corr_mode +): + raise NotImplementedError + + +@overload(null_v_crosshand_phase_solver_impl, jit_options=JIT_OPTIONS) +def nb_null_v_crosshand_phase_solver_impl( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + corr_mode +): + + coerce_literal(nb_null_v_crosshand_phase_solver_impl, ["corr_mode"]) + + identity_params = get_identity_params(corr_mode) + + def impl( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + corr_mode + ): + + gains = chain_inputs.gains + gain_flags = chain_inputs.gain_flags + + inverse_gains = List() + for gain_term in gains: + inverse_gains.append(np.empty_like(gain_term)) + invert_gains(gains, inverse_gains, corr_mode) + + active_term = meta_inputs.active_term + max_iter = meta_inputs.iters + solve_per = meta_inputs.solve_per + dd_term = meta_inputs.dd_term + n_thread = meta_inputs.threads + + active_gain = gains[active_term] + active_gain_flags = gain_flags[active_term] + active_params = chain_inputs.params[active_term] + + # Set up some intemediaries used for flagging. + km1_gain = active_gain.copy() + km1_abs2_diffs = np.zeros_like(active_gain_flags, dtype=np.float64) + abs2_diffs_trend = np.zeros_like(active_gain_flags, dtype=np.float64) + flag_imdry = \ + flag_intermediaries(km1_gain, km1_abs2_diffs, abs2_diffs_trend) + + # Set up some intemediaries used for solving. + real_dtype = active_gain.real.dtype + param_shape = active_params.shape + + active_t_map_g = mapping_inputs.time_maps[active_term] + active_f_map_g = mapping_inputs.freq_maps[active_term] + + # Create more work to do in paralllel when needed, else no-op. + resampler = resample_solints(active_t_map_g, param_shape, n_thread) + + # Determine the starts and stops of the rows and channels associated + # with each solution interval. + extents = get_extents(resampler.upsample_t_map, active_f_map_g) + + upsample_shape = resampler.upsample_shape + upsampled_jhj = np.empty(upsample_shape + (upsample_shape[-1],), + dtype=real_dtype) + upsampled_jhr = np.empty(upsample_shape, dtype=real_dtype) + jhj = upsampled_jhj[:param_shape[0]] + jhr = upsampled_jhr[:param_shape[0]] + update = np.zeros(param_shape, dtype=real_dtype) + + upsampled_imdry = upsampled_itermediaries(upsampled_jhj, upsampled_jhr) + native_imdry = native_intermediaries(jhj, jhr, update) + + for loop_idx in range(max_iter or 1): + + compute_jhj_jhr( + inverse_gains, + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + upsampled_imdry, + extents, + corr_mode + ) + + if resampler.active: + downsample_jhj_jhr(upsampled_imdry, resampler.downsample_t_map) + + if solve_per == "array": + per_array_jhj_jhr(native_imdry) + + if not max_iter: # Non-solvable term, we just want jhj. + conv_perc = 0 # Didn't converge. + loop_idx = -1 # Did zero iterations. + break + + compute_update(native_imdry, corr_mode) + + finalize_update( + chain_inputs, + meta_inputs, + native_imdry, + loop_idx, + corr_mode + ) + + # The parameters/gains are correct but we are solving for the + # inverse so we update the inverse term here. + inverse_gains[active_term][:] = active_gain.conj() + + # Check for gain convergence. Produced as a side effect of + # flagging. The converged percentage is based on unflagged + # intervals. + conv_perc = update_gain_flags( + chain_inputs, + meta_inputs, + flag_imdry, + loop_idx, + corr_mode, + numbness=1e9 + ) + + # Propagate gain flags to parameter flags. + update_param_flags( + mapping_inputs, + chain_inputs, + meta_inputs, + identity_params + ) + + if conv_perc >= meta_inputs.stop_frac: + break + + # NOTE: Removes soft flags and flags points which have bad trends. + finalize_gain_flags( + chain_inputs, + meta_inputs, + flag_imdry, + corr_mode + ) + + # Call this one last time to ensure points flagged by finialize are + # propagated (in the DI case). + if not dd_term: + apply_gain_flags_to_flag_col( + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs + ) + + return native_imdry.jhj, loop_idx + 1, conv_perc + + return impl + + +def compute_jhj_jhr( + inverse_gains, + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + upsampled_imdry, + extents, + corr_mode +): + return NotImplementedError + + +@overload(compute_jhj_jhr, jit_options=PARALLEL_JIT_OPTIONS) +def nb_compute_jhj_jhr( + inverse_gains, + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + upsampled_imdry, + extents, + corr_mode +): + + coerce_literal(nb_compute_jhj_jhr, ["corr_mode"]) + + # We want to dispatch based on this field so we need its type. + row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS') + row_weights_type = ms_inputs[row_weights_idx] + + imul_rweight = factories.imul_rweight_factory(corr_mode, row_weights_type) + v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode) + v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode) + v1ct_imul_v2 = factories.v1ct_imul_v2_factory(corr_mode) + absv1_idiv_absv2 = factories.absv1_idiv_absv2_factory(corr_mode) + iunpack = factories.iunpack_factory(corr_mode) + iunpackct = factories.iunpackct_factory(corr_mode) + imul = factories.imul_factory(corr_mode) + iadd = factories.iadd_factory(corr_mode) + isub = factories.isub_factory(corr_mode) + valloc = factories.valloc_factory(corr_mode) + make_loop_vars = factories.loop_var_factory(corr_mode) + set_identity = factories.set_identity_factory(corr_mode) + compute_jhwj_jhwr_elem = compute_jhwj_jhwr_elem_factory(corr_mode) + + def impl( + inverse_gains, + ms_inputs, + mapping_inputs, + chain_inputs, + meta_inputs, + upsampled_imdry, + extents, + corr_mode + ): + + data = ms_inputs.DATA + model = ms_inputs.MODEL_DATA + weights = ms_inputs.WEIGHT + flags = ms_inputs.FLAG + antenna1 = ms_inputs.ANTENNA1 + antenna2 = ms_inputs.ANTENNA2 + row_map = ms_inputs.ROW_MAP + row_weights = ms_inputs.ROW_WEIGHTS + + # Reverse all the (invernse) gains and mappings. + time_maps = mapping_inputs.time_maps[::-1] + freq_maps = mapping_inputs.freq_maps[::-1] + dir_maps = mapping_inputs.dir_maps[::-1] + + gains = inverse_gains[::-1] + + # Active term in the REVERSED chain. + active_term = len(gains) - meta_inputs.active_term - 1 + + jhj = upsampled_imdry.jhj + jhr = upsampled_imdry.jhr + + _, n_chan, n_dir, n_corr = model.shape + + jhj[:] = 0 + jhr[:] = 0 + + n_tint, n_fint, n_ant, n_gdir, n_param = jhr.shape + n_int = n_tint*n_fint + + complex_dtype = gains[active_term].dtype + weight_dtype = weights.dtype + + n_gains = len(gains) + + row_starts = extents.row_starts + row_stops = extents.row_stops + chan_starts = extents.chan_starts + chan_stops = extents.chan_stops + + # Determine loop variables based on where we are in the chain. + # gt means greater than (n>j) and lt means less than (n