Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXPERIMENTAL: Add a crosshand phase solver which expolits the zero Stokes V assumption. #344

Merged
merged 10 commits into from
Nov 29, 2024
3 changes: 3 additions & 0 deletions quartical/apps/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def to_plot_dict(xdsl, iter_attrs):


def _plot(group, xds, args):
# get rid of question marks
qstrip = lambda x: x.replace('?', 'N/A')
group = tuple(map(qstrip, group))

xds = xds.compute(scheduler="single-threaded")

Expand Down
10 changes: 5 additions & 5 deletions quartical/calibration/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


term_spec_tup = namedtuple("term_spec_tup", "name type shape pshape")
aux_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")
log_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")


def construct_solver(
Expand Down Expand Up @@ -54,9 +54,9 @@ def construct_solver(
corr_mode = data_xds.sizes["corr"]

block_id_arr = get_block_id_arr(data_col)
aux_block_info = {
k: data_xds.attrs.get(k, "?") for k in aux_info_fields
}
data_xds_meta = data_xds.attrs.copy()
for k in log_info_fields:
data_xds_meta[k] = data_xds_meta.get(k, "?")

# Grab the number of input chunks - doing this on the data should be
# safe.
Expand Down Expand Up @@ -87,7 +87,7 @@ def construct_solver(
)
blocker.add_input("term_spec_list", spec_list, ("row", "chan"))
blocker.add_input("corr_mode", corr_mode)
blocker.add_input("aux_block_info", aux_block_info)
blocker.add_input("data_xds_meta", data_xds_meta)
blocker.add_input("solver_opts", solver_opts)
blocker.add_input("chain", chain)

Expand Down
27 changes: 16 additions & 11 deletions quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def solver_wrapper(
solver_opts,
chain,
block_id_arr,
aux_block_info,
data_xds_meta,
corr_mode,
**kwargs
):
Expand Down Expand Up @@ -108,11 +108,11 @@ def solver_wrapper(
# Perform term specific setup e.g. init gains and params.
if term.is_parameterized:
gains, gain_flags, params, param_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
else:
gains, gain_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
# Dummy arrays with standard dtypes - aids compilation.
params = np.empty(term_pshape, dtype=np.float64)
Expand Down Expand Up @@ -190,6 +190,7 @@ def solver_wrapper(
for ind, (term, iters) in enumerate(zip(cycle(chain), iter_recipe)):

active_term = chain.index(term)
active_spec = term_spec_list[term_ind]

ms_fields = term.ms_inputs._fields
ms_inputs = term.ms_inputs(
Expand Down Expand Up @@ -219,13 +220,17 @@ def solver_wrapper(
term.solve_per
)

jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
if term.solver:
jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
else:
jhj = np.zeros(getattr(active_spec, "pshape", active_spec.shape))
conv_iter, conv_perc = 0, 1

# If reweighting is enabled, do it when the epoch changes, except
# for the final epoch - we don't reweight if we won't solve again.
Expand Down Expand Up @@ -269,7 +274,7 @@ def solver_wrapper(
corr_mode
)

log_chisq(presolve_chisq, postsolve_chisq, aux_block_info, block_id)
log_chisq(presolve_chisq, postsolve_chisq, data_xds_meta, block_id)

results_dict["presolve_chisq"] = presolve_chisq
results_dict["postsolve_chisq"] = postsolve_chisq
Expand Down
2 changes: 2 additions & 0 deletions quartical/config/gain_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ gain:
- rotation_measure
- rotation
- crosshand_phase
- crosshand_phase_null_v
- leakage
- parallactic_angle
info:
Type of gain to solve for.

Expand Down
82 changes: 41 additions & 41 deletions quartical/data_handling/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@
_thread_local = threading.local()


def assign_parangle_data(ms_path, data_xds_list):

anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0]
feedtab = xds_from_storage_table(ms_path + "::FEED")[0]
fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0]

# We do the following eagerly to reduce graph complexity.
feeds = feedtab.POLARIZATION_TYPE.values
unique_feeds = np.unique(feeds)

if np.all([feed in "XxYy" for feed in unique_feeds]):
feed_type = "linear"
elif np.all([feed in "LlRr" for feed in unique_feeds]):
feed_type = "circular"
else:
raise ValueError("Unsupported feed type/configuration.")

phase_dirs = fieldtab.PHASE_DIR.values

updated_data_xds_list = []
for xds in data_xds_list:
xds = xds.assign(
{
"RECEPTOR_ANGLE": (
("ant", "feed"), clone(feedtab.RECEPTOR_ANGLE.data)
),
"POSITION": (
("ant", "xyz"),
clone(anttab.POSITION.data)
)
}
)
xds.attrs["FEED_TYPE"] = feed_type
xds.attrs["FIELD_CENTRE"] = tuple(phase_dirs[xds.FIELD_ID, 0])

updated_data_xds_list.append(xds)

return updated_data_xds_list


def make_parangle_xds_list(ms_path, data_xds_list):
"""Create a list of xarray.Datasets containing the parallactic angles."""

Expand Down Expand Up @@ -266,7 +306,7 @@ def nb_apply_parangle_rot(data_col, parangles, utime_ind, ant1_col, ant2_col,
v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode)
v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode)
valloc = factories.valloc_factory(corr_mode)
rotmat = rotation_factory(corr_mode, feed_type)
rotmat = factories.rotation_factory(corr_mode, feed_type)

def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
corr_mode, feed_type):
Expand Down Expand Up @@ -299,43 +339,3 @@ def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
return data_col

return impl


def rotation_factory(corr_mode, feed_type):

if feed_type.literal_value == "circular":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = 0
out[2] = 0
out[3] = np.exp(1j*rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = np.exp(1j*rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
else:
raise ValueError("Unsupported number of correlations.")
elif feed_type.literal_value == "linear":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.sin(rot0)
out[2] = -np.sin(rot1)
out[3] = np.cos(rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.cos(rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
else:
raise ValueError("Unsupported number of correlations.")
else:
raise ValueError("Unsupported feed type.")

return factories.qcjit(impl)
23 changes: 17 additions & 6 deletions quartical/data_handling/ms_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from quartical.data_handling.selection import filter_xds_list
from quartical.data_handling.angles import apply_parangles

DASKMS_ATTRS = {
"__daskms_partition_schema__",
"SCAN_NUMBER",
"FIELD_ID",
"DATA_DESC_ID"
}


def read_xds_list(model_columns, ms_opts):
"""Reads a measurement set and generates a list of xarray data sets.
Expand Down Expand Up @@ -293,14 +300,18 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):

logger.info("Outputs will be written to {}.", ", ".join(output_cols))

# Select only the output columns to simplify datasets.
xds_list = [xds[list(output_cols)] for xds in xds_list]

# Remove all coords bar ROWID so that they do not get written.
xds_list = [
xds.drop_vars(set(xds.coords.keys()) - {"ROWID"}, errors='ignore')
for xds in xds_list
]

# Remove attrs added by QuartiCal so that they do not get written.
for xds in xds_list:
xds.attrs.pop("UTIME_CHUNKS", None)
xds.attrs.pop("FIELD_NAME", None)

# Remove coords added by QuartiCal so that they do not get written.
xds_list = [xds.drop_vars(["chan", "corr"], errors='ignore')
for xds in xds_list]
xds.attrs = {k: v for k, v in xds.attrs.items() if k in DASKMS_ATTRS}

with warnings.catch_warnings(): # We anticipate spurious warnings.
warnings.simplefilter("ignore")
Expand Down
4 changes: 3 additions & 1 deletion quartical/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
preprocess_xds_list,
postprocess_xds_list)
from quartical.data_handling.model_handler import add_model_graph
from quartical.data_handling.angles import make_parangle_xds_list
from quartical.data_handling.angles import (make_parangle_xds_list,
assign_parangle_data)
from quartical.calibration.calibrate import add_calibration_graph
from quartical.statistics.statistics import make_stats_xds_list
from quartical.statistics.logging import log_summary_stats
Expand Down Expand Up @@ -110,6 +111,7 @@ def _execute(exitstack):

# Preprocess the xds_list - initialise some values and fix bad data.
data_xds_list = preprocess_xds_list(data_xds_list, ms_opts)
data_xds_list = assign_parangle_data(ms_opts.path, data_xds_list)

# Make a list of datasets containing the parallactic angles as these
# can be expensive to compute and may be used several times. NOTE: At
Expand Down
7 changes: 5 additions & 2 deletions quartical/gains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from quartical.gains.tec_and_offset import TecAndOffset
from quartical.gains.rotation import Rotation
from quartical.gains.rotation_measure import RotationMeasure
from quartical.gains.crosshand_phase import CrosshandPhase
from quartical.gains.crosshand_phase import CrosshandPhase, CrosshandPhaseNullV
from quartical.gains.leakage import Leakage
from quartical.gains.delay_and_tec import DelayAndTec
from quartical.gains.parallactic_angle import ParallacticAngle


TERM_TYPES = {
Expand All @@ -22,6 +23,8 @@
"rotation": Rotation,
"rotation_measure": RotationMeasure,
"crosshand_phase": CrosshandPhase,
"crosshand_phase_null_v": CrosshandPhaseNullV,
"leakage": Leakage,
"delay_and_tec": DelayAndTec
"delay_and_tec": DelayAndTec,
"parallactic_angle": ParallacticAngle
}
2 changes: 1 addition & 1 deletion quartical/gains/amplitude/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_param_names(cls, correlations):

return [f"amplitude_{c}" for c in param_corr]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
10 changes: 9 additions & 1 deletion quartical/gains/crosshand_phase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
crosshand_phase_solver,
crosshand_params_to_gains
)
from quartical.gains.crosshand_phase.null_v_kernel import (
null_v_crosshand_phase_solver
)
from quartical.gains.general.flagging import (
apply_gain_flags_to_gains,
apply_param_flags_to_params
Expand Down Expand Up @@ -39,7 +42,7 @@ def make_param_names(cls, correlations):

return [f"crosshand_phase_{c}" for c in param_corr]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand All @@ -54,3 +57,8 @@ def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
apply_gain_flags_to_gains(gain_flags, gains)

return gains, gain_flags, params, param_flags


class CrosshandPhaseNullV(CrosshandPhase):

solver = staticmethod(null_v_crosshand_phase_solver)
Loading