from __future__ import annotations
import collections
import collections.abc
import itertools
import os
import re
import shutil
import string
from typing import TYPE_CHECKING
import numpy as np
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.utils as utils
import pipeline.infrastructure.vdp as vdp
import pipeline.infrastructure.sessionutils as sessionutils
from pipeline.domain import DataTable
from pipeline.h.tasks.flagging import flagdeterbase
from pipeline.infrastructure import casa_tools
from pipeline.infrastructure import casa_tasks
from pipeline.infrastructure import task_registry
from pipeline.infrastructure.displays import pointing
from pipeline.hsd.heuristics.pointing_outlier import PointingOutlierHeuristics
from pipeline.hsd.tasks.common.flagcmd_util import datatable_rowid_to_timerange
if TYPE_CHECKING:
from collections.abc import Generator
from pipeline.domain import SpectralWindow
from pipeline.infrastructure import Context
LOG = infrastructure.logging.get_logger(__name__)
PointingOutlierStats = collections.namedtuple(
"PointingOutlierStats",
["cx", "cy", "median_distance", "factor",
"outliers", "timerange", "separations"]
)
class FlagDeterALMASingleDishInputs(flagdeterbase.FlagDeterBaseInputs):
"""
FlagDeterALMASingleDishInputs defines the inputs for the FlagDeterALMASingleDish pipeline task.
"""
parallel = sessionutils.parallel_inputs_impl()
autocorr = vdp.VisDependentProperty(default=False)
edgespw = vdp.VisDependentProperty(default=True)
fracspw = vdp.VisDependentProperty(default='1.875GHz')
fracspwfps = vdp.VisDependentProperty(default=0.048387)
@vdp.VisDependentProperty
def intents(self) -> str:
"""Define default list of intents to be flagged."""
# return just the unwanted intents that are present in the MS
intents_to_flag = {'POINTING', 'FOCUS', 'ATMOSPHERE', 'SIDEBAND',
'UNKNOWN', 'SYSTEM_CONFIGURATION', 'CHECK'}
return ','.join(self.ms.intents.intersection(intents_to_flag))
template = vdp.VisDependentProperty(default=True)
@flagdeterbase.FlagDeterBaseInputs.filetemplate.postprocess
def filetemplate(self, unprocessed: str | list[str]) -> str:
"""Post-process filetemplate.
This ensures filetemplate value is string.
Args:
unprocessed: Unprocessed value of filetemplate.
Returns:
String value of filetemplate.
"""
if isinstance(unprocessed, list) and len(unprocessed) == 1:
value = unprocessed[0]
else:
value = unprocessed
return value
pointing = vdp.VisDependentProperty(default=True)
incompleteraster = vdp.VisDependentProperty(default=True)
@vdp.VisDependentProperty
def filepointing(self) -> str:
"""Define defualt name of pointing flag file.
Returns:
Default name of pointing flag file.
"""
vis_root = os.path.splitext(self.vis)[0]
return vis_root + '.flagpointing.txt'
# New property for QA0 / QA2 flags
qa0 = vdp.VisDependentProperty(default=True)
qa2 = vdp.VisDependentProperty(default=True)
# docstring and type hints: supplements hsd_flagdata
def __init__(self,
context: Context,
vis: list[str] | None = None,
output_dir: str | None = None,
flagbackup: str | bool | None = None,
autocorr: str | bool | None = None,
shadow: str | bool | None = None,
scan: str | bool | None = None,
scannumber: str | None = None,
intents: str | None = None,
edgespw: str | bool | None = None,
fracspw: str | None = None,
fracspwfps: str | float | None = None,
online: str | bool | None = None,
fileonline: str | None = None,
template: str | bool | None = None,
filetemplate: str | None = None,
pointing: str | bool | None = None,
filepointing: str | None = None,
incompleteraster: str | bool | None = None,
hm_tbuff: str | None = None,
tbuff: str | float | None = None,
qa0: str | bool | None = None,
qa2: str | bool | None = None,
parallel: str | bool | None = None):
"""Construct FlagDeterALMASingleDishInputs instance.
Args:
context: Pipeline context object containing state information.
vis: The list of input MeasurementSets. Defaults to the list of
MeasurementSets defined in the pipeline context.
output_dir: Output directory.
flagbackup: Back up any pre-existing flags before applying new ones.
Default: None (equivalent to True)
autocorr: Flag autocorrelation data.
Default: None (equivalent to False)
shadow: Flag shadowed antennas.
Default: None (equivalent to True)
scan: Flag a list of scans and intents specified by scannumber and intents.
Default: None (equivalent to True)
scannumber: A string containing a comma delimited list of scans to be flagged.
Default: None (equivalent to '')
intents: A string containing a comma delimited list of intents
against which the scans to be flagged are matched.
Defaults to intents that are not relevant to pipeline processing.
Example: `'*BANDPASS*'`
edgespw: Flag the edge spectral window channels.
Default: None (equivalent to True)
fracspw: Fraction of the baseline correlator TDM edge channels to be flagged.
Default: None (equivalent to 0.03125)
fracspwfps: Fraction of the ACA correlator TDM edge channels to be flagged.
Default: None (equivalent to 0.048387)
online: Apply the online flags.
Default: None (equivalent to True)
fileonline: File containing the online flags. These are computed
by the h_init or hsd_importdata data tasks. If the online flags files
are undefined a name of the form 'msname.flagonline.txt' is assumed.
template: Apply a flagging template.
Default: None (equivalent to True)
filetemplate: The name of a text file that contains the flagging
template for RFI, birdies, telluric lines, etc. If the
template flags files is undefined a name of the form
'msname.flagtemplate.txt' is assumed.
pointing: Apply a flagging template for pointing flag.
Default: None (equivalent to True)
filepointing: The name of a text file that contains the flagging
template for pointing flag. If the template flags files is
undefined a name of the form 'msname.flagpointing.txt' is assumed.
incompleteraster: Apply commands to flag incomplete raster sequence.
If this is False, relevant commands in filepointing are
simply commented out.
Default: None (equivalent to True)
hm_tbuff: The heuristic for computing the default time interval
padding parameter. The options are 'halfint' and 'manual'.
In 'halfint' mode tbuff is set to half the maximum of the
median integration time of the science and calibrator target
observations.
Default: None (equivalent to 'halfint')
tbuff: The time in seconds used to pad flagging command time
intervals if hm_tbuff='manual'.
Default: None (equivalent to 0.0)
qa0: QA0 flags
qa2: QA2 flags
parallel: Execute using CASA HPC functionality, if available.
Options: 'automatic', 'true', 'false', True, False
Default: None (equivalent to 'automatic')
"""
super().__init__(
context, vis=vis, output_dir=output_dir, flagbackup=flagbackup, autocorr=autocorr, shadow=shadow, scan=scan,
scannumber=scannumber, intents=intents, edgespw=edgespw, fracspw=fracspw, fracspwfps=fracspwfps,
online=online, fileonline=fileonline, template=template, filetemplate=filetemplate, hm_tbuff=hm_tbuff,
tbuff=tbuff)
# solution parameters
self.qa0 = qa0
self.qa2 = qa2
# pointing flag
self.pointing = pointing
self.filepointing = filepointing
self.incompleteraster = incompleteraster
# Tier-0 parallelization
self.parallel = parallel
def to_casa_args(self):
# Initialize the arguments from the inherited
# FlagDeterBaseInputs() class
task_args = super().to_casa_args()
# Return the tflagdata task arguments
return task_args
class FlagDeterALMASingleDishResults(flagdeterbase.FlagDeterBaseResults):
def __init__(
self,
summaries: list[dict],
flagcmds: list[str],
pointing_outlier_stats: dict[tuple[int, int], PointingOutlierStats]):
"""Initialize results object for hsd_flagdata.
Args:
summaries: List of flagging summaries.
flagcmds: List of flagging commands
pointing_outlier_stats: Statistics of pointing outliers.
"""
super().__init__(summaries, flagcmds)
self.pointing_outlier_stats = pointing_outlier_stats
def merge_with_context(self, context):
# call parent's method
super().merge_with_context(context)
# regenerate pointing plots
if not basetask.DISABLE_WEBLOG:
ephem_names = casa_tools.measures.listcodes(casa_tools.measures.direction())['extra']
valid_ephem_names = [x for x in ephem_names if x != 'COMET']
LOG.info('Regenerate pointing plots to update flag information')
msobj = context.observing_run.get_ms(self.inputs['vis'])
task = pointing.SingleDishPointingChart(context, msobj)
for antenna in msobj.antennas:
for target, reference in msobj.calibration_strategy['field_strategy'].items():
LOG.debug('target field id %s / reference field id %s' % (target, reference))
task.plot(revise_plot=True, antenna=antenna, target_field_id=target,
reference_field_id=reference, target_only=True)
task.plot(revise_plot=True, antenna=antenna, target_field_id=target,
reference_field_id=reference, target_only=False)
# if the target is ephemeris, offset pointing pattern should also be plotted
target_field = msobj.fields[target]
source_name = target_field.source.name
offset_pointings = []
if source_name.upper() in valid_ephem_names:
plotres = task.plot(revise_plot=True, antenna=antenna, target_field_id=target,
reference_field_id=reference, target_only=True, ofs_coord=True)
if plotres is not None:
offset_pointings.append(plotres)
def update_flag_pointing(filename: str, flag_incomplete_raster: bool):
"""Disable "uniform_image_rms" flag commands if necessary.
Args:
filename: Name of the flag commands file.
flag_incomplete_raster: Set True to disable "uniform_image_rms"
flag commands.
"""
tmpfile = filename + '.bak'
try:
shutil.copy(filename, tmpfile)
reason = "reason='SDPL:uniform_image_rms'"
with open(filename, 'r') as f:
if flag_incomplete_raster is True:
# uncomment commands
gen = map(
lambda x: x.lstrip('#') if x.find(reason) != -1 and x.startswith('#') else x, f
)
else:
LOG.info(f'Disabling flag commands for reason "{reason}')
# comment out commands
gen = map(
lambda x: f'#{x}' if x.find(reason) != -1 and not x.startswith('#') else x, f
)
lines = list(gen)
with open(filename, 'w') as f:
f.writelines(lines)
except Exception:
shutil.copy(tmpfile, filename)
finally:
if os.path.exists(tmpfile):
os.remove(tmpfile)
class SerialFlagDeterALMASingleDish(flagdeterbase.FlagDeterBase):
# Make the member functions of the FlagDeterALMASingleDishInputs() class member
# functions of this class
Inputs = FlagDeterALMASingleDishInputs
# Flag edge channels if bandwidth exceeds bandwidth_limit
# Currently, default bandwidth limit is set to 1.875GHz but it is
# controllable via parameter 'fracspw'
@property
def bandwidth_limit(self):
if isinstance(self.inputs.fracspw, str):
return casa_tools.quanta.convert(self.inputs.fracspw, 'Hz')['value']
else:
return 1.875e9 # 1.875GHz
def prepare(self) -> FlagDeterALMASingleDishResults:
"""Generates results object."""
if self.inputs.pointing:
# save flag status before flagging
self._execute_flagmanager(mode='save')
try:
# pre-apply deterministic flagging
results = self._apply_deterministic_flagging()
# run pointing outlier heuristic for each target field
outlier_stats = self._detect_pointing_outliers()
# check if pointing outliers are detected or not
if len(outlier_stats) > 0:
for (field_id, antenna_id), stats in outlier_stats.items():
field = self.inputs.ms.get_fields(field_id=field_id)[0]
antenna = self.inputs.ms.get_antenna(str(antenna_id))[0]
LOG.warning(
'[pointing_outlier_flagged] '
'Pointing outliers are detected in "%s", '
'Field "%s", Antenna "%s": time range "%s", '
'max separation %.2f deg',
self.inputs.vis, field.name, antenna.name,
', '.join(stats.timerange), np.max(stats.separations)
)
if self.inputs.pointing:
# if outlier exists, update pointing flag file
self._append_outlier_flagcmd_to_flagpoinging_file(
outlier_stats
)
# restore flag status
self._execute_flagmanager(mode='restore')
# apply deterministic flagging and update datatable again
results = self._apply_deterministic_flagging()
finally:
if self.inputs.pointing:
# delete flag state for internal use
self._execute_flagmanager(mode='delete')
return FlagDeterALMASingleDishResults(
results.summaries,
results.flagcmds(),
pointing_outlier_stats=outlier_stats
)
def _apply_deterministic_flagging(self) -> flagdeterbase.FlagDeterBaseResults:
"""Apply deterministic flagging.
It also updates the datatable.
Returns:
Results object of the base class.
"""
results = super().prepare()
# update datatable
self._update_datatable()
return results
def _execute_flagmanager(self, mode: str):
"""Run flagmanager to save/restore/delete flag status.
Args:
mode: Execution mode. Should be either
'save', 'restore', or 'delete'.
"""
flagversion_name = 'SDPL_hsd_flagdata_internal_before_flagging'
args_save = {
'vis': self.inputs.vis,
'mode': mode,
'versionname': flagversion_name
}
job = casa_tasks.flagmanager(**args_save)
self._executor.execute(job)
def _update_datatable(self):
"""Update flag information in the datatable.
This is necessary to make datatable consistent with the MS.
"""
origin_basename = os.path.basename(self.inputs.ms.origin_ms)
table_name = os.path.join(
self.inputs.context.observing_run.ms_datatable_name,
origin_basename
)
datatable = DataTable(name=table_name, readonly=False)
datatable._update_flag(self.inputs.ms.origin_ms)
datatable.exportdata(minimal=False)
def _detect_pointing_outliers(self) -> dict[tuple[int, int], PointingOutlierStats]:
outlier_stats = {}
origin_basename = os.path.basename(self.inputs.ms.origin_ms)
table_name = os.path.join(
self.inputs.context.observing_run.ms_datatable_name,
origin_basename
)
msname = self.inputs.ms.basename
datatable = DataTable(name=table_name, readonly=True)
antennas = self.inputs.ms.get_antenna()
target_fields = self.inputs.ms.get_fields(intent="TARGET")
ra_all = datatable.getcol("SHIFT_RA")
dec_all = datatable.getcol("SHIFT_DEC")
srctype_all = datatable.getcol("SRCTYPE")
field_id_all = datatable.getcol("FIELD_ID")
antenna_id_all = datatable.getcol("ANTENNA")
flag_all = np.any(datatable.getcol("FLAG_PERMANENT")[:, 3] == 1, axis=0)
valid_on_source = np.logical_and(srctype_all == 0, flag_all == 1)
heuristic = PointingOutlierHeuristics()
for field, antenna in itertools.product(target_fields, antennas):
# data selection
field_antenna_flag = np.logical_and(
field_id_all == field.id,
antenna_id_all == antenna.id
)
selection = np.logical_and(valid_on_source, field_antenna_flag)
rows = np.where(selection)[0]
ra = ra_all[selection]
dec = dec_all[selection]
# run heuristic
heuristic_result = heuristic(field.frame, ra, dec)
outlier_mask = np.where(np.logical_not(heuristic_result.mask))[0]
if len(outlier_mask) > 0:
outliers = rows[outlier_mask]
LOG.info(
'MS "%s" field "%s" antenna "%s": %d pointing outliers detected',
msname, field.name, antenna.name, len(outliers)
)
LOG.debug("Outliers: %s", outliers)
separations = heuristic_result.dist[outlier_mask]
timerange_list = datatable_rowid_to_timerange(
datatable, rows[outlier_mask]
)
outliers_for_field = PointingOutlierStats(
heuristic_result.cx, heuristic_result.cy,
heuristic_result.med_dist, heuristic_result.factor,
outliers, timerange_list, separations
)
outlier_stats[(field.id, antenna.id)] = outliers_for_field
else:
LOG.debug(
'MS "%s" field "%s" antenna "%s": no outliers detected',
msname, field.name, antenna.name
)
return outlier_stats
def _append_outlier_flagcmd_to_flagpoinging_file(
self,
outlier_stats: dict[tuple[int, int], PointingOutlierStats]):
reason = "pointing_outlier"
cmd_template = string.Template(
"mode='manual' field='$field_id' antenna='$antenna_id&&&' timerange='$timerange'"
f" reason='SDPL:{reason}'\n"
)
# generate flag commands
flagcmds = []
for (field_id, antenna_id), stats in outlier_stats.items():
timerange_list = stats.timerange
for timerange in timerange_list:
flagcmds.append(
cmd_template.safe_substitute(
field_id=field_id,
antenna_id=antenna_id,
timerange=timerange
)
)
# append flag commands only when flagpointing file exists
if os.path.exists(self.inputs.filepointing):
with open(self.inputs.filepointing, 'a') as f:
for cmd in flagcmds:
f.write(cmd)
def _yield_edge_spw_cmds(self) -> Generator[str, None, None]:
"""Yield flag commands to flag edge channels.
Yields:
flag command string to flag edge channels.
"""
inputs = self.inputs
# loop over the spectral windows, generate a flagging command for each
# spw in the ms. Calling get_spectral_windows() with no arguments
# returns just the science windows, which is exactly what we want.
for spw in inputs.ms.get_spectral_windows():
try:
# test that this spw should be flagged by assessing number of
# correlations, TDM/FDM mode etc.
self.verify_spw(spw)
except ValueError as e:
# this spw should not be or is incapable of being flagged
LOG.debug(str(e))
continue
# get fraction of spw to flag from template function
fracspw_org = inputs.fracspw
try:
fracspw_list = []
for _frac in fracspw_org:
inputs.fracspw = _frac
fracspw_list.append(self.get_fracspw(spw))
finally:
inputs.fracspw = fracspw_org
if len(fracspw_list) == 0:
continue
elif len(fracspw_list) == 1:
fracspw_list.append(fracspw_list[0])
# If the twice the number of flagged channels is greater than the
# number of channels for a given spectral window, skip it.
# frac_chan = int(utils.round_half_up(fracspw * spw.num_channels + 0.5))
# Make rounding less agressive
frac_chan_list = [int(utils.round_half_up(x * spw.num_channels)) for x in fracspw_list][:2]
if sum(frac_chan_list) >= spw.num_channels:
LOG.debug('Too many flagged channels %s for spw %s '
'' % (spw.num_channels, spw.id))
continue
# calculate the channel ranges to flag. No need to calculate the
# left minimum as it is always channel 0.
l_max = frac_chan_list[0] - 1
# r_min = spw.num_channels - frac_chan - 1
# Fix asymmetry
r_min = spw.num_channels - frac_chan_list[1]
r_max = spw.num_channels - 1
# state the spw and channels to flag in flagdata format, adding
# the statement to the list of flag commands
def yield_channel_ranges():
if l_max >= 0:
yield '0~{0}'.format(l_max)
if r_max >= r_min:
yield '{0}~{1}'.format(r_min, r_max)
channel_ranges = list(yield_channel_ranges())
if len(channel_ranges) == 0:
continue
cmd = '{0}:{1}'.format(spw.id, ';'.join(channel_ranges))
LOG.debug('list type edge fraction specification for spw %s' % spw.id)
LOG.debug('cmd=\'%s\'' % cmd)
yield cmd
def _get_edgespw_cmds(self) -> list[str]:
"""Construct and return list of flag commands.
Returned list contains flag commands for edge channel flagging.
Returns:
List of flag commands for edge channel flag.
"""
inputs = self.inputs
if isinstance(inputs.fracspw, float) or isinstance(inputs.fracspw, str):
to_flag = super()._get_edgespw_cmds()
elif isinstance(inputs.fracspw, collections.abc.Iterable):
# inputs.fracspw is iterable indicating that the user want to flag
# edge channels with different fractions/number of channels for
# left and right edges
# to_flag is the list to which flagging commands will be appended
to_flag = list(self._yield_edge_spw_cmds())
return to_flag
def get_fracspw(self, spw: SpectralWindow) -> float:
"""Get fraction of total number of spw channels that are to be flagged on each side of the spw.
Args:
spw: SpectralWindow domain object for target spw.
Returns:
Fraction of number of channels to be flagged.
"""
# override the default fracspw getter with our ACA-aware code
# if spw.num_channels in (62, 124, 248):
# return self.inputs.fracspwfps
# else:
# return self.inputs.fracspw
if isinstance(self.inputs.fracspw, float):
return self.inputs.fracspw
elif isinstance(self.inputs.fracspw, str):
LOG.debug('bandwidth limited edge flagging for spw %s' % spw.id)
bandwidth_limit = self.bandwidth_limit
bandwidth = float(spw.bandwidth.value)
fracspw = 0.5 * (bandwidth - bandwidth_limit) / bandwidth
LOG.debug('fraction is %s' % fracspw)
return max(0.0, fracspw)
def verify_spw(self, spw: SpectralWindow):
"""Test if given spw needs to be processed by edgespw flagging.
Args:
spw: SpectralWindow domain object for target spw.
Raises:
ValueError: Bandwidth of the spw is less than bandwidth limit.
"""
# override the default verifier, adding bandwidth check
super().verify_spw(spw)
# Skip if TDM mode where TDM modes are defined to be modes with
# <= 256 channels per correlation
# dd = self.inputs.ms.get_data_description(spw=spw)
# ncorr = len(dd.corr_axis)
# if ncorr * spw.num_channels > 256:
# raise ValueError('Skipping edge flagging for FDM spw %s' % spw.id)
# Skip if edge channel flagging is based on bandwidth limit, and
# bandwidth is less than bandwidth limit
if isinstance(self.inputs.fracspw, str) and spw.bandwidth.value <= self.bandwidth_limit:
raise ValueError('Skipping edge flagging for spw %s' % spw.id)
def _get_flag_commands(self) -> list[str]:
"""
Edit flag commands so that all summaries are based on target data instead of total.
"""
flag_cmds = super()._get_flag_commands()
# PIPE-646 & PIPE-647
# apply flag commands in flagpointing.txt
if self.inputs.pointing:
if not os.path.exists(self.inputs.filepointing):
LOG.warning(
'Pointing flag file \'{}\' was not found. Pointing '
'flagging for {} disabled.'
.format(self.inputs.filepointing, self.inputs.ms.basename)
)
else:
update_flag_pointing(self.inputs.filepointing, self.inputs.incompleteraster)
pointing_cmds = self._read_flagfile(self.inputs.filepointing)
pointing_cmds.append("mode='summary' name='pointing' reason='SDPL:missing_pointing_data'")
# insert flag commands between shadow and edgespw
idx = [i for i, c in enumerate(flag_cmds) if re.search(r"(mode|name)='shadow'", c)]
assert len(idx) > 0
sep = idx[-1] + 1
flag_cmds = flag_cmds[:sep] + pointing_cmds + flag_cmds[sep:]
for i in range(len(flag_cmds)):
if flag_cmds[i].startswith("mode='summary'"):
flag_cmds[i] += " intent='OBSERVE_TARGET#ON_SOURCE'"
return flag_cmds
[docs]
@task_registry.set_equivalent_casa_task('hsd_flagdata')
@task_registry.set_casa_commands_comment(
'Flags generated by the online telescope software, by the QA0 process, and manually set by the pipeline user.'
)
class FlagDeterALMASingleDish(sessionutils.ParallelTemplate):
Inputs = FlagDeterALMASingleDishInputs
Task = SerialFlagDeterALMASingleDish