from __future__ import annotations
import dataclasses
import os
import re
import numpy as np
import pipeline.extern.tsys_contamination as extern
import pipeline.h.tasks.tsysflag.tsysflag as tsysflag
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.sessionutils as sessionutils
import pipeline.infrastructure.vdp as vdp
from pipeline.extern.TsysDataClassFile import TsysData
from pipeline.h.tasks.common import calibrationtableaccess as caltableaccess
from pipeline.h.tasks.tsysflag.resultobjects import TsysflagResults
from pipeline.infrastructure import task_registry
from pipeline.infrastructure.basetask import StandardTaskTemplate
from pipeline.infrastructure.exceptions import PipelineException
from pipeline.infrastructure.pipelineqa import QAScore, TargetDataSelection
__all__ = ["TsysFlagContamination", "TsysFlagContaminationInputs"]
LOG = infrastructure.logging.get_logger(__name__)
# QA score for multi-source, multi-tuning EBs, full polarization EBs, etc.
# that cannot be processed by the heuristic. The intent is that these EBs
# be directed to the QA slow lane for inspection and potential manual
# flagging of line contamination
UNPROCESSABLE_DATA_QA_SCORE = 0.6
# short summary message when heuristic sees data is it not validated for
UNPROCESSABLE_DATA_QA_SHORTMSG = 'Heuristics not applied'
class TsysFlagContaminationInputs(vdp.StandardInputs):
"""
TsysFlagContaminationInputs defines the inputs for the TsysFlagContamination
pipeline task.
The continue_on_failure parameter sets the task behaviour when an
exception is raised by the heuristic. When continue_on_failure = True, the
failure is logged and a QA score of -0.1 recorded, but when set to False,
the failure is handled by the pipeline's standard task failure mechanisms,
where no QA score is recorded and the pipeline run terminates.
Heuristic parameters specific to this task are:
- remove_n_extreme: defaults to 2
- relative_detection_factor: defaults to 0.005
- diagnostic_plots: include diagnostic plots in the weblog. Defaults to
True.
- continue_on_failure: continue after heuristic failure (True) or terminate
pipeline run. (False).
"""
@vdp.VisDependentProperty
def caltable(self):
caltables = self.context.callibrary.active.get_caltable(caltypes="tsys")
# return just the tsys table that matches the vis being handled
result = None
for name in caltables:
# Get the tsys table name
tsystable_vis = caltableaccess.CalibrationTableDataFiller._readvis(name)
if tsystable_vis in self.vis:
result = name
break
return result
@vdp.VisDependentProperty
def filetemplate(self):
vis_root = os.path.splitext(self.vis)[0]
return vis_root + ".flag_tsys_contamination.txt"
@vdp.VisDependentProperty
def logpath(self):
vis_root = os.path.splitext(self.vis)[0]
return vis_root + ".ms_tsys_contamination.txt"
remove_n_extreme = vdp.VisDependentProperty(default=2)
relative_detection_factor = vdp.VisDependentProperty(default=0.005)
diagnostic_plots = vdp.VisDependentProperty(default=True)
continue_on_failure = vdp.VisDependentProperty(default=True)
parallel = sessionutils.parallel_inputs_impl(default=False)
# docstring and type hints: supplements hifa_tsysflagcontamination
def __init__(
self,
context,
output_dir=None,
vis=None,
caltable=None,
filetemplate=None,
logpath=None,
remove_n_extreme=None,
relative_detection_factor=None,
diagnostic_plots=None,
continue_on_failure=None,
parallel=None,
):
"""Initialize Inputs.
Args:
context: Pipeline context object containing state information.
output_dir: Output directory.
Defaults to None, which corresponds to the current working directory.
vis: List of input MeasurementSets (Not used).
caltable: List of input Tsys calibration tables.
Default: [] - Use the table currently stored in the pipeline context.
Example: caltable=['X132.ms.tsys.s2.tbl']
filetemplate: output file to which regions to flag will be written
logpath: output file to which heuristic log statements will be
written
remove_n_extreme: expert parameter for contamination heuristic
Default: 2
relative_detection_factor: expert parameter for contamination detection heuristic
Default: 0.005
diagnostic_plots: create diagnostic plots for the line contamination heuristic
Default: True
continue_on_failure: controls whether pipeline execution continues if a failure
occurs in the underlying contamination detection heuristic.
Default: True
parallel: Process multiple MeasurementSets in parallel using the casampi parallelization framework.
Options: ``'automatic'``, ``'true'``, ``'false'``, ``True``, ``False``
Default: ``None`` (equivalent to ``False``)
"""
super().__init__()
# pipeline inputs
self.context = context
# vis must be set first, as other properties may depend on it
self.vis = vis
self.output_dir = output_dir
# data selection arguments
self.caltable = caltable
self.filetemplate = filetemplate
self.logpath = logpath
self.diagnostic_plots = diagnostic_plots
self.continue_on_failure = continue_on_failure
# heuristic parameter arguments
self.remove_n_extreme = remove_n_extreme
self.relative_detection_factor = relative_detection_factor
self.parallel = parallel
@dataclasses.dataclass
class ExternFunctionArguments:
"""
Adapter class to adapt TsysflagContaminationInputs task inputs class to
the function arguments required by the external heuristic.
"""
vis: str
diagnostic_plots: bool
tsystable: str
remove_n_extreme: float
relative_detection_factor: float
logpath: str
filetemplate: str
pl_run_dir: str
plots_path: str
single_polarization: bool
@staticmethod
def from_inputs(inputs: TsysFlagContaminationInputs) -> ExternFunctionArguments:
context = inputs.context
weblog_dir = os.path.join(context.report_dir, f"stage{context.task_counter}")
os.makedirs(weblog_dir, exist_ok=True)
num_polarizations = {
inputs.ms.get_data_description(spw=spw.id).num_polarizations
for spw in inputs.ms.get_spectral_windows(intent="TARGET")
}
single_polarization = all(n == 1 for n in num_polarizations)
return ExternFunctionArguments(
vis=inputs.vis,
diagnostic_plots=inputs.diagnostic_plots,
tsystable=inputs.caltable,
remove_n_extreme=inputs.remove_n_extreme,
relative_detection_factor=inputs.relative_detection_factor,
logpath=inputs.logpath,
filetemplate=inputs.filetemplate,
pl_run_dir=inputs.context.output_dir,
plots_path=weblog_dir,
single_polarization=single_polarization,
)
class SerialTsysFlagContamination(StandardTaskTemplate):
"""
Flag line contamination in the Tsys tables.
This purpose of this class is to call the external flagging heuristic to
generate flagging commands based on line contamination in Tsys tables, and
then pass those flagging commands to the standard h_tsysflag child task in
manual flagging mode.
The bulk of what you see here comes directly from the extern code,
sandwiched between a few lines of code to extract input parameters and
pass them to the heuristic, followed at the end of the method by
wrapping and adapting the results - a list of flagging commands - into
a manual flagging request for the existing h_tsysflag task. The results
of this child task are then captured and adapted so that the QA and
weblog rendering code can operate on the results of this line
contamination task.
"""
Inputs = TsysFlagContaminationInputs
def prepare(self):
result = TsysflagResults()
result.vis = self.inputs.vis
result.caltable = self.inputs.caltable
# TODO from PIPE-2009: adding new attributes to the result after
# instance construction isn't great but we don't have time to
# rationalise and refactor the base class right now
result.qascores_from_task = []
# step 1: do not run the heuristic on data we know it cannot handle
preflight_qascores = self._assert_heuristic_preconditions()
result.qascores_from_task.extend(preflight_qascores)
if preflight_qascores:
result.task_incomplete_reason = f"Preconditions for line contamination heuristic not met. See QA scores for details."
result.metric_order = ["manual"] # required for renderer
return result
# step 2: run extern heuristic
extern_fn_args = ExternFunctionArguments.from_inputs(self.inputs)
try:
plot_wrappers, qascores = self._call_extern_heuristic(extern_fn_args)
except Exception as e:
reason = f"Heuristic failed while processing {self.inputs.vis}"
if self.inputs.continue_on_failure:
# soft failure: QA score of -0.1 but do not halt the pipeline
s = QAScore(
score=-0.1,
shortmsg="Heuristic failed",
longmsg=reason,
applies_to=TargetDataSelection(vis={self.inputs.vis}),
)
result.qascores_from_task.append(s)
result.task_incomplete_reason = reason
result.metric_order = ["manual"] # required for renderer
LOG.exception(reason, exc_info=e)
return result
else:
# hard failure: raise exception, let the framework deal with it
raise PipelineException(reason) from e
result.plots = plot_wrappers
result.extern_qascores = qascores
# Step 3: do not flag data for DSB data
# Set manual flagging template to that written by the heuristic unless it's a DSB EB.
if self._contains_dsb():
filetemplate = None
s = QAScore(
score=UNPROCESSABLE_DATA_QA_SCORE,
shortmsg=UNPROCESSABLE_DATA_QA_SHORTMSG,
longmsg=f"Heuristic not applied: DSB data in {self.inputs.vis}.",
applies_to=TargetDataSelection(vis={self.inputs.vis}),
)
result.qascores_from_task.append(s)
else:
filetemplate = self.inputs.filetemplate
# Always run the child task - even for DSB - as the results are required by the Tsyscalflag renderer
child_inputs = tsysflag.Tsysflag.Inputs(
self.inputs.context,
output_dir=self.inputs.output_dir,
vis=self.inputs.vis,
caltable=self.inputs.caltable,
filetemplate=filetemplate,
flag_birdies=False,
flag_derivative=False,
flag_edgechans=False,
flag_fieldshape=False,
flag_nmedian=False,
flag_toomany=False,
fnm_byfield=False,
normalize_tsys=False,
)
child_task = tsysflag.Tsysflag(child_inputs)
child_result = self._executor.execute(child_task)
result.pool.extend(child_result.pool)
result.final.extend(child_result.final)
result.components.update(child_result.components)
result.summaries = child_result.summaries
result.error.update(child_result.error)
result.metric_order = list(child_result.metric_order)
return result
def analyse(self, result):
return result
def _call_extern_heuristic(self, fn_args: ExternFunctionArguments):
vis = fn_args.vis
diagnostic_plots = fn_args.diagnostic_plots
tsystable = fn_args.tsystable
remove_n_extreme = fn_args.remove_n_extreme
relative_detection_factor = fn_args.relative_detection_factor
logpath = fn_args.logpath
filetemplate = fn_args.filetemplate
pl_run_dir = fn_args.pl_run_dir
plots_path = fn_args.plots_path
single_polarization = fn_args.single_polarization
tsys = TsysData(
tsystable=tsystable,
load_pickle=True,
single_polarization=single_polarization,
)
line_contamination_intervals, warnings_list, plot_wrappers, qascores = (
extern.get_tsys_contaminated_intervals(
tsys,
plot=diagnostic_plots,
remove_n_extreme=remove_n_extreme,
relative_detection_factor=relative_detection_factor,
savefigfile=f"{plots_path}/{vis}.tsyscontamination",
)
)
for k, v in line_contamination_intervals.copy().items():
if np.sum(np.array([len(vv) for vv in v.values()])) == 0:
del line_contamination_intervals[k]
[intents, scans, fields] = [
tsys.tsysdata[tsys.tsysfields.index(f)] for f in ["intent", "scan", "field"]
]
field_intent_dict = dict({(f, i) for f, i in zip(fields, intents)})
scan_field_dict = dict({(s, f) for s, f in zip(scans, fields)})
field_scanlist_dict = {}
for k, v in scan_field_dict.items():
field_scanlist_dict.setdefault(v, []).append(k)
# end replace
all_freqs_mhz = tsys.specdata[tsys.specfields.index("freq_mhz")]
with open(logpath, "w") as f:
with open(filetemplate, "a") as ft:
pl_run_dir = pl_run_dir
f.write(
f"\n# script version {extern.VERSION} {pl_run_dir}\n# {tsystable}\n"
)
field_contamination = {}
for k in line_contamination_intervals:
m = re.match(r"(?P<spw>[0-9]+)_(?P<field>[0-9]+)", k)
spw, field = m.group(1, 2)
field_contamination.setdefault(np.int64(field), []).append(
np.int64(spw)
)
if len(field_contamination) == 0:
msg = f"## No tsys contamination identified.\n"
f.write(msg)
LOG.info(msg)
# v3.3 large baseline residual
for w in warnings_list:
msg = " ".join(w[:-1])
f.write(f"# {msg}\n")
LOG.info("# %s", msg)
for field in field_contamination:
field = np.int64(field)
spw_ranges = []
spw_ranges_freq = []
for spw in field_contamination[field]:
if field_intent_dict[field] == "bandpass":
continue
key = f"{spw}_{field}"
spw = np.int64(spw)
freqs_ghz = (
all_freqs_mhz[
np.nonzero(
tsys.specdata[tsys.specfields.index("spw")] == spw
)[0][0]
]
/ 1000
)
rs = extern.intervals_to_casa_string(
line_contamination_intervals[key]["tsys_contamination"]
)
rsf = extern.intervals_to_casa_string(
line_contamination_intervals[key]["tsys_contamination"],
scaled_array=freqs_ghz,
unit="GHz",
format=".3f",
)
if rs != "":
spw_ranges.append(f"{spw}:{rs}")
spw_ranges_freq.append(f"{spw}:{rsf}")
if len(spw_ranges) == 0:
continue # v2.2
spw_ranges = ",".join(spw_ranges)
spw_ranges_freq = ",".join(spw_ranges_freq)
contamination_scans = field_scanlist_dict[field]
contamination_scans.sort()
flagline = f"mode='manual' scan='{','.join([str(sc) for sc in contamination_scans])}' spw='{spw_ranges}' reason='Tsys:tsysflag_tsys_channel'\n"
ft.write(flagline)
msg = (
f"## {tsystable}: field={field}, intent={field_intent_dict[field]}\n"
f"# Frequency ranges: '{spw_ranges_freq}' \n"
f"{flagline}"
)
f.write(msg)
LOG.info(msg)
return plot_wrappers, qascores
def _assert_heuristic_preconditions(self) -> list[QAScore]:
"""
Preflight checks to identify data that the heuristic cannot handle.
"""
qa_scores = []
qa_scores.extend(self._assert_not_multisource_multituning())
qa_scores.extend(self._assert_not_full_polarization())
qa_scores.extend(self._assert_bandpass_is_present())
return qa_scores
def _assert_not_multisource_multituning(self) -> list[QAScore]:
"""
Returns a list containing an appropriate QAScore if multitunings are
present, otherwise an empty list is returned.
"""
ms = self.inputs.ms
qa_scores = []
# Algorithm to detect multi-source multi-tuning EBs is:
#
# 1. Identify all Field-Id with CALIBRATE_ATMOSPHERE intents.
# 2. Among those Field-Ids identified above, identify all which are
# associated with OBSERVE_TARGET intent (in other scans)
# 3. If there is more than 1 distinct Field IDs with the above two
# requirements, then the EB is "multi-source"
# 4. If in addition the EB has more than one science spectral spec
# (disregarding pointing spectral specs etc.), then the EB is
# "multi-tuning multi-source" type.
dual_intent_fields = {
target_field
# So, mapping the implementation to the algorithm above, we:
# 1. identify all fields with ATMOSPHERE intent
for tsys_field in ms.get_fields(intent="ATMOSPHERE")
# 2. get scans for that field that are also associated with TARGET
# intent...
for target_scan_on_tsys_field in ms.get_scans(
scan_intent="TARGET", field=tsys_field.id
)
# ... and iterate over the fields for that scan...
for target_field in target_scan_on_tsys_field.fields
# ... but only if this TARGET scan does not also have ATMOSPHERE
# intent, i.e., this has to be TARGET intent in *other* scans
# associated with the Tsys field, not the *same* scan
if "ATMOSPHERE" not in target_scan_on_tsys_field.intents
}
# 3. it's a multi-source EB if more than 1 field meets this requirement
is_multi_source_eb = len(dual_intent_fields) > 1
# 4. If in addition the EB has more than one science spectral spec
# (disregarding pointing spectral specs etc.), then the EB is
# "multi-tuning multi-source" type.
science_specs = {
spw.spectralspec
for spw in ms.get_spectral_windows(science_windows_only=True)
}
is_multi_tuning_eb = len(science_specs) > 1
if is_multi_tuning_eb and is_multi_source_eb:
s = QAScore(
score=UNPROCESSABLE_DATA_QA_SCORE,
shortmsg=UNPROCESSABLE_DATA_QA_SHORTMSG,
longmsg=f"Heuristic not applied: multi-source multi-tuning data present in {ms.basename}.",
applies_to=TargetDataSelection(vis={ms.basename}),
)
qa_scores.append(s)
return qa_scores
def _assert_not_full_polarization(self) -> list[QAScore]:
"""
Returns a list containing an appropriate QAScore if full polarization
data are present, otherwise an empty list is returned.
"""
qa_scores = []
ms = self.inputs.ms
science_spws = ms.get_spectral_windows(science_windows_only=True)
num_corr_axes = {
len(ms.get_data_description(spw=spw.id).corr_axis)
for spw in science_spws
}
if any(n > 2 for n in num_corr_axes):
s = QAScore(
score=UNPROCESSABLE_DATA_QA_SCORE,
shortmsg=UNPROCESSABLE_DATA_QA_SHORTMSG,
longmsg=f"Heuristic not applied: full-polarization data present in {ms.basename}.",
applies_to=TargetDataSelection(vis={ms.basename}),
)
qa_scores.append(s)
return qa_scores
def _assert_bandpass_is_present(self) -> list[QAScore]:
"""
Returns a list containing an appropriate QAScore if BANDPASS data are
missing, otherwise an empty list is returned.
"""
qa_scores = []
ms = self.inputs.ms
# exclude TP (fails by design; needs bandpass intent scan)
if "BANDPASS" not in ms.intents:
s = QAScore(
score=UNPROCESSABLE_DATA_QA_SCORE,
shortmsg=UNPROCESSABLE_DATA_QA_SHORTMSG,
longmsg=f"Heuristic not applied: no bandpass data in {ms.basename}.",
applies_to=TargetDataSelection(vis={ms.basename}),
)
qa_scores.append(s)
return qa_scores
def _contains_dsb(self) -> bool:
"""
Returns True if any science spectral window uses a DSB receiver.
"""
ms = self.inputs.ms
receivers = [
spw.receiver for spw in ms.get_spectral_windows(science_windows_only=True)
]
return "DSB" in receivers
[docs]
@task_registry.set_equivalent_casa_task("hifa_tsysflagcontamination")
@task_registry.set_casa_commands_comment(
"Line contamination in the Tsys tables is detected and flagged."
)
class TsysFlagContamination(sessionutils.ParallelTemplate):
Inputs = TsysFlagContaminationInputs
Task = SerialTsysFlagContamination