diff --git a/quartical/config/external.py b/quartical/config/external.py index 2c5db43c..4fe4f71d 100644 --- a/quartical/config/external.py +++ b/quartical/config/external.py @@ -208,6 +208,7 @@ class Gain(Input): "pure_delay", "phase", "tec", + "rotation", "rotation_measure", "crosshand_phase", "leakage"]) diff --git a/quartical/gains/__init__.py b/quartical/gains/__init__.py index 5a17d2c7..d133ab4b 100644 --- a/quartical/gains/__init__.py +++ b/quartical/gains/__init__.py @@ -3,6 +3,7 @@ from quartical.gains.phase import Phase from quartical.gains.delay import Delay, PureDelay from quartical.gains.tec import TEC +from quartical.gains.rotation import Rotation from quartical.gains.rotation_measure import RotationMeasure from quartical.gains.crosshand_phase import CrosshandPhase from quartical.gains.leakage import Leakage @@ -15,6 +16,7 @@ "delay": Delay, "pure_delay": PureDelay, "tec": TEC, + "rotation": Rotation, "rotation_measure": RotationMeasure, "crosshand_phase": CrosshandPhase, "leakage": Leakage} diff --git a/quartical/gains/rotation/__init__.py b/quartical/gains/rotation/__init__.py new file mode 100644 index 00000000..bf7a5b3c --- /dev/null +++ b/quartical/gains/rotation/__init__.py @@ -0,0 +1,42 @@ +from quartical.gains.gain import Gain, gain_spec_tup, param_spec_tup +from quartical.gains.rotation.kernel import rotation_solver, rotation_args +import numpy as np + + +class Rotation(Gain): + + solver = rotation_solver + term_args = rotation_args + + def __init__(self, term_name, term_opts, data_xds, coords, tipc, fipc): + + Gain.__init__(self, term_name, term_opts, data_xds, coords, tipc, fipc) + + self.n_param = 1 # This term only makes sense in a 2x2 chain. + self.gain_chunk_spec = gain_spec_tup(self.n_tipc_g, + self.n_fipc_g, + (self.n_ant,), + (self.n_dir,), + (self.n_corr,)) + self.param_chunk_spec = param_spec_tup(self.n_tipc_g, # Check! + self.n_fipc_g, + (self.n_ant,), + (self.n_dir,), + (self.n_param,)) + + self.gain_axes = ("gain_t", "gain_f", "ant", "dir", "corr") + self.param_axes = ("param_t", "param_f", "ant", "dir", "param") + + def make_xds(self): + + xds = Gain.make_xds(self) + + xds = xds.assign_coords({"param": np.array(["rotation"]), + "param_t": self.gain_times, + "param_f": self.gain_freqs}) + xds = xds.assign_attrs({"GAIN_SPEC": self.gain_chunk_spec, + "PARAM_SPEC": self.param_chunk_spec, + "GAIN_AXES": self.gain_axes, + "PARAM_AXES": self.param_axes}) + + return xds diff --git a/quartical/gains/rotation/kernel.py b/quartical/gains/rotation/kernel.py new file mode 100644 index 00000000..cdd40b62 --- /dev/null +++ b/quartical/gains/rotation/kernel.py @@ -0,0 +1,570 @@ +# -*- coding: utf-8 -*- +import numpy as np +from numba import prange, generated_jit +from quartical.utils.numba import coerce_literal +from quartical.gains.general.generics import (native_intermediaries, + upsampled_itermediaries, + per_array_jhj_jhr, + resample_solints, + downsample_jhj_jhr) +from quartical.gains.general.flagging import (flag_intermediaries, + update_gain_flags, + finalize_gain_flags, + apply_gain_flags, + update_param_flags) +from quartical.gains.general.convenience import (get_row, + get_extents) +import quartical.gains.general.factories as factories +from quartical.gains.general.inversion import (invert_factory, + inversion_buffer_factory) +from collections import namedtuple + + +# This can be done without a named tuple now. TODO: Add unpacking to +# constructor. +stat_fields = {"conv_iters": np.int64, + "conv_perc": np.float64} + +term_conv_info = namedtuple("term_conv_info", " ".join(stat_fields.keys())) + +rotation_args = namedtuple( + "rotation_args", + ( + "params", + "param_flags", + "t_bin_arr" + ) +) + + +def get_identity_params(corr_mode): + + if corr_mode.literal_value == 4: + return np.zeros((1,), dtype=np.float64) + else: + raise ValueError("Unsupported number of correlations.") + + +@generated_jit(nopython=True, + fastmath=True, + parallel=False, + cache=True, + nogil=True) +def rotation_solver(base_args, term_args, meta_args, corr_mode): + + coerce_literal(rotation_solver, ["corr_mode"]) + + identity_params = get_identity_params(corr_mode) + + def impl(base_args, term_args, meta_args, corr_mode): + + gains = base_args.gains + gain_flags = base_args.gain_flags + + active_term = meta_args.active_term + max_iter = meta_args.iters + solve_per = meta_args.solve_per + dd_term = meta_args.dd_term + n_thread = meta_args.threads + + active_gain = gains[active_term] + active_gain_flags = gain_flags[active_term] + active_params = term_args.params[active_term] + + # Set up some intemediaries used for flagging. TODO: Move? + km1_gain = active_gain.copy() + km1_abs2_diffs = np.zeros_like(active_gain_flags, dtype=np.float64) + abs2_diffs_trend = np.zeros_like(active_gain_flags, dtype=np.float64) + flag_imdry = \ + flag_intermediaries(km1_gain, km1_abs2_diffs, abs2_diffs_trend) + + # Set up some intemediaries used for solving. + real_dtype = active_gain.real.dtype + param_shape = active_params.shape + + active_t_map_g = base_args.t_map_arr[0, :, active_term] + active_f_map_p = base_args.f_map_arr[1, :, active_term] + + # Create more work to do in paralllel when needed, else no-op. + resampler = resample_solints(active_t_map_g, param_shape, n_thread) + + # Determine the starts and stops of the rows and channels associated + # with each solution interval. + extents = get_extents(resampler.upsample_t_map, active_f_map_p) + + upsample_shape = resampler.upsample_shape + upsampled_jhj = np.empty(upsample_shape + (upsample_shape[-1],), + dtype=real_dtype) + upsampled_jhr = np.empty(upsample_shape, dtype=real_dtype) + jhj = upsampled_jhj[:param_shape[0]] + jhr = upsampled_jhr[:param_shape[0]] + update = np.zeros(param_shape, dtype=real_dtype) + + upsampled_imdry = upsampled_itermediaries(upsampled_jhj, upsampled_jhr) + native_imdry = native_intermediaries(jhj, jhr, update) + + for loop_idx in range(max_iter): + + compute_jhj_jhr(base_args, + term_args, + meta_args, + upsampled_imdry, + extents, + corr_mode) + + if resampler.active: + downsample_jhj_jhr(upsampled_imdry, resampler.downsample_t_map) + + if solve_per == "array": + per_array_jhj_jhr(native_imdry) + + compute_update(native_imdry, + corr_mode) + + finalize_update(base_args, + term_args, + meta_args, + native_imdry, + loop_idx, + corr_mode) + + # Check for gain convergence. Produced as a side effect of + # flagging. The converged percentage is based on unflagged + # intervals. + conv_perc = update_gain_flags(base_args, + term_args, + meta_args, + flag_imdry, + loop_idx, + corr_mode, + numbness=1e9) + + # Propagate gain flags to parameter flags. + update_param_flags(base_args, + term_args, + meta_args, + identity_params) + + if conv_perc >= meta_args.stop_frac: + break + + # NOTE: Removes soft flags and flags points which have bad trends. + finalize_gain_flags(base_args, + meta_args, + flag_imdry, + corr_mode) + + # Call this one last time to ensure points flagged by finialize are + # propagated (in the DI case). + if not dd_term: + apply_gain_flags(base_args, + meta_args) + + return jhj, term_conv_info(loop_idx + 1, conv_perc) + + return impl + + +@generated_jit(nopython=True, + fastmath=True, + parallel=True, + cache=True, + nogil=True) +def compute_jhj_jhr( + base_args, + term_args, + meta_args, + upsampled_imdry, + extents, + corr_mode +): + + # We want to dispatch based on this field so we need its type. + row_weights = base_args[base_args.fields.index('row_weights')] + + imul_rweight = factories.imul_rweight_factory(corr_mode, row_weights) + v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode) + v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode) + v1ct_imul_v2 = factories.v1ct_imul_v2_factory(corr_mode) + iunpack = factories.iunpack_factory(corr_mode) + iunpackct = factories.iunpackct_factory(corr_mode) + imul = factories.imul_factory(corr_mode) + iadd = factories.iadd_factory(corr_mode) + isub = factories.isub_factory(corr_mode) + valloc = factories.valloc_factory(corr_mode) + make_loop_vars = factories.loop_var_factory(corr_mode) + set_identity = factories.set_identity_factory(corr_mode) + compute_jhwj_jhwr_elem = compute_jhwj_jhwr_elem_factory(corr_mode) + + def impl( + base_args, + term_args, + meta_args, + upsampled_imdry, + extents, + corr_mode + ): + + active_term = meta_args.active_term + + data = base_args.data + model = base_args.model + weights = base_args.weights + flags = base_args.flags + antenna1 = base_args.a1 + antenna2 = base_args.a2 + row_map = base_args.row_map + row_weights = base_args.row_weights + + gains = base_args.gains + params = term_args.params[active_term] + t_map_arr = base_args.t_map_arr[0] # We only need the gain mappings. + f_map_arr_g = base_args.f_map_arr[0] + f_map_arr_p = base_args.f_map_arr[1] + d_map_arr = base_args.d_map_arr + + jhj = upsampled_imdry.jhj + jhr = upsampled_imdry.jhr + + _, n_chan, n_dir, n_corr = model.shape + + jhj[:] = 0 + jhr[:] = 0 + + n_tint, n_fint, n_ant, n_gdir, n_param = jhr.shape + n_int = n_tint*n_fint + + complex_dtype = gains[active_term].dtype + weight_dtype = weights.dtype + + n_gains = len(gains) + + row_starts = extents.row_starts + row_stops = extents.row_stops + chan_starts = extents.chan_starts + chan_stops = extents.chan_stops + + # Determine loop variables based on where we are in the chain. + # gt means greater than (n>j) and lt means less than (n