From bbcff943db32600d0a5caea5ba4577f4897d643d Mon Sep 17 00:00:00 2001 From: JSKenyon Date: Thu, 16 Jan 2025 14:45:31 +0200 Subject: [PATCH] Add option to collapse the Jones chain inside the solver (#356) * Add WIP on dynamically collapsing the chain of gain terms - hacky DI version working. * Improve intelligencein collapse to keep chain as short as possible. Fix inciental regex warning. * Chain collapse is now an optional feature which will be configured to be the default in a future commit. * Better handling of DD case. * Expose parameter to control chain collapse. Fix formatting. * Add time input to parameterized terms. * Remove time from parallactic angle term as it is now used in all terms. * Add newline. * Fix incorrect handling of direction dependent terms. Ignore intervals on parallactic angle term. --- quartical/calibration/solver.py | 168 +++++++++++++++++- quartical/config/argument_schema.yaml | 10 ++ quartical/config/helper.py | 2 +- quartical/gains/gain.py | 3 +- quartical/gains/parallactic_angle/__init__.py | 6 +- quartical/gains/parameterized_gain.py | 3 +- 6 files changed, 187 insertions(+), 5 deletions(-) diff --git a/quartical/calibration/solver.py b/quartical/calibration/solver.py index e36586bf..3c747f26 100644 --- a/quartical/calibration/solver.py +++ b/quartical/calibration/solver.py @@ -4,6 +4,7 @@ from numba import set_num_threads from collections import namedtuple from itertools import cycle +from quartical.gains.general.generics import combine_gains, combine_flags from quartical.weights.robust import robust_reweighting from quartical.statistics.stat_kernels import compute_mean_postsolve_chisq from quartical.statistics.logging import log_chisq @@ -207,9 +208,26 @@ def solver_wrapper( **{k: chain_kwargs.get(k, None) for k in chain_fields} ) + # NOTE: This is newer code which is difficult to test in all cases. + # If users run in to issues, particularly with direction-dependent + # terms, this may well be the culprit. + if solver_opts.collapse_chain: + mapping_inputs, chain_inputs, collapsed_term = get_collapsed_inputs( + ms_kwargs, + mapping_kwargs, + chain_kwargs, + term_spec_list, + chain, + active_term + ) + else: + collapsed_term = active_term + + # NOTE: Collapsed term is either the active term or its equivalent in + # the collapsed chain. See get_collapsed_inputs. meta_inputs = meta_args_nt( iters, - active_term, + collapsed_term, solver_opts.convergence_fraction, solver_opts.convergence_criteria, solver_opts.threads, @@ -283,3 +301,151 @@ def solver_wrapper( gc.collect() return results_dict + + +def get_collapsed_inputs( + ms_kwargs, + mapping_kwargs, + chain_kwargs, + term_spec_list, + chain, + active_term +): + + term = chain[active_term] + + _, net_t_map = np.unique(ms_kwargs["TIME"], return_inverse=True) + net_t_map = net_t_map.astype(np.int32) + n_t = mapping_kwargs["time_bins"][active_term].size + n_f = mapping_kwargs["freq_maps"][active_term].size + net_t_bins = np.arange(n_t, dtype=np.int32) + net_f_map = np.arange(n_f, dtype=np.int32) + + n_a = max([s.shape[2] for s in term_spec_list]) + n_d = max([s.shape[3] for s in term_spec_list]) + n_c = max([s.shape[4] for s in term_spec_list]) + + l_terms = term_spec_list[:active_term] or None + r_terms = term_spec_list[active_term + 1:] or None + + # Determine how to slice collapsed inputs to produce the shortest chain + # possible. Special cases for length one chains, and the first and last + # element in a chain. + if len(chain) == 1: + sel = slice(1, 2) + collapsed_term = 0 + elif active_term == 0: + sel = slice(1, None) + collapsed_term = 0 + elif active_term == len(chain) - 1: + sel = slice(0, 2) + collapsed_term = 1 + else: + sel = slice(0, None) + collapsed_term = 1 + + if l_terms: + n_l_d = max([s.shape[3] for s in l_terms]) + dir_map_func = np.arange if n_l_d > 1 else np.zeros + l_dir_map = dir_map_func(n_d, dtype=np.int32) + + # TODO: Cache array to avoid allocation? + l_gain = combine_gains( + chain_kwargs["gains"][:active_term], + mapping_kwargs["time_bins"][:active_term], + mapping_kwargs["freq_maps"][:active_term], + mapping_kwargs["dir_maps"][:active_term], + (n_t, n_f, n_a, n_l_d, n_c), + n_c + ) + + # TODO: Add dtype option/alternative. + l_flag = combine_flags( + chain_kwargs["gain_flags"][:active_term], + mapping_kwargs["time_bins"][:active_term], + mapping_kwargs["freq_maps"][:active_term], + mapping_kwargs["dir_maps"][:active_term], + (n_t, n_f, n_a, n_l_d, n_c), + ).astype(np.int8) + else: + l_dir_map, l_gain, l_flag = (None,) * 3 + + if r_terms: + n_r_d = max([s.shape[3] for s in r_terms]) + dir_map_func = np.arange if n_r_d > 1 else np.zeros + r_dir_map = dir_map_func(n_d, dtype=np.int32) + + r_gain = combine_gains( + chain_kwargs["gains"][active_term + 1:], + mapping_kwargs["time_bins"][active_term + 1:], + mapping_kwargs["freq_maps"][active_term + 1:], + mapping_kwargs["dir_maps"][active_term + 1:], + (n_t, n_f, n_a, n_r_d, n_c), + n_c + ) + + # TODO: Add dtype option/alternative. + r_flag = combine_flags( + chain_kwargs["gain_flags"][active_term + 1:], + mapping_kwargs["time_bins"][active_term + 1:], + mapping_kwargs["freq_maps"][active_term + 1:], + mapping_kwargs["dir_maps"][active_term + 1:], + (n_t, n_f, n_a, n_r_d, n_c) + ).astype(np.int8) + else: + r_dir_map, r_gain, r_flag = (None,) * 3 + + mapping_kwargs = { + "time_bins": ( + net_t_bins, mapping_kwargs["time_bins"][active_term], net_t_bins + ), + "time_maps": ( + net_t_map, mapping_kwargs["time_maps"][active_term], net_t_map + ), + "freq_maps": ( + net_f_map, mapping_kwargs["freq_maps"][active_term], net_f_map + ), + "dir_maps": ( + l_dir_map, mapping_kwargs["dir_maps"][active_term], r_dir_map + ), + "param_time_bins": ( + net_t_bins, + mapping_kwargs["param_time_bins"][active_term], + net_t_bins + ), + "param_time_maps": ( + net_t_map, + mapping_kwargs["param_time_maps"][active_term], + net_t_map + ), + "param_freq_maps": ( + net_f_map, + mapping_kwargs["param_freq_maps"][active_term], + net_f_map + ), + } + + mapping_kwargs = {k: v[sel] for k, v in mapping_kwargs.items()} + + mapping_fields = term.mapping_inputs._fields + mapping_inputs = term.mapping_inputs( + **{k: mapping_kwargs.get(k, None) for k in mapping_fields} + ) + + chain_kwargs = { + "gains": (l_gain, chain_kwargs["gains"][active_term], r_gain), + "gain_flags": (l_flag, chain_kwargs["gain_flags"][active_term], r_flag), + "params": (chain_kwargs["params"][active_term],) * 3, + "param_flags": ( + l_flag, chain_kwargs["param_flags"][active_term], r_flag + ) + } + + chain_kwargs = {k: v[sel] for k, v in chain_kwargs.items()} + + chain_fields = term.chain_inputs._fields + chain_inputs = term.chain_inputs( + **{k: chain_kwargs.get(k, None) for k in chain_fields} + ) + + return mapping_inputs, chain_inputs, collapsed_term diff --git a/quartical/config/argument_schema.yaml b/quartical/config/argument_schema.yaml index af3b369a..46164086 100644 --- a/quartical/config/argument_schema.yaml +++ b/quartical/config/argument_schema.yaml @@ -357,6 +357,16 @@ solver: almost always be enabled so that data associated with diverging gains is properly flagged. + collapse_chain: + dtype: bool + default: True + info: + Determines whether the chain is collapsed into the minimum number of + terms inside the solver. This will typically increase memory footprint, + but may speed up calibration when utilising many gain terms. Set to false + to apply every term on-the-fly inside the solver (behaviour prior to + v0.2.6). + robust: dtype: bool default: False diff --git a/quartical/config/helper.py b/quartical/config/helper.py index d81db760..4194d8b4 100644 --- a/quartical/config/helper.py +++ b/quartical/config/helper.py @@ -88,7 +88,7 @@ def help(): print_help(HelpConfig) else: selection = help_arg.split("=")[-1] - selection = re.sub('[\[\] ]', "", selection) # noqa + selection = re.sub(r'[\[\] ]', "", selection) selection = selection.split(",") print_help(HelpConfig, section_names=selection) diff --git a/quartical/gains/gain.py b/quartical/gains/gain.py index 0b18d4f2..805563df 100644 --- a/quartical/gains/gain.py +++ b/quartical/gains/gain.py @@ -39,7 +39,8 @@ "WEIGHT", "FLAG", "ROW_MAP", - "ROW_WEIGHTS" + "ROW_WEIGHTS", + "TIME" ) ) diff --git a/quartical/gains/parallactic_angle/__init__.py b/quartical/gains/parallactic_angle/__init__.py index 9afc079f..5cc0513a 100644 --- a/quartical/gains/parallactic_angle/__init__.py +++ b/quartical/gains/parallactic_angle/__init__.py @@ -16,7 +16,7 @@ ms_inputs = namedtuple( 'ms_inputs', ParameterizedGain.ms_inputs._fields + \ - ('RECEPTOR_ANGLE', 'POSITION', 'TIME') + ('RECEPTOR_ANGLE', 'POSITION') ) class ParallacticAngle(ParameterizedGain): @@ -38,6 +38,10 @@ def __init__(self, term_name, term_opts): super().__init__(term_name, term_opts) + # NOTE: Ignore user-specified values on this term. + self.time_interval = 1 + self.freq_interval = 0 + @classmethod def _make_freq_map(cls, chan_freqs, chan_widths, freq_interval): # Overload gain mapping construction - we evaluate it in every channel. diff --git a/quartical/gains/parameterized_gain.py b/quartical/gains/parameterized_gain.py index ab68361c..cbd5f4a3 100644 --- a/quartical/gains/parameterized_gain.py +++ b/quartical/gains/parameterized_gain.py @@ -14,7 +14,8 @@ "WEIGHT", "FLAG", "ROW_MAP", - "ROW_WEIGHTS" + "ROW_WEIGHTS", + "TIME" ) )