Skip to content

Commit

Permalink
Add a scalar mode to diagonal parameterised Jones terms (#358)
Browse files Browse the repository at this point in the history
* Checkpoint WIP on generic scalar terms.

* Update all terms to accept/error out appropriately when invoked in scalar mode.

* Add scalar mode support for diag_complex terms.

* Add some tests for scalar mode. Does not test all terms, but probes all functionality.
  • Loading branch information
JSKenyon authored Jan 16, 2025
1 parent bbcff94 commit 7aabafc
Show file tree
Hide file tree
Showing 22 changed files with 209 additions and 33 deletions.
2 changes: 2 additions & 0 deletions quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"threads",
"robust",
"reference_antenna",
"scalar",
"dd_term",
"pinned_directions",
"solve_per",
Expand Down Expand Up @@ -233,6 +234,7 @@ def solver_wrapper(
solver_opts.threads,
solver_opts.robust,
solver_opts.reference_antenna,
term.scalar,
term.direction_dependent,
term.pinned_directions,
term.solve_per
Expand Down
10 changes: 9 additions & 1 deletion quartical/config/gain_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,17 @@ gain:
Determines whether this term should be solved per antenna (conventional)
or over the entire array (doesn't vary with antenna).

scalar:
dtype: bool
default: false
info:
Determines whether the term is treated as scalar i.e. whether it is
solved for as a single effect over all correlations. This is only
supported for terms which would otherwise be diagonal.

direction_dependent:
dtype: bool
default: False
default: false
info:
Determines whether this term is treated as direction dependent.

Expand Down
7 changes: 6 additions & 1 deletion quartical/gains/amplitude/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr)
downsample_jhj_jhr,
scalar_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
Expand Down Expand Up @@ -86,6 +87,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -144,6 +146,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 1)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
35 changes: 35 additions & 0 deletions quartical/gains/complex/diag_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -129,6 +130,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down Expand Up @@ -677,3 +681,34 @@ def impl(chain_inputs, meta_inputs, mode):
apply_gain_flags_to_gains(gain_flags, gains)

return impl


@njit(**JIT_OPTIONS)
def scalar_jhj_jhr(solver_imdry):
"""This manipulates the entries of jhj and jhr to be scalar."""

# NOTE: This differes from the generic implmenentation in generics.py.

jhj = solver_imdry.jhj
jhr = solver_imdry.jhr

n_tint, n_fint, n_ant, n_dir, n_corr = jhj.shape

for t in range(n_tint):
for f in range(n_fint):
for a in range(n_ant):
for d in range(n_dir):

jhr_sel = jhr[t, f, a, d]
jhj_sel = jhj[t, f, a, d]

# Sum to a single scalar element.
for p in range(1, n_corr):
jhr_sel[0] += jhr_sel[p]
jhr_sel[p] = 0
jhj_sel[0] += jhj_sel[p]
jhj_sel[p] = 0

# Repopluate appropriate zeroed values from scalar sum.
jhr_sel[-1] = jhr_sel[0]
jhj_sel[-1] = jhj_sel[0]
4 changes: 4 additions & 0 deletions quartical/gains/complex/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -132,6 +133,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError("Scalar mode not supported for complex terms.")

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/crosshand_phase/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -141,6 +142,11 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError(
"Scalar mode not supported for crosshand phase terms."
)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/crosshand_phase/null_v_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -149,6 +150,11 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError(
"Scalar mode not supported for crosshand phase terms."
)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
7 changes: 6 additions & 1 deletion quartical/gains/delay/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr
downsample_jhj_jhr,
scalar_jhj_jhr
)
from quartical.gains.general.flagging import (
flag_intermediaries,
Expand Down Expand Up @@ -94,6 +95,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -157,6 +159,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 1)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
50 changes: 31 additions & 19 deletions quartical/gains/delay_and_offset/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,34 @@
import numpy as np
from numba import prange, njit
from numba.extending import overload
from quartical.utils.numba import (coerce_literal,
JIT_OPTIONS,
PARALLEL_JIT_OPTIONS)
from quartical.gains.general.generics import (native_intermediaries,
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
apply_gain_flags_to_flag_col,
update_param_flags,
apply_gain_flags_to_gains,
apply_param_flags_to_params)
from quartical.gains.general.convenience import (get_row,
get_extents)
from quartical.utils.numba import (
coerce_literal,
JIT_OPTIONS,
PARALLEL_JIT_OPTIONS
)
from quartical.gains.general.generics import (
native_intermediaries,
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr,
scalar_jhj_jhr
)
from quartical.gains.general.flagging import (
flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
apply_gain_flags_to_flag_col,
update_param_flags,
apply_gain_flags_to_gains,
apply_param_flags_to_params
)
from quartical.gains.general.convenience import get_row, get_extents
import quartical.gains.general.factories as factories
from quartical.gains.general.inversion import (invert_factory,
inversion_buffer_factory)
from quartical.gains.general.inversion import (
invert_factory,
inversion_buffer_factory
)


def get_identity_params(corr_mode):
Expand Down Expand Up @@ -88,6 +96,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -151,6 +160,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 2)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
7 changes: 6 additions & 1 deletion quartical/gains/delay_and_tec/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr)
downsample_jhj_jhr,
scalar_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
Expand Down Expand Up @@ -88,6 +89,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -155,6 +157,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 2)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
1 change: 1 addition & 0 deletions quartical/gains/gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, term_name, term_opts):
self.name = term_name
self.type = term_opts.type
self.solve_per = term_opts.solve_per
self.scalar = term_opts.scalar
self.direction_dependent = term_opts.direction_dependent
self.pinned_directions = term_opts.pinned_directions
self.time_interval = term_opts.time_interval
Expand Down
40 changes: 40 additions & 0 deletions quartical/gains/general/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,46 @@ def per_array_jhj_jhr(solver_imdry):
jhr[t, f, a] = jhr[t, f, 0]


@njit(**JIT_OPTIONS)
def scalar_jhj_jhr(solver_imdry, values_per_correlation):
"""This manipulates the entries of jhj and jhr to be scalar."""

jhj = solver_imdry.jhj
jhr = solver_imdry.jhr

vpc = values_per_correlation # For brevity.

n_tint, n_fint, n_ant, n_dir, n_par, _ = jhj.shape

for t in range(n_tint):
for f in range(n_fint):
for a in range(n_ant):
for d in range(n_dir):

jhr_sel = jhr[t, f, a, d]
jhj_sel = jhj[t, f, a, d]

# Sum bottom half into top half.
for p in range(vpc, n_par):
jhr_sel[p % vpc] += jhr_sel[p]

# Sum right half into left half and zero.
for p0 in range(n_par):
for p1 in range(vpc, n_par):
jhj_sel[p0, p1 % vpc] += jhj_sel[p0, p1]
jhj_sel[p0, p1] = 0

# Sum bottom half into top half and zero.
for p0 in range(vpc, n_par):
for p1 in range(vpc):
jhj_sel[p0 % vpc, p1] += jhj_sel[p0, p1]
jhj_sel[p0, p1] = 0

# Repopluate zeroed values from scalar sum.
jhr_sel[vpc:] = jhr_sel[:vpc]
jhj_sel[vpc:, vpc:] = jhj_sel[:vpc, :vpc]


@njit(**JIT_OPTIONS)
def resample_solints(native_map, native_shape, n_thread):

Expand Down
4 changes: 4 additions & 0 deletions quartical/gains/leakage/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -129,6 +130,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError("Scalar mode not supported for leakage terms.")

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
Loading

0 comments on commit 7aabafc

Please sign in to comment.