Skip to content

Commit

Permalink
Add beginnings of baseline-based corrections. (#235)
Browse files Browse the repository at this point in the history
* Add beginnings of baseline-based corrections. More for analysis than application.

* Fix bug in baseline corrections.
  • Loading branch information
JSKenyon authored Mar 3, 2023
1 parent 1f609c5 commit c3d17fc
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 3 deletions.
27 changes: 26 additions & 1 deletion quartical/calibration/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
make_net_xds_list,
populate_net_xds_list)
from quartical.interpolation.interpolate import load_and_interpolate_gains
from quartical.gains.baseline import (compute_baseline_corrections,
apply_baseline_corrections)
from loguru import logger # noqa
from collections import namedtuple

Expand Down Expand Up @@ -163,6 +165,17 @@ def add_calibration_graph(
else:
net_xds_lod = []

if output_opts.compute_baseline_corrections:
bl_corr_xds_list = compute_baseline_corrections(
data_xds_list,
gain_xds_lod,
t_map_list,
f_map_list,
d_map_list
)
else:
bl_corr_xds_list = None

# Update the data xarray.Datasets with visibility outputs.
data_xds_list = make_visibility_output(
data_xds_list,
Expand All @@ -173,8 +186,20 @@ def add_calibration_graph(
output_opts
)

if output_opts.apply_baseline_corrections:
data_xds_list = apply_baseline_corrections(
data_xds_list,
bl_corr_xds_list
)

# Return the resulting graphs for the gains and updated xds.
return gain_xds_lod, net_xds_lod, data_xds_list, stats_xds_list
return (
gain_xds_lod,
net_xds_lod,
data_xds_list,
stats_xds_list,
bl_corr_xds_list
)


def make_visibility_output(data_xds_list, solved_gain_xds_lod, t_map_list,
Expand Down
6 changes: 6 additions & 0 deletions quartical/config/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class Outputs(Input):
apply_p_jones_inv: bool = False
subtract_directions: Optional[List[int]] = None
net_gains: Optional[List[Any]] = None
compute_baseline_corrections: bool = False
apply_baseline_corrections: bool = False

def __post_init__(self):
self.validate_choice_fields()
Expand All @@ -131,6 +133,10 @@ def __post_init__(self):
"Must be strictly a list or list of lists.")
# In the non-nested case, introduce outer list (consistent).
self.net_gains = [self.net_gains]
if self.apply_baseline_corrections:
assert self.compute_baseline_corrections, \
("output.compute_baseline_corrections must be enabled if "
"output.apply_baseline corrections is enabled.")


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions quartical/config/helpstrings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ output:
Which model directions to subtract when generating residuals. Must be
specified as a list of integers e.g. [0, 5, 7]. The default will subtract
all directions.
compute_baseline_corrections:
Enable or disable computation of baseline-based corrections. Functionality
is currently limited to a solution per-channel, per-chunk. These solutions
are useful for analysis and are stored in output.gain_directory.
apply_baseline_corrections:
Enable or disable application of baseline-based corrections. Extreme
caution advised - this can and will lead to overfitting.


mad_flags:
Expand Down
13 changes: 11 additions & 2 deletions quartical/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from quartical.flagging.flagging import finalise_flags, add_mad_graph
from quartical.scheduling import install_plugin
from quartical.gains.datasets import write_gain_datasets
from quartical.gains.baseline import write_baseline_datasets
from quartical.utils.dask import compute_context


Expand Down Expand Up @@ -128,7 +129,11 @@ def _execute(exitstack):
output_opts
)

gain_xds_lod, net_xds_lod, data_xds_list, stats_xds_list = cal_outputs
(gain_xds_lod,
net_xds_lod,
data_xds_list,
stats_xds_list,
bl_corr_xds_list) = cal_outputs

if mad_flag_opts.enable:
data_xds_list = add_mad_graph(data_xds_list, mad_flag_opts)
Expand All @@ -149,6 +154,9 @@ def _execute(exitstack):
net_xds_lod,
output_opts)

bl_corr_writes = write_baseline_datasets(bl_corr_xds_list,
output_opts)

logger.success("{:.2f} seconds taken to build graph.", time.time() - t0)

logger.info("Computation starting. Please be patient - log messages will "
Expand All @@ -158,10 +166,11 @@ def _execute(exitstack):

with compute_context(dask_opts, output_opts, time_str):

_, _, stats_xds_list = dask.compute(
_, _, stats_xds_list, _ = dask.compute(
ms_writes,
gain_writes,
stats_xds_list,
bl_corr_writes,
num_workers=dask_opts.threads,
optimize_graph=True,
scheduler=dask_opts.scheduler
Expand Down
Loading

0 comments on commit c3d17fc

Please sign in to comment.