Skip to content

Commit

Permalink
Add option to collapse the Jones chain inside the solver (#356)
Browse files Browse the repository at this point in the history
* Add WIP on dynamically collapsing the chain of gain terms - hacky DI version working.

* Improve intelligencein collapse to keep chain as short as possible. Fix inciental regex warning.

* Chain collapse is now an optional feature which will be configured to be the default in a future commit.

* Better handling of DD case.

* Expose parameter to control chain collapse. Fix formatting.

* Add time input to parameterized terms.

* Remove time from parallactic angle term as it is now used in all terms.

* Add newline.

* Fix incorrect handling of direction dependent terms. Ignore intervals on parallactic angle term.
  • Loading branch information
JSKenyon authored Jan 16, 2025
1 parent 203a13b commit bbcff94
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 5 deletions.
168 changes: 167 additions & 1 deletion quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numba import set_num_threads
from collections import namedtuple
from itertools import cycle
from quartical.gains.general.generics import combine_gains, combine_flags
from quartical.weights.robust import robust_reweighting
from quartical.statistics.stat_kernels import compute_mean_postsolve_chisq
from quartical.statistics.logging import log_chisq
Expand Down Expand Up @@ -207,9 +208,26 @@ def solver_wrapper(
**{k: chain_kwargs.get(k, None) for k in chain_fields}
)

# NOTE: This is newer code which is difficult to test in all cases.
# If users run in to issues, particularly with direction-dependent
# terms, this may well be the culprit.
if solver_opts.collapse_chain:
mapping_inputs, chain_inputs, collapsed_term = get_collapsed_inputs(
ms_kwargs,
mapping_kwargs,
chain_kwargs,
term_spec_list,
chain,
active_term
)
else:
collapsed_term = active_term

# NOTE: Collapsed term is either the active term or its equivalent in
# the collapsed chain. See get_collapsed_inputs.
meta_inputs = meta_args_nt(
iters,
active_term,
collapsed_term,
solver_opts.convergence_fraction,
solver_opts.convergence_criteria,
solver_opts.threads,
Expand Down Expand Up @@ -283,3 +301,151 @@ def solver_wrapper(
gc.collect()

return results_dict


def get_collapsed_inputs(
ms_kwargs,
mapping_kwargs,
chain_kwargs,
term_spec_list,
chain,
active_term
):

term = chain[active_term]

_, net_t_map = np.unique(ms_kwargs["TIME"], return_inverse=True)
net_t_map = net_t_map.astype(np.int32)
n_t = mapping_kwargs["time_bins"][active_term].size
n_f = mapping_kwargs["freq_maps"][active_term].size
net_t_bins = np.arange(n_t, dtype=np.int32)
net_f_map = np.arange(n_f, dtype=np.int32)

n_a = max([s.shape[2] for s in term_spec_list])
n_d = max([s.shape[3] for s in term_spec_list])
n_c = max([s.shape[4] for s in term_spec_list])

l_terms = term_spec_list[:active_term] or None
r_terms = term_spec_list[active_term + 1:] or None

# Determine how to slice collapsed inputs to produce the shortest chain
# possible. Special cases for length one chains, and the first and last
# element in a chain.
if len(chain) == 1:
sel = slice(1, 2)
collapsed_term = 0
elif active_term == 0:
sel = slice(1, None)
collapsed_term = 0
elif active_term == len(chain) - 1:
sel = slice(0, 2)
collapsed_term = 1
else:
sel = slice(0, None)
collapsed_term = 1

if l_terms:
n_l_d = max([s.shape[3] for s in l_terms])
dir_map_func = np.arange if n_l_d > 1 else np.zeros
l_dir_map = dir_map_func(n_d, dtype=np.int32)

# TODO: Cache array to avoid allocation?
l_gain = combine_gains(
chain_kwargs["gains"][:active_term],
mapping_kwargs["time_bins"][:active_term],
mapping_kwargs["freq_maps"][:active_term],
mapping_kwargs["dir_maps"][:active_term],
(n_t, n_f, n_a, n_l_d, n_c),
n_c
)

# TODO: Add dtype option/alternative.
l_flag = combine_flags(
chain_kwargs["gain_flags"][:active_term],
mapping_kwargs["time_bins"][:active_term],
mapping_kwargs["freq_maps"][:active_term],
mapping_kwargs["dir_maps"][:active_term],
(n_t, n_f, n_a, n_l_d, n_c),
).astype(np.int8)
else:
l_dir_map, l_gain, l_flag = (None,) * 3

if r_terms:
n_r_d = max([s.shape[3] for s in r_terms])
dir_map_func = np.arange if n_r_d > 1 else np.zeros
r_dir_map = dir_map_func(n_d, dtype=np.int32)

r_gain = combine_gains(
chain_kwargs["gains"][active_term + 1:],
mapping_kwargs["time_bins"][active_term + 1:],
mapping_kwargs["freq_maps"][active_term + 1:],
mapping_kwargs["dir_maps"][active_term + 1:],
(n_t, n_f, n_a, n_r_d, n_c),
n_c
)

# TODO: Add dtype option/alternative.
r_flag = combine_flags(
chain_kwargs["gain_flags"][active_term + 1:],
mapping_kwargs["time_bins"][active_term + 1:],
mapping_kwargs["freq_maps"][active_term + 1:],
mapping_kwargs["dir_maps"][active_term + 1:],
(n_t, n_f, n_a, n_r_d, n_c)
).astype(np.int8)
else:
r_dir_map, r_gain, r_flag = (None,) * 3

mapping_kwargs = {
"time_bins": (
net_t_bins, mapping_kwargs["time_bins"][active_term], net_t_bins
),
"time_maps": (
net_t_map, mapping_kwargs["time_maps"][active_term], net_t_map
),
"freq_maps": (
net_f_map, mapping_kwargs["freq_maps"][active_term], net_f_map
),
"dir_maps": (
l_dir_map, mapping_kwargs["dir_maps"][active_term], r_dir_map
),
"param_time_bins": (
net_t_bins,
mapping_kwargs["param_time_bins"][active_term],
net_t_bins
),
"param_time_maps": (
net_t_map,
mapping_kwargs["param_time_maps"][active_term],
net_t_map
),
"param_freq_maps": (
net_f_map,
mapping_kwargs["param_freq_maps"][active_term],
net_f_map
),
}

mapping_kwargs = {k: v[sel] for k, v in mapping_kwargs.items()}

mapping_fields = term.mapping_inputs._fields
mapping_inputs = term.mapping_inputs(
**{k: mapping_kwargs.get(k, None) for k in mapping_fields}
)

chain_kwargs = {
"gains": (l_gain, chain_kwargs["gains"][active_term], r_gain),
"gain_flags": (l_flag, chain_kwargs["gain_flags"][active_term], r_flag),
"params": (chain_kwargs["params"][active_term],) * 3,
"param_flags": (
l_flag, chain_kwargs["param_flags"][active_term], r_flag
)
}

chain_kwargs = {k: v[sel] for k, v in chain_kwargs.items()}

chain_fields = term.chain_inputs._fields
chain_inputs = term.chain_inputs(
**{k: chain_kwargs.get(k, None) for k in chain_fields}
)

return mapping_inputs, chain_inputs, collapsed_term
10 changes: 10 additions & 0 deletions quartical/config/argument_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,16 @@ solver:
almost always be enabled so that data associated with diverging gains
is properly flagged.

collapse_chain:
dtype: bool
default: True
info:
Determines whether the chain is collapsed into the minimum number of
terms inside the solver. This will typically increase memory footprint,
but may speed up calibration when utilising many gain terms. Set to false
to apply every term on-the-fly inside the solver (behaviour prior to
v0.2.6).

robust:
dtype: bool
default: False
Expand Down
2 changes: 1 addition & 1 deletion quartical/config/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def help():
print_help(HelpConfig)
else:
selection = help_arg.split("=")[-1]
selection = re.sub('[\[\] ]', "", selection) # noqa
selection = re.sub(r'[\[\] ]', "", selection)
selection = selection.split(",")
print_help(HelpConfig, section_names=selection)

Expand Down
3 changes: 2 additions & 1 deletion quartical/gains/gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
"WEIGHT",
"FLAG",
"ROW_MAP",
"ROW_WEIGHTS"
"ROW_WEIGHTS",
"TIME"
)
)

Expand Down
6 changes: 5 additions & 1 deletion quartical/gains/parallactic_angle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ms_inputs = namedtuple(
'ms_inputs',
ParameterizedGain.ms_inputs._fields + \
('RECEPTOR_ANGLE', 'POSITION', 'TIME')
('RECEPTOR_ANGLE', 'POSITION')
)

class ParallacticAngle(ParameterizedGain):
Expand All @@ -38,6 +38,10 @@ def __init__(self, term_name, term_opts):

super().__init__(term_name, term_opts)

# NOTE: Ignore user-specified values on this term.
self.time_interval = 1
self.freq_interval = 0

@classmethod
def _make_freq_map(cls, chan_freqs, chan_widths, freq_interval):
# Overload gain mapping construction - we evaluate it in every channel.
Expand Down
3 changes: 2 additions & 1 deletion quartical/gains/parameterized_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"WEIGHT",
"FLAG",
"ROW_MAP",
"ROW_WEIGHTS"
"ROW_WEIGHTS",
"TIME"
)
)

Expand Down

0 comments on commit bbcff94

Please sign in to comment.