"""Spectral baseline subtraction stage."""
from __future__ import annotations
import collections
import os
from typing import TYPE_CHECKING
import numpy
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.callibrary as callibrary
import pipeline.infrastructure.mpihelpers as mpihelpers
import pipeline.infrastructure.vdp as vdp
import pipeline.infrastructure.sessionutils as sessionutils
from pipeline.hsd.tasks.common.inspection_util import generate_ms, inspect_reduction_group, merge_reduction_group
from pipeline.domain import DataType
from pipeline.hsd.heuristics import MaskDeviationHeuristic
from pipeline.hsd.tasks.common.inspection_util import generate_ms, inspect_reduction_group, merge_reduction_group
from pipeline.infrastructure import task_registry
from . import maskline
from . import worker
from .. import common
from ..common import compress
from ..common import utils
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from pipeline.infrastructure.api import Heuristic
from pipeline.infrastructure.launcher import Context
from .typing import FitFunc, FitOrder, LineWindow
# import memory_profiler
LOG = infrastructure.logging.get_logger(__name__)
class SDBaselineInputs(vdp.StandardInputs):
"""Inputs for baseline subtraction task."""
# Search order of input vis
processing_data_type = [DataType.ATMCORR, DataType.REGCAL_CONTLINE_ALL, DataType.RAW]
infiles = vdp.VisDependentProperty(default='', null_input=['', None, [], ['']])
spw = vdp.VisDependentProperty(default='')
pol = vdp.VisDependentProperty(default='')
field = vdp.VisDependentProperty(default='')
linewindow = vdp.VisDependentProperty(default=[])
linewindowmode = vdp.VisDependentProperty(default='replace')
edge = vdp.VisDependentProperty(default=(0, 0))
broadline = vdp.VisDependentProperty(default=True)
fitorder = vdp.VisDependentProperty(default=-1)
fitfunc = vdp.VisDependentProperty(default='cspline')
switchpoly = vdp.VisDependentProperty(default=True)
clusteringalgorithm = vdp.VisDependentProperty(default='hierarchy')
wave_number = vdp.VisDependentProperty(default=None)
deviationmask = vdp.VisDependentProperty(default=True)
deviationmask_sigma_threshold = vdp.VisDependentProperty(default=5.0)
# Synchronization between infiles and vis is still necessary
@vdp.VisDependentProperty
def vis(self) -> str:
"""Return input MS name."""
return self.infiles
# handle conversion from string to bool
@switchpoly.convert
def switchpoly(self, value: str | bool) -> bool:
"""Convert switchpoly value into bool.
Args:
value: any value provided by the user. It should be
boolean value or its string representation
such as 'True' or 'FALSE'.
Return:
boolean value
"""
converted = None
if isinstance(value, bool):
converted = value
elif isinstance(value, str):
for b in (True, False):
s = str(b)
if value in (s.lower(), s.upper(), s.capitalize()):
converted = b
break
assert converted is not None
return converted
# use common implementation for parallel inputs argument
parallel = sessionutils.parallel_inputs_impl()
# docstring and type hints: supplements hsd_baseline
def __init__(self,
context: Context,
infiles: list[str] | None = None,
antenna: list[str] | None = None,
spw: list[str] | None = None,
pol: list[str] | None = None,
field: list[str] | None = None,
linewindow: LineWindow | None = None,
linewindowmode: str | None = None,
edge: tuple[int, int] | None = None,
broadline: bool | None = None,
fitfunc: FitFunc | None = None,
fitorder: FitOrder | None = None,
switchpoly: bool | None = None,
clusteringalgorithm: str | None = None,
wave_number: list[int] | None = None,
deviationmask: bool | None = None,
deviationmask_sigma_threshold: bool | None = None,
parallel: str | None = None) -> None:
"""Construct SDBaselineInputs instance.
Args:
context: Pipeline context
infiles: List of data files. These must be a name of MeasurementSets that are
registered to context via hsd_importdata or hsd_restoredata task.
Example: ``vis=['X227.ms', 'X228.ms']``
Default: ``None`` (process all registered MeasurementSets)
antenna: Data selection by antenna.
Example: '1' (select by ANTENNA_ID), 'PM03' (select by antenna name), '' (all antennas)
Default: ``None`` (equivalent to ``''``)
spw: Data selection by spw.
Example: ``'3,4'`` (process spw 3 and 4), ['0','2'] (spw 0 for first data, 2 for second), ``''`` (all spws)
Default: ``None`` (equivalent to ``''``)
pol: Data selection by polarizations.
Example: ``'0'`` (process pol 0), ``['0~1','0']`` (pol 0 and 1 for first data, only 0 for second), ``''`` (all polarizations)
Default: ``None`` (equivalent to ``''``)
field: Data selection by field.
Example: ``'1'`` (select by FIELD_ID), ``'M100*'`` (select by field name), ``''`` (all fields)
Default: ``None`` (equivalent to ``''``)
linewindow: Pre-defined line window. If this is set, specified line windows are used as
a line mask for baseline subtraction instead to determine masks based on line detection and
validation stage. Several types of format are acceptable. One is channel-based window.
::
[min_chan, max_chan]
where min_chan and max_chan should be an integer. For multiple windows, nested list is
also acceptable.
::
[[min_chan0, max_chan0], [min_chan1, max_chan1], ...]
Another way is frequency-based window.
::
[min_freq, max_freq]
where min_freq and max_freq should be either a float or a string. If float value is given,
it is interpreted as a frequency in Hz. String should be a quantity consisting
of "value" and "unit", e.g., '100GHz'. Multiple windows are also supported.
::
[[min_freq0, max_freq0], [min_freq1, max_freq1], ...]
Note that the specified frequencies are assumed to be the value in LSRK frame.
Note also that there is a limitation when multiple MSes are processed.
If native frequency frame of the data is not LSRK (e.g. TOPO), frequencies need to be
converted to that frame. As a result, corresponding channel range may vary
between MSes. However, current implementation is not able to handle such case.
Frequencies are converted to desired frame using representative MS (time, position,
direction).
In the above cases, specified line windows are applied to all science spws.
In case when line windows vary with spw, line windows can be specified by a dictionary whose
key is spw id while value is line window. For example, the following dictionary gives different
line windows to spws 17 and 19. Other spws, if available, will have an empty line window.
::
{17: [[100, 200], [1200, 1400]], 19: ['112115MHz', '112116MHz']}
Furthermore, linewindow accepts MS selection string. The following string gives
[[100,200],[1200,1400]] for spw 17 while [1000,1500] for spw 21.
::
"17:100~200;1200~1400,21:1000~1500"
The string also accepts frequency with units. Note, however, that frequency reference frame
in this case is not fixed to LSRK. Instead, the frame will be taken from the MS
(typically TOPO for ALMA). Thus, the following two frequency-based line windows
result different channel selections.
::
{19: ['112115MHz', '112116MHz']} # frequency frame is LSRK
"19:11215MHz~11216MHz" # frequency frame is taken from the data (TOPO for ALMA)
None is allowed as a value of dictionary input to indicate that no line detection/validation
is required even if manually specified line window does not exist. When None is given as a value
and if ``linewindowmode`` is 'replace', line detection/validation is not performed
for the corresponding spw. For example, suppose the following parameters are given for the data
with four science spws, 17, 19, 21, and 23.
::
linewindow={17: [112.1e9, 112.2e9], 19: [113.1e9, 113.15e9], 21: None}
linewindowmode='replace'
The task will use given line window for 17 and 19 while the task performs
line deteciton/validation for spw 23 because no line window is set.
On the other hand, line detection/validation is skipped for spw 21 due to the effect of None.
Example: [100,200] (channel), [115e9, 115.1e9] (frequency in Hz), ['115GHz', '115.1GHz'] (see above for more examples)
Default: None
linewindowmode: Merge or replace given manual line window with line detection/validation result.
If 'replace' is given, line detection and validation will not be performed.
On the other hand, when 'merge' is specified, line detection/validation will be performed
and manually specified line windows are added to the result.
Note that this has no effect when linewindow for target spw is an empty list. In that case,
line detection/validation will be performed regardless of the value of linewindowmode.
In case if no linewindow nor line detection/validation are necessary, you should set
linewindowmode to 'replace' and specify None as a value of the linewindow dictionary
for the spw to apply. See parameter description of ``linewindow`` for detail.
edge: Number of edge channels to be dropped from baseline subtraction.
The value must be a list with length of 2, whose values specify left and right edge channels,
respectively.
Example: ``[10,10]``
Default: ``None``
broadline: Try to detect broad component of spectral line if ``True``.
Default: ``None`` (equivalent to ``True``)
fitfunc: Fitting function for baseline subtraction. You can choose either cubic spline
('spline' or 'cspline'), polynomial ('poly' or 'polynomial'), sinusoid.
Accepts:
- A string: Applies the same function to all spectral windows (SPWs).
- A dictionary: Maps SPW IDs (int or str) to a specific fitting function.
If an SPW ID is not present in the dictionary, ``'cspline'`` will be used as the default.
Default: ``None`` (equivalent to ``'cspline'``)
fitorder: Fitting order for polynomial. For cubic spline, it is used to determine how
much the spectrum is segmented into.
Accepts:
- An integer: Applies the same order to all SPWs. Valid values: ``-1`` (automatic), ``0``, or any positive integer.
- A dictionary: Maps SPW IDs (int or str) to a specific fitting order.
If an SPW ID is not present in the dictionary, ``-1`` will be used as the default, triggering automatic order selection.
Default: ``None`` (equivalent to ``-1``)
switchpoly: Whether to fall back the fits from cubic spline to 1st or 2nd order polynomial
when large masks exist at the edges of the spw. Condition for switching is as follows:
- if nmask > nchan/2 => 1st order polynomial
- else if nmask > nchan/4 => 2nd order polynomial
- else => use fitfunc and fitorder
where nmask is a number of channels for mask at edge while
nchan is a number of channels of entire spectral window.
Default: ``None`` (equivalent to ``True``)
clusteringalgorithm: Selection of the algorithm used in the clustering analysis
to check the validity of detected line features. The 'kmean' algorithm,
hierarchical clustering algorithm, 'hierarchy', and their combination ('both')
are so far implemented.
Default: ``None`` (equivalent to ``'hierarchy'``)
wave_number: a list of sinusoidal wave numbers. The maximum wave numbers should not exceed the
((number of channels/2)-1) limit. If the offset is present in the data, add 0 to the number
of waves. That is, nwave=[0] is a constant term, nwave=[0,1,2] fits with a maximum of 2
sinusoids, and so on.
Default: ``None`` (equivalent to ``False``)
deviationmask: Apply deviation mask in addition to masks determined by the automatic line detection.
Default: ``None`` (equivalent to ``True``)
deviationmask_sigma_threshold: Threshold factor (F) to detect the deviation.
Actual threshold will be median + F * standard-deviation of the spectrum.
Default: ``None`` (equivalent to ``5.0``)
parallel: Execute using CASA HPC functionality, if available.
Options: ``'automatic'``, ``'true'``, ``'false'``, ``True``, ``False``.
Default: ``None`` (equivalent to ``'automatic'``).
"""
super().__init__()
self.context = context
self.infiles = infiles
self.antenna = antenna
self.spw = spw
self.pol = pol
self.field = field
self.linewindow = linewindow
self.linewindowmode = linewindowmode
self.edge = edge
self.broadline = broadline
self.fitorder = fitorder
self.fitfunc = fitfunc
self.switchpoly = switchpoly
self.clusteringalgorithm = clusteringalgorithm
self.wave_number = wave_number
self.deviationmask = deviationmask
self.deviationmask_sigma_threshold = deviationmask_sigma_threshold
self.parallel = parallel
def to_casa_args(self) -> dict:
"""Convert Inputs instance into task execution parameter.
Returns:
Task execution parameter as kwargs.
"""
infiles = self.infiles
if isinstance(self.infiles, list):
self.infiles = infiles[0]
args = super().to_casa_args()
self.infiles = infiles
if 'antenna' not in args:
args['antenna'] = ''
return args
class SDBaselineResults(common.SingleDishResults):
"""Results class to hold the result of baseline subtraction task."""
def __init__(self,
task: type[basetask.StandardTaskTemplate] | None = None,
success: bool | None = None,
outcome: Any = None) -> None:
"""Construct SDBaselineResults instance.
Args:
task: Task class that produced the result.
success: Whether task execution is successful or not.
outcome: Outcome of the task execution.
"""
super().__init__(task, success, outcome)
self.out_mses = []
# @utils.profiler
def merge_with_context(self, context: Context) -> None:
"""Merge result instance into context.
Merge of the result instance of baseline subtraction task includes
the following updates to Pipeline context,
- register MeasurementSet domain object for the output of sdbaseline
task to Pipeline context, namely measurement_sets list and
reduction_group
- register detected spectral lines to reduction group
- register deviation mask to each MeasurementSet domain object
Args:
context: Pipeline context object containing state information.
"""
super().merge_with_context(context)
# register output MS domain object and reduction_group to context
target = context.observing_run
for ms in self.out_mses:
# remove existing MS in context if the same MS is already in list.
oldms_index = None
for index, oldms in enumerate(target.get_measurement_sets()):
if ms.name == oldms.name:
oldms_index = index
break
if oldms_index is not None:
LOG.info('Replace {} in context'.format(ms.name))
del target.measurement_sets[oldms_index]
# Adding mses to context
LOG.info('Adding {} to context'.format(ms.name))
target.add_measurement_set(ms)
# Initialize callibrary
calto = callibrary.CalTo(vis=ms.name)
LOG.info('Registering {} with callibrary'.format(ms.name))
context.callibrary.add(calto, [])
# register output MS to processing group
reduction_group = inspect_reduction_group(ms)
merge_reduction_group(target, reduction_group)
# increment iteration counter (only to in MS for now)
# register detected lines to reduction group member (to both in and out MS)
reduction_group = target.ms_reduction_group
for b in self.outcome['baselined']:
group_id = b['group_id']
member_list = b['members']
lines = b['lines']
channelmap_range = b['channelmap_range']
group_desc = reduction_group[group_id]
for (ms, field, ant, spw) in utils.iterate_group_member(reduction_group[group_id], member_list):
group_desc.iter_countup(ms, ant, spw, field)
out_ms = target.get_ms(name=self.outcome['vis_map'][ms.name])
for m in [ms, out_ms]:
group_desc.add_linelist(lines, m, ant, spw, field,
channelmap_range=channelmap_range)
# merge deviation_mask with context
for ms in target.measurement_sets:
if not hasattr(ms, 'deviation_mask'): ms.deviation_mask = None
if 'deviation_mask' in self.outcome:
for (basename, masks) in self.outcome['deviation_mask'].items():
ms = target.get_ms(basename)
ms.deviation_mask = {}
for field in ms.get_fields(intent='TARGET'):
for antenna in ms.antennas:
for spw in ms.get_spectral_windows(science_windows_only=True):
key = (field.id, antenna.id, spw.id)
if key in masks:
ms.deviation_mask[key] = masks[key]
out_ms = target.get_ms(name=self.outcome['vis_map'][ms.name])
out_ms.deviation_mask = ms.deviation_mask
def _outcome_name(self) -> str:
"""Return string summarizing the outcome.
Returns:
summary of the outcome.
"""
return '\n'.join(['Reduction Group {0}: member {1}'.format(b['group_id'], b['members'])
for b in self.outcome['baselined']])
[docs]
@task_registry.set_equivalent_casa_task('hsd_baseline')
@task_registry.set_casa_commands_comment(
'Subtracts spectral baseline by least-square fitting with N-sigma clipping. Spectral lines are automatically '
'detected and examined to determine the region that is masked to protect these features from the fit.\n'
'This stage performs a pipeline calculation without running any CASA commands to be put in this file.'
)
class SDBaseline(basetask.StandardTaskTemplate):
"""Baseline subtraction task."""
Inputs = SDBaselineInputs
is_multi_vis_task = True
# @memory_profiler.profile
[docs]
def prepare(self) -> SDBaselineResults:
"""Perform baseline subtraction.
The method first evaluates deviation mask for each MS if the mask
is not available, then perform line detection by combining all
spectral data, and finally perform baseline subtraction using
sdbaseline task.
Returns:
SDBaselineResults instance that holds list of output MS names
with the map among input MS names, necessary data for
weblog rendering, and the metric representing the quality of
the baseline subtraction.
"""
LOG.debug('Starting SDMDBaseline.prepare')
inputs = self.inputs
context = inputs.context
reduction_group = context.observing_run.ms_reduction_group
vis_list = inputs.vis
args = inputs.to_casa_args()
# Spw selection will accept virtual spw so inputs.spw should be
# (a list of) virtual spw ids. Intention here is to get a list
# of selected *real* spw ids per MeasurementSet and store them
# into args_real_spw as dict.
args_real_spw = utils.convert_spw_virtual2real(context, inputs.spw)
window = inputs.linewindow
windowmode = inputs.linewindowmode
LOG.info('{}: window={}, windowmode={}'.format(self.__class__.__name__, window, windowmode))
edge = inputs.edge
broadline = inputs.broadline
fitorder = 'automatic' if inputs.fitorder is None else inputs.fitorder
fitfunc = inputs.fitfunc
wave_number = inputs.wave_number
switchpoly = inputs.switchpoly
clusteringalgorithm = inputs.clusteringalgorithm
deviationmask = inputs.deviationmask
deviationmask_sigma_threshold = inputs.deviationmask_sigma_threshold
# mkdir stage_dir if it doesn't exist
stage_number = context.task_counter
stage_dir = os.path.join(context.report_dir, "stage%d" % stage_number)
if not os.path.exists(stage_dir):
os.makedirs(stage_dir)
# loop over reduction group
# This is a dictionary for deviation mask that will be merged with top-level context
deviation_mask = collections.defaultdict(dict)
# collection of field, antenna, and spw ids in reduction group per MS
registry = collections.defaultdict(utils.RGAccumulator)
# outcome for baseline subtraction
baselined = []
# dictionary of org_direction
org_directions_dict = {}
LOG.debug('Starting per reduction group processing: number of groups is %d', len(reduction_group))
for (group_id, group_desc) in reduction_group.items():
LOG.info('Processing Reduction Group %s', group_id)
LOG.info('Group Summary:')
for m in group_desc:
# m.spw_id is real spw id
LOG.info(
f'\tMS "{m.ms.basename}" '
f'Antenna "{m.antenna_name}" (ID {m.antenna_id}) '
f'Spw ID {m.spw_id} '
f'Field "{m.field_name}" (ID {m.field_id})'
)
# scan for org_direction and find the first one in group
msobj = context.observing_run.get_ms(m.ms.basename)
field_id = m.field_id
fields = msobj.get_fields(field_id=field_id)
source_name = fields[0].source.name
if source_name not in org_directions_dict:
if fields[0].source.is_eph_obj or fields[0].source.is_known_eph_obj:
org_direction = fields[0].source.org_direction
else:
org_direction = None
org_directions_dict.update({source_name: org_direction})
LOG.info("registered org_direction[{}]={}".format(source_name, org_direction))
# skip channel averaged spw
nchan = group_desc.nchan
LOG.debug('nchan for group %s is %s', group_id, nchan)
if nchan == 1:
first_member = group_desc[0]
virtual_spw = context.observing_run.real2virtual_spw_id(first_member.spw.id, first_member.ms)
LOG.info('Skip channel averaged spw %s.', virtual_spw)
continue
LOG.debug('spw=\'%s\'', args_real_spw)
LOG.debug('vis_list=%s', vis_list)
member_list = numpy.fromiter(
utils.get_valid_ms_members(group_desc, vis_list, args['antenna'], args['field'], args_real_spw),
dtype=numpy.int32)
# skip this group if valid member list is empty
LOG.debug('member_list=%s', member_list)
if len(member_list) == 0:
LOG.info('Skip reduction group %s', group_id)
continue
member_list.sort()
# assume all members have same iteration
first_member = group_desc[member_list[0]]
iteration = first_member.iteration
LOG.debug('iteration for group %s is %s', group_id, iteration)
LOG.info('Members to be processed:')
for (gms, gfield, gant, greal_spw) in utils.iterate_group_member(group_desc, member_list):
LOG.info('\tMS "%s" Field ID %s Antenna ID %s Spw ID %s',
gms.basename, gfield, gant, greal_spw)
# Deviation Mask
# NOTE: deviation mask is evaluated per ms per field per spw
if deviationmask:
LOG.info('Apply deviation mask to baseline fitting')
dvtasks = []
dvparams = collections.defaultdict(lambda: ([], [], []))
for (ms, fieldid, antennaid, real_spwid) in utils.iterate_group_member(group_desc, member_list):
if (not hasattr(ms, 'deviation_mask')) or ms.deviation_mask is None:
ms.deviation_mask = {}
if (fieldid, antennaid, real_spwid) not in ms.deviation_mask:
LOG.debug('Evaluating deviation mask for %s field %s antenna %s spw %s',
ms.basename, fieldid, antennaid, real_spwid)
dvparams[ms.name][0].append(fieldid)
dvparams[ms.name][1].append(antennaid)
dvparams[ms.name][2].append(real_spwid)
else:
deviation_mask[ms.basename][(fieldid, antennaid, real_spwid)] = ms.deviation_mask[(fieldid, antennaid, real_spwid)]
for (vis, params) in dvparams.items():
field_list, antenna_list, real_spw_list = params
dvtasks.append(deviation_mask_heuristic(vis=vis,
field_list=field_list,
antenna_list=antenna_list,
real_spw_list=real_spw_list,
deviationmask_sigma_threshold=deviationmask_sigma_threshold,
consider_flag=True,
parallel=self.inputs.parallel))
for vis, dvtask in dvtasks:
dvmasks = dvtask.get_result()
field_list, antenna_list, real_spw_list = dvparams[vis]
ms = context.observing_run.get_ms(vis)
for field_id, antenna_id, real_spw_id, mask_list in zip(field_list, antenna_list, real_spw_list, dvmasks):
# key: (fieldid, antennaid, real spwid)
key = (field_id, antenna_id, real_spw_id)
LOG.debug('deviation mask: key %s %s %s mask %s', field_id, antenna_id, real_spw_id, mask_list)
ms.deviation_mask[key] = mask_list
deviation_mask[ms.basename][key] = ms.deviation_mask[key]
LOG.debug('evaluated deviation mask is %s', ms.deviation_mask[key])
else:
LOG.info('Deviation mask is disabled by the user')
LOG.debug('deviation_mask=%s', deviation_mask)
# Spectral Line Detection and Validation
# MaskLine will update DataTable.MASKLIST column
maskline_inputs = maskline.MaskLine.Inputs(context, iteration, group_id, member_list,
window, windowmode, edge, broadline, clusteringalgorithm)
maskline_task = maskline.MaskLine(maskline_inputs)
maskline_result = self._executor.execute(maskline_task, merge=False)
grid_table = maskline_result.outcome['grid_table']
if grid_table is None:
LOG.info('Skip reduction group %s', group_id)
continue
compressed_table = compress.CompressedObj(grid_table)
del grid_table
detected_lines = maskline_result.outcome['detected_lines']
channelmap_range = maskline_result.outcome['channelmap_range']
cluster_info = maskline_result.outcome['cluster_info']
flag_digits = maskline_result.outcome['flag_digits']
# register ids to per MS id collection
for i in member_list:
member = group_desc[i]
registry[member.ms].append(member.field_id, member.antenna_id, member.spw_id,
grid_table=compressed_table, channelmap_range=channelmap_range)
# add entry to outcome
baselined.append({'group_id': group_id, 'iteration': iteration,
'members': member_list,
'lines': detected_lines,
'channelmap_range': channelmap_range,
'clusters': cluster_info,
'flag_digits': flag_digits,
'org_direction': org_direction})
# - end of the loop over reduction group
blparam_file = lambda ms: ms.basename.rstrip('/') \
+ '_blparam_stage{stage}.txt'.format(stage=stage_number)
vis_map = {} # key and value are input and output vis name, respectively
plot_list = []
baseline_quality_stat = []
# plot_manager = plotter.BaselineSubtractionPlotManager(self.inputs.context, datatable)
# Generate and apply baseline fitting solutions
vislist = [ms.name for ms in registry]
plan = [registry[ms] for ms in registry]
blparam = [blparam_file(ms) for ms in registry]
deviationmask_list = [deviation_mask[ms.basename] for ms in registry]
edge_list = [edge for _ in registry]
worker_cls = worker.BaselineSubtractionWorker
fitter_inputs = vdp.InputsContainer(worker_cls, context,
vis=vislist, plan=plan,
fit_func=fitfunc,
wave_number=wave_number,
fit_order=fitorder, switchpoly=switchpoly,
edge=edge_list, blparam=blparam,
deviationmask=deviationmask_list,
org_directions_dict=org_directions_dict,
parallel=self.inputs.parallel)
fitter_task = worker_cls(fitter_inputs)
fitter_results = self._executor.execute(fitter_task, merge=False)
# Check if fitting was successful
fitting_failed = False
if isinstance(fitter_results, basetask.FailedTaskResults):
fitting_failed = True
failed_results = basetask.ResultsList([fitter_results])
elif isinstance(fitter_results, basetask.ResultsList) and numpy.any([isinstance(r, basetask.FailedTaskResults) for r in fitter_results]):
fitting_failed = True
failed_results = basetask.ResultsList([r for r in fitter_results
if isinstance(r, basetask.FailedTaskResults)])
if fitting_failed:
for r in failed_results:
r.origtask_cls = self.__class__
return failed_results
results_dict = dict((os.path.basename(r.outcome['infile']), r) for r in fitter_results)
for ms in context.observing_run.measurement_sets:
if ms.basename not in results_dict:
continue
result = results_dict[ms.basename]
vis = result.outcome['infile']
outfile = result.outcome['outfile']
LOG.debug('infile: {0}, outfile: {1}'.format(os.path.basename(vis), os.path.basename(outfile)))
vis_map[ms.name] = outfile
if 'plot_list' in result.outcome:
plot_list.extend(result.outcome['plot_list'])
if 'baseline_quality_stat' in result.outcome:
baseline_quality_stat.extend(result.outcome['baseline_quality_stat'])
outcome = {'baselined': baselined,
'vis_map': vis_map,
'edge': edge,
'deviation_mask': deviation_mask,
'plots': plot_list,
'baseline_quality_stat': baseline_quality_stat}
results = SDBaselineResults(task=self.__class__,
success=True,
outcome=outcome)
return results
[docs]
def analyse(self, result: SDBaselineResults) -> SDBaselineResults:
"""Generate measurementset domain object for output MS.
The method generates measumentsets objects from output MSes
of SDBaseline task. Newly created objects are registered to
results instance.
Args:
result: SDBaselineResults instance
Returns:
SDBaselineResults instance
"""
# Generate domain object of baselined MS
for infile, outfile in result.outcome['vis_map'].items():
in_ms = self.inputs.context.observing_run.get_ms(infile)
new_ms = generate_ms(outfile, in_ms)
new_ms.set_data_column(DataType.BASELINED, 'DATA')
result.out_mses.append(new_ms)
return result
class HeuristicsTask:
"""Executor for heuristics class. It is an adaptor to mpihelper framework."""
def __init__(self, heuristics_cls: type[Heuristic], *args: Any, **kwargs: Any) -> None:
"""Construct HeuristicsTask instance.
Args:
heuristics_cls: Heuristic class to run
"""
self.heuristics = heuristics_cls()
# print(args, kwargs)
self.args = args
self.kwargs = kwargs
# print(self.args, self.kwargs)
def execute(self) -> Any:
"""Perform Heuristics and return its result.
Returns:
Heuristics result. Actual contents of return value
depends on the Heuristics class.
"""
return self.heuristics.calculate(*self.args, **self.kwargs)
def get_executable(self) -> Callable[[], Any]:
"""Return function that runs execute method.
Returns:
Function to run execute method
"""
return lambda: self.execute()
class DeviationMaskHeuristicsTask(HeuristicsTask):
"""Executor class specialized to MaskDeviationHeuristic."""
def __init__(self,
heuristics_cls: type[MaskDeviationHeuristic],
vis: str,
field_list: list[int],
antenna_list: list[int],
real_spw_list: list[int],
deviationmask_sigma_threshold: float,
consider_flag: bool = False) -> None:
"""Construct DeviationMaskHeuristicsTask instance.
Executes heuristics to find deviation masks for given set of
field id, antenna id, and (real) spw ids. Those ids are taken from
field_list, antenna_list, and real_spw_list via zip so that their
length must be identical.
Args:
heuristics_cls: Heuristics class
vis: Name of the MS
field_list: List of field ids to process
antenna_list: List of antenna ids to process
real_spw_list: List of (real) spectral window ids to process
deviationmask_sigma_threshold: Threshold factor to detect the deviation.
(see SDBaselineInputs for details)
consider_flag: Consider flag when performing heuristics. Defaults to False.
"""
super().__init__(heuristics_cls, vis=vis, consider_flag=consider_flag)
self.vis = vis
self.field_list = field_list
self.antenna_list = antenna_list
self.real_spw_list = real_spw_list
self.deviationmask_sigma_threshold = deviationmask_sigma_threshold
def execute(self) -> list:
"""Execute heuristics.
Returns:
Deviation mask for each set of field id, antenna id, and real spw id.
"""
result = []
for field_id, antenna_id, real_spw_id in zip(self.field_list, self.antenna_list, self.real_spw_list):
self.kwargs.update({'field_id': field_id,
'antenna_id': antenna_id,
'spw_id': real_spw_id,
'detection': self.deviationmask_sigma_threshold})
mask_list = super().execute()
result.append(mask_list)
return result
def deviation_mask_heuristic(
vis: str,
field_list: list[int],
antenna_list: list[int],
real_spw_list: list[int],
deviationmask_sigma_threshold: float,
consider_flag: bool = False,
parallel: bool | None = None) -> tuple[str, mpihelpers.SyncTask | mpihelpers.AsyncTask]:
"""Prepare task instance that can be executed in mpihelpers framework.
Args:
vis: Name of MS
field_list: List of field ids to process
antenna_list: List of antenna ids to process
real_spw_list: List of (real) spectral window ids to process
consider_flag: Consider flag when performing heuristics. Defaults to False.
deviationmask_sigma_threshold: Threshold factor (F) to detect the deviation.
(see SDBaselineInputs for detail)
parallel: Parallel execution or not. Currently, disabled.
Task is always executed in serial mode. Defaults to None.
Returns:
Name of the MS and the task instance for mpihelpers framework.
"""
# parallel_wanted = mpihelpers.parse_mpi_input_parameter(parallel)
mytask = DeviationMaskHeuristicsTask(MaskDeviationHeuristic,
vis=vis,
field_list=field_list,
antenna_list=antenna_list,
real_spw_list=real_spw_list,
deviationmask_sigma_threshold=deviationmask_sigma_threshold,
consider_flag=consider_flag)
# if parallel_wanted:
if False:
task = mpihelpers.AsyncTask(mytask)
else:
LOG.trace('Deviation Mask Heuristic always runs in serial mode.')
task = mpihelpers.SyncTask(mytask)
return vis, task
def test_deviation_mask_heuristic(real_spw: int | None = None) -> None:
"""Test deviation mask heuristic.
Args:
real_spw: (real) spectral window id to process.
If None is given, spw is set to 17.
"""
import glob
vislist = glob.glob('uid___A002_X*.ms')
print('vislist={0}'.format(vislist))
field_list = [1, 1, 1]
antenna_list = [0, 1, 2]
real_spw_list = [17, 17, 17] if real_spw is None else [real_spw, real_spw, real_spw]
consider_flag = True
import time
start_time = time.time()
serial_tasks = [deviation_mask_heuristic(vis=vis, field_list=field_list, antenna_list=antenna_list, real_spw_list=real_spw_list, consider_flag=consider_flag, parallel=False) for vis in vislist]
serial_results = [(v, t.get_result()) for v, t in serial_tasks]
end_time = time.time()
print('serial execution duration {0}sec'.format(end_time-start_time))
start_time = time.time()
parallel_tasks = [deviation_mask_heuristic(vis=vis, field_list=field_list, antenna_list=antenna_list, real_spw_list=real_spw_list, consider_flag=consider_flag, parallel=True) for vis in vislist]
parallel_results = [(v, t.get_result()) for v, t in parallel_tasks]
end_time = time.time()
print('parallel execution duration {0}sec'.format(end_time-start_time))
for vis, smask in serial_results:
for _vis, pmask in parallel_results:
if vis == _vis:
for field_id in field_list:
for antenna_id in antenna_list:
for real_spw_id in real_spw_list:
print('vis "{0}", field {1} antenna {2} spw {3}:'.format(vis, field_id, antenna_id, real_spw_id))
print(' serial mask: {0}'.format(smask))
print(' parallel mask: {0}'.format(pmask))
print(' {0}'.format(smask == pmask))