Source code for pipeline.hsd.tasks.baseline.baseline

"""Spectral baseline subtraction stage."""
from __future__ import annotations

import collections
import os
from typing import TYPE_CHECKING

import numpy

import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.callibrary as callibrary
import pipeline.infrastructure.mpihelpers as mpihelpers
import pipeline.infrastructure.vdp as vdp
import pipeline.infrastructure.sessionutils as sessionutils

from pipeline.hsd.tasks.common.inspection_util import generate_ms, inspect_reduction_group, merge_reduction_group

from pipeline.domain import DataType
from pipeline.hsd.heuristics import MaskDeviationHeuristic
from pipeline.hsd.tasks.common.inspection_util import generate_ms, inspect_reduction_group, merge_reduction_group
from pipeline.infrastructure import task_registry

from . import maskline
from . import worker
from .. import common
from ..common import compress
from ..common import utils

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any

    from pipeline.infrastructure.api import Heuristic
    from pipeline.infrastructure.launcher import Context
    from .typing import FitFunc, FitOrder, LineWindow

# import memory_profiler
LOG = infrastructure.logging.get_logger(__name__)


class SDBaselineInputs(vdp.StandardInputs):
    """Inputs for baseline subtraction task."""

    # Search order of input vis
    processing_data_type = [DataType.ATMCORR, DataType.REGCAL_CONTLINE_ALL, DataType.RAW]

    infiles = vdp.VisDependentProperty(default='', null_input=['', None, [], ['']])
    spw = vdp.VisDependentProperty(default='')
    pol = vdp.VisDependentProperty(default='')
    field = vdp.VisDependentProperty(default='')
    linewindow = vdp.VisDependentProperty(default=[])
    linewindowmode = vdp.VisDependentProperty(default='replace')
    edge = vdp.VisDependentProperty(default=(0, 0))
    broadline = vdp.VisDependentProperty(default=True)
    fitorder = vdp.VisDependentProperty(default=-1)
    fitfunc = vdp.VisDependentProperty(default='cspline')
    switchpoly = vdp.VisDependentProperty(default=True)
    clusteringalgorithm = vdp.VisDependentProperty(default='hierarchy')
    wave_number = vdp.VisDependentProperty(default=None)
    deviationmask = vdp.VisDependentProperty(default=True)
    deviationmask_sigma_threshold = vdp.VisDependentProperty(default=5.0)

    # Synchronization between infiles and vis is still necessary
    @vdp.VisDependentProperty
    def vis(self) -> str:
        """Return input MS name."""
        return self.infiles

    # handle conversion from string to bool
    @switchpoly.convert
    def switchpoly(self, value: str | bool) -> bool:
        """Convert switchpoly value into bool.

        Args:
            value: any value provided by the user. It should be
                   boolean value or its string representation
                   such as 'True' or 'FALSE'.

        Return:
            boolean value
        """
        converted = None
        if isinstance(value, bool):
            converted = value
        elif isinstance(value, str):
            for b in (True, False):
                s = str(b)
                if value in (s.lower(), s.upper(), s.capitalize()):
                    converted = b
                    break
        assert converted is not None
        return converted

    # use common implementation for parallel inputs argument
    parallel = sessionutils.parallel_inputs_impl()

    # docstring and type hints: supplements hsd_baseline
    def __init__(self,
                 context: Context,
                 infiles: list[str] | None = None,
                 antenna: list[str] | None = None,
                 spw: list[str] | None = None,
                 pol: list[str] | None = None,
                 field: list[str] | None = None,
                 linewindow: LineWindow | None = None,
                 linewindowmode: str | None = None,
                 edge: tuple[int, int] | None = None,
                 broadline: bool | None = None,
                 fitfunc: FitFunc | None = None,
                 fitorder: FitOrder | None = None,
                 switchpoly: bool | None = None,
                 clusteringalgorithm: str | None = None,
                 wave_number: list[int] | None = None,
                 deviationmask: bool | None = None,
                 deviationmask_sigma_threshold: bool | None = None,
                 parallel: str | None = None) -> None:
        """Construct SDBaselineInputs instance.

        Args:
            context: Pipeline context

            infiles: List of data files. These must be a name of MeasurementSets that are
                registered to context via hsd_importdata or hsd_restoredata task.

                Example: ``vis=['X227.ms', 'X228.ms']``
                Default: ``None`` (process all registered MeasurementSets)

            antenna: Data selection by antenna.

                Example: '1' (select by ANTENNA_ID), 'PM03' (select by antenna name), '' (all antennas)

                Default: ``None`` (equivalent to ``''``)

            spw: Data selection by spw.

                Example: ``'3,4'`` (process spw 3 and 4), ['0','2'] (spw 0 for first data, 2 for second), ``''`` (all spws)

                Default: ``None`` (equivalent to ``''``)

            pol: Data selection by polarizations.

                Example: ``'0'`` (process pol 0), ``['0~1','0']`` (pol 0 and 1 for first data, only 0 for second), ``''`` (all polarizations)

                Default: ``None`` (equivalent to ``''``)

            field: Data selection by field.

                Example: ``'1'`` (select by FIELD_ID), ``'M100*'`` (select by field name), ``''`` (all fields)

                Default: ``None`` (equivalent to ``''``)

            linewindow: Pre-defined line window. If this is set, specified line windows are used as
                a line mask for baseline subtraction instead to determine masks based on line detection and
                validation stage. Several types of format are acceptable. One is channel-based window.
                ::

                    [min_chan, max_chan]

                where min_chan and max_chan should be an integer. For multiple  windows, nested list is
                also acceptable.
                ::

                    [[min_chan0, max_chan0], [min_chan1, max_chan1], ...]

                Another way is frequency-based window.
                ::

                    [min_freq, max_freq]

                where min_freq and max_freq should be either a float or a string. If float value is given,
                it is interpreted as a frequency in Hz. String should be a quantity consisting
                of "value" and "unit", e.g., '100GHz'. Multiple windows are also supported.
                ::

                    [[min_freq0, max_freq0], [min_freq1, max_freq1], ...]

                Note that the specified frequencies are assumed to be the value in LSRK frame.
                Note also that there is a limitation when multiple MSes are processed.
                If native frequency frame of the data is not LSRK (e.g. TOPO), frequencies need to be
                converted to that frame. As a result, corresponding channel range may vary
                between MSes. However, current implementation is not able to handle such case.
                Frequencies are converted to desired frame using representative MS (time, position,
                direction).

                In the above cases, specified line windows are applied to all science spws.
                In case when line windows vary with spw, line windows can be specified by a dictionary whose
                key is spw id while value is line window. For example, the following dictionary gives different
                line windows to spws 17 and 19. Other spws, if available, will have an empty line window.
                ::

                    {17: [[100, 200], [1200, 1400]], 19: ['112115MHz', '112116MHz']}

                Furthermore, linewindow accepts MS selection string. The following string gives
                [[100,200],[1200,1400]] for spw 17 while [1000,1500] for spw 21.
                ::

                    "17:100~200;1200~1400,21:1000~1500"

                The string also accepts frequency with units. Note, however, that frequency reference frame
                in this case is not fixed to LSRK. Instead, the frame will be taken from the MS
                (typically TOPO for ALMA). Thus, the following two frequency-based line windows
                result different channel selections.
                ::

                    {19: ['112115MHz', '112116MHz']} # frequency frame is LSRK
                    "19:11215MHz~11216MHz" # frequency frame is taken from the data (TOPO for ALMA)

                None is allowed as a value of dictionary input to indicate that no line detection/validation
                is required even if manually specified line window does not exist. When None is given as a value
                and if ``linewindowmode`` is 'replace', line detection/validation is not performed
                for the corresponding spw. For example, suppose the following parameters are given for the data
                with four science spws, 17, 19, 21, and 23.
                ::

                    linewindow={17: [112.1e9, 112.2e9], 19: [113.1e9, 113.15e9], 21: None}
                    linewindowmode='replace'

                The task will use given line window for 17 and 19 while the task performs
                line deteciton/validation for spw 23 because no line window is set.
                On the other hand, line detection/validation is skipped for spw 21 due to the effect of None.

                Example: [100,200] (channel), [115e9, 115.1e9] (frequency in Hz), ['115GHz', '115.1GHz'] (see above for more examples)

                Default: None

            linewindowmode: Merge or replace given manual line window with line detection/validation result.
                If 'replace' is given, line detection and validation will not be performed.
                On the other hand, when 'merge' is specified, line detection/validation will be performed
                and manually specified line windows are added to the result.
                Note that this has no effect when linewindow for target spw is an empty list. In that case,
                line detection/validation will be performed regardless of the value of linewindowmode.
                In case if no linewindow nor line detection/validation are necessary, you should set
                linewindowmode to 'replace' and specify None as a value of the linewindow dictionary
                for the spw to apply. See parameter description of ``linewindow`` for detail.

            edge: Number of edge channels to be dropped from baseline subtraction.
                The value must be a list with length of 2, whose values specify left and right edge channels,
                respectively.

                Example: ``[10,10]``

                Default: ``None``

            broadline: Try to detect broad component of spectral line if ``True``.

                Default: ``None`` (equivalent to ``True``)

            fitfunc: Fitting function for baseline subtraction. You can choose either cubic spline
                ('spline' or 'cspline'), polynomial ('poly' or 'polynomial'), sinusoid.

                Accepts:

                - A string: Applies the same function to all spectral windows (SPWs).
                - A dictionary: Maps SPW IDs (int or str) to a specific fitting function.

                If an SPW ID is not present in the dictionary, ``'cspline'`` will be used as the default.

                Default: ``None`` (equivalent to ``'cspline'``)

            fitorder: Fitting order for polynomial. For cubic spline, it is used to determine how
                much the spectrum is segmented into.

                Accepts:

                - An integer: Applies the same order to all SPWs. Valid values: ``-1`` (automatic), ``0``, or any positive integer.
                - A dictionary: Maps SPW IDs (int or str) to a specific fitting order.

                If an SPW ID is not present in the dictionary, ``-1`` will be used as the default, triggering automatic order selection.

                Default: ``None`` (equivalent to ``-1``)

            switchpoly: Whether to fall back the fits from cubic spline to 1st or 2nd order polynomial
                when large masks exist at the edges of the spw. Condition for switching is as follows:

                - if nmask > nchan/2      => 1st order polynomial
                - else if nmask > nchan/4 => 2nd order polynomial
                - else                    => use fitfunc and fitorder

                where nmask is a number of channels for mask at edge while
                nchan is a number of channels of entire spectral window.

                Default: ``None`` (equivalent to ``True``)

            clusteringalgorithm: Selection of the algorithm used in the clustering analysis
                to check the validity of detected line features. The 'kmean' algorithm,
                hierarchical clustering algorithm, 'hierarchy', and their combination ('both')
                are so far implemented.

                Default: ``None`` (equivalent to ``'hierarchy'``)

            wave_number: a list of sinusoidal wave numbers. The maximum wave numbers should not exceed the
                ((number of channels/2)-1) limit. If the offset is present in the data, add 0 to the number
                of waves. That is, nwave=[0] is a constant term, nwave=[0,1,2] fits with a maximum of 2
                sinusoids, and so on.


                Default: ``None`` (equivalent to ``False``)

            deviationmask: Apply deviation mask in addition to masks determined by the automatic line detection.

                Default: ``None`` (equivalent to ``True``)

            deviationmask_sigma_threshold: Threshold factor (F) to detect the deviation.
                Actual threshold will be median + F * standard-deviation of the spectrum.

                Default: ``None`` (equivalent to ``5.0``)

            parallel: Execute using CASA HPC functionality, if available.

                Options: ``'automatic'``, ``'true'``, ``'false'``, ``True``, ``False``.

                Default: ``None`` (equivalent to ``'automatic'``).

        """
        super().__init__()

        self.context = context
        self.infiles = infiles
        self.antenna = antenna
        self.spw = spw
        self.pol = pol
        self.field = field
        self.linewindow = linewindow
        self.linewindowmode = linewindowmode
        self.edge = edge
        self.broadline = broadline
        self.fitorder = fitorder
        self.fitfunc = fitfunc
        self.switchpoly = switchpoly
        self.clusteringalgorithm = clusteringalgorithm
        self.wave_number = wave_number
        self.deviationmask = deviationmask
        self.deviationmask_sigma_threshold = deviationmask_sigma_threshold
        self.parallel = parallel

    def to_casa_args(self) -> dict:
        """Convert Inputs instance into task execution parameter.

        Returns:
            Task execution parameter as kwargs.
        """
        infiles = self.infiles
        if isinstance(self.infiles, list):
            self.infiles = infiles[0]
        args = super().to_casa_args()
        self.infiles = infiles

        if 'antenna' not in args:
            args['antenna'] = ''
        return args


class SDBaselineResults(common.SingleDishResults):
    """Results class to hold the result of baseline subtraction task."""

    def __init__(self,
                 task: type[basetask.StandardTaskTemplate] | None = None,
                 success: bool | None = None,
                 outcome: Any = None) -> None:
        """Construct SDBaselineResults instance.

        Args:
            task: Task class that produced the result.
            success: Whether task execution is successful or not.
            outcome: Outcome of the task execution.
        """
        super().__init__(task, success, outcome)
        self.out_mses = []

    # @utils.profiler
    def merge_with_context(self, context: Context) -> None:
        """Merge result instance into context.

        Merge of the result instance of baseline subtraction task includes
        the following updates to Pipeline context,

          - register MeasurementSet domain object for the output of sdbaseline
            task to Pipeline context, namely measurement_sets list and
            reduction_group
          - register detected spectral lines to reduction group
          - register deviation mask to each MeasurementSet domain object

        Args:
            context: Pipeline context object containing state information.
        """
        super().merge_with_context(context)

        # register output MS domain object and reduction_group to context
        target = context.observing_run
        for ms in self.out_mses:
            # remove existing MS in context if the same MS is already in list.
            oldms_index = None
            for index, oldms in enumerate(target.get_measurement_sets()):
                if ms.name == oldms.name:
                    oldms_index = index
                    break
            if oldms_index is not None:
                LOG.info('Replace {} in context'.format(ms.name))
                del target.measurement_sets[oldms_index]

            # Adding mses to context
            LOG.info('Adding {} to context'.format(ms.name))
            target.add_measurement_set(ms)
            # Initialize callibrary
            calto = callibrary.CalTo(vis=ms.name)
            LOG.info('Registering {} with callibrary'.format(ms.name))
            context.callibrary.add(calto, [])
            # register output MS to processing group
            reduction_group = inspect_reduction_group(ms)
            merge_reduction_group(target, reduction_group)

        # increment iteration counter (only to in MS for now)
        # register detected lines to reduction group member (to both in and out MS)
        reduction_group = target.ms_reduction_group
        for b in self.outcome['baselined']:
            group_id = b['group_id']
            member_list = b['members']
            lines = b['lines']
            channelmap_range = b['channelmap_range']
            group_desc = reduction_group[group_id]
            for (ms, field, ant, spw) in utils.iterate_group_member(reduction_group[group_id], member_list):
                group_desc.iter_countup(ms, ant, spw, field)
                out_ms = target.get_ms(name=self.outcome['vis_map'][ms.name])
                for m in [ms, out_ms]:
                    group_desc.add_linelist(lines, m, ant, spw, field,
                                            channelmap_range=channelmap_range)

        # merge deviation_mask with context
        for ms in target.measurement_sets:
            if not hasattr(ms, 'deviation_mask'): ms.deviation_mask = None
        if 'deviation_mask' in self.outcome:
            for (basename, masks) in self.outcome['deviation_mask'].items():
                ms = target.get_ms(basename)
                ms.deviation_mask = {}
                for field in ms.get_fields(intent='TARGET'):
                    for antenna in ms.antennas:
                        for spw in ms.get_spectral_windows(science_windows_only=True):
                            key = (field.id, antenna.id, spw.id)
                            if key in masks:
                                ms.deviation_mask[key] = masks[key]
                out_ms = target.get_ms(name=self.outcome['vis_map'][ms.name])
                out_ms.deviation_mask = ms.deviation_mask

    def _outcome_name(self) -> str:
        """Return string summarizing the outcome.

        Returns:
            summary of the outcome.
        """
        return '\n'.join(['Reduction Group {0}: member {1}'.format(b['group_id'], b['members'])
                          for b in self.outcome['baselined']])


[docs] @task_registry.set_equivalent_casa_task('hsd_baseline') @task_registry.set_casa_commands_comment( 'Subtracts spectral baseline by least-square fitting with N-sigma clipping. Spectral lines are automatically ' 'detected and examined to determine the region that is masked to protect these features from the fit.\n' 'This stage performs a pipeline calculation without running any CASA commands to be put in this file.' ) class SDBaseline(basetask.StandardTaskTemplate): """Baseline subtraction task.""" Inputs = SDBaselineInputs is_multi_vis_task = True # @memory_profiler.profile
[docs] def prepare(self) -> SDBaselineResults: """Perform baseline subtraction. The method first evaluates deviation mask for each MS if the mask is not available, then perform line detection by combining all spectral data, and finally perform baseline subtraction using sdbaseline task. Returns: SDBaselineResults instance that holds list of output MS names with the map among input MS names, necessary data for weblog rendering, and the metric representing the quality of the baseline subtraction. """ LOG.debug('Starting SDMDBaseline.prepare') inputs = self.inputs context = inputs.context reduction_group = context.observing_run.ms_reduction_group vis_list = inputs.vis args = inputs.to_casa_args() # Spw selection will accept virtual spw so inputs.spw should be # (a list of) virtual spw ids. Intention here is to get a list # of selected *real* spw ids per MeasurementSet and store them # into args_real_spw as dict. args_real_spw = utils.convert_spw_virtual2real(context, inputs.spw) window = inputs.linewindow windowmode = inputs.linewindowmode LOG.info('{}: window={}, windowmode={}'.format(self.__class__.__name__, window, windowmode)) edge = inputs.edge broadline = inputs.broadline fitorder = 'automatic' if inputs.fitorder is None else inputs.fitorder fitfunc = inputs.fitfunc wave_number = inputs.wave_number switchpoly = inputs.switchpoly clusteringalgorithm = inputs.clusteringalgorithm deviationmask = inputs.deviationmask deviationmask_sigma_threshold = inputs.deviationmask_sigma_threshold # mkdir stage_dir if it doesn't exist stage_number = context.task_counter stage_dir = os.path.join(context.report_dir, "stage%d" % stage_number) if not os.path.exists(stage_dir): os.makedirs(stage_dir) # loop over reduction group # This is a dictionary for deviation mask that will be merged with top-level context deviation_mask = collections.defaultdict(dict) # collection of field, antenna, and spw ids in reduction group per MS registry = collections.defaultdict(utils.RGAccumulator) # outcome for baseline subtraction baselined = [] # dictionary of org_direction org_directions_dict = {} LOG.debug('Starting per reduction group processing: number of groups is %d', len(reduction_group)) for (group_id, group_desc) in reduction_group.items(): LOG.info('Processing Reduction Group %s', group_id) LOG.info('Group Summary:') for m in group_desc: # m.spw_id is real spw id LOG.info( f'\tMS "{m.ms.basename}" ' f'Antenna "{m.antenna_name}" (ID {m.antenna_id}) ' f'Spw ID {m.spw_id} ' f'Field "{m.field_name}" (ID {m.field_id})' ) # scan for org_direction and find the first one in group msobj = context.observing_run.get_ms(m.ms.basename) field_id = m.field_id fields = msobj.get_fields(field_id=field_id) source_name = fields[0].source.name if source_name not in org_directions_dict: if fields[0].source.is_eph_obj or fields[0].source.is_known_eph_obj: org_direction = fields[0].source.org_direction else: org_direction = None org_directions_dict.update({source_name: org_direction}) LOG.info("registered org_direction[{}]={}".format(source_name, org_direction)) # skip channel averaged spw nchan = group_desc.nchan LOG.debug('nchan for group %s is %s', group_id, nchan) if nchan == 1: first_member = group_desc[0] virtual_spw = context.observing_run.real2virtual_spw_id(first_member.spw.id, first_member.ms) LOG.info('Skip channel averaged spw %s.', virtual_spw) continue LOG.debug('spw=\'%s\'', args_real_spw) LOG.debug('vis_list=%s', vis_list) member_list = numpy.fromiter( utils.get_valid_ms_members(group_desc, vis_list, args['antenna'], args['field'], args_real_spw), dtype=numpy.int32) # skip this group if valid member list is empty LOG.debug('member_list=%s', member_list) if len(member_list) == 0: LOG.info('Skip reduction group %s', group_id) continue member_list.sort() # assume all members have same iteration first_member = group_desc[member_list[0]] iteration = first_member.iteration LOG.debug('iteration for group %s is %s', group_id, iteration) LOG.info('Members to be processed:') for (gms, gfield, gant, greal_spw) in utils.iterate_group_member(group_desc, member_list): LOG.info('\tMS "%s" Field ID %s Antenna ID %s Spw ID %s', gms.basename, gfield, gant, greal_spw) # Deviation Mask # NOTE: deviation mask is evaluated per ms per field per spw if deviationmask: LOG.info('Apply deviation mask to baseline fitting') dvtasks = [] dvparams = collections.defaultdict(lambda: ([], [], [])) for (ms, fieldid, antennaid, real_spwid) in utils.iterate_group_member(group_desc, member_list): if (not hasattr(ms, 'deviation_mask')) or ms.deviation_mask is None: ms.deviation_mask = {} if (fieldid, antennaid, real_spwid) not in ms.deviation_mask: LOG.debug('Evaluating deviation mask for %s field %s antenna %s spw %s', ms.basename, fieldid, antennaid, real_spwid) dvparams[ms.name][0].append(fieldid) dvparams[ms.name][1].append(antennaid) dvparams[ms.name][2].append(real_spwid) else: deviation_mask[ms.basename][(fieldid, antennaid, real_spwid)] = ms.deviation_mask[(fieldid, antennaid, real_spwid)] for (vis, params) in dvparams.items(): field_list, antenna_list, real_spw_list = params dvtasks.append(deviation_mask_heuristic(vis=vis, field_list=field_list, antenna_list=antenna_list, real_spw_list=real_spw_list, deviationmask_sigma_threshold=deviationmask_sigma_threshold, consider_flag=True, parallel=self.inputs.parallel)) for vis, dvtask in dvtasks: dvmasks = dvtask.get_result() field_list, antenna_list, real_spw_list = dvparams[vis] ms = context.observing_run.get_ms(vis) for field_id, antenna_id, real_spw_id, mask_list in zip(field_list, antenna_list, real_spw_list, dvmasks): # key: (fieldid, antennaid, real spwid) key = (field_id, antenna_id, real_spw_id) LOG.debug('deviation mask: key %s %s %s mask %s', field_id, antenna_id, real_spw_id, mask_list) ms.deviation_mask[key] = mask_list deviation_mask[ms.basename][key] = ms.deviation_mask[key] LOG.debug('evaluated deviation mask is %s', ms.deviation_mask[key]) else: LOG.info('Deviation mask is disabled by the user') LOG.debug('deviation_mask=%s', deviation_mask) # Spectral Line Detection and Validation # MaskLine will update DataTable.MASKLIST column maskline_inputs = maskline.MaskLine.Inputs(context, iteration, group_id, member_list, window, windowmode, edge, broadline, clusteringalgorithm) maskline_task = maskline.MaskLine(maskline_inputs) maskline_result = self._executor.execute(maskline_task, merge=False) grid_table = maskline_result.outcome['grid_table'] if grid_table is None: LOG.info('Skip reduction group %s', group_id) continue compressed_table = compress.CompressedObj(grid_table) del grid_table detected_lines = maskline_result.outcome['detected_lines'] channelmap_range = maskline_result.outcome['channelmap_range'] cluster_info = maskline_result.outcome['cluster_info'] flag_digits = maskline_result.outcome['flag_digits'] # register ids to per MS id collection for i in member_list: member = group_desc[i] registry[member.ms].append(member.field_id, member.antenna_id, member.spw_id, grid_table=compressed_table, channelmap_range=channelmap_range) # add entry to outcome baselined.append({'group_id': group_id, 'iteration': iteration, 'members': member_list, 'lines': detected_lines, 'channelmap_range': channelmap_range, 'clusters': cluster_info, 'flag_digits': flag_digits, 'org_direction': org_direction}) # - end of the loop over reduction group blparam_file = lambda ms: ms.basename.rstrip('/') \ + '_blparam_stage{stage}.txt'.format(stage=stage_number) vis_map = {} # key and value are input and output vis name, respectively plot_list = [] baseline_quality_stat = [] # plot_manager = plotter.BaselineSubtractionPlotManager(self.inputs.context, datatable) # Generate and apply baseline fitting solutions vislist = [ms.name for ms in registry] plan = [registry[ms] for ms in registry] blparam = [blparam_file(ms) for ms in registry] deviationmask_list = [deviation_mask[ms.basename] for ms in registry] edge_list = [edge for _ in registry] worker_cls = worker.BaselineSubtractionWorker fitter_inputs = vdp.InputsContainer(worker_cls, context, vis=vislist, plan=plan, fit_func=fitfunc, wave_number=wave_number, fit_order=fitorder, switchpoly=switchpoly, edge=edge_list, blparam=blparam, deviationmask=deviationmask_list, org_directions_dict=org_directions_dict, parallel=self.inputs.parallel) fitter_task = worker_cls(fitter_inputs) fitter_results = self._executor.execute(fitter_task, merge=False) # Check if fitting was successful fitting_failed = False if isinstance(fitter_results, basetask.FailedTaskResults): fitting_failed = True failed_results = basetask.ResultsList([fitter_results]) elif isinstance(fitter_results, basetask.ResultsList) and numpy.any([isinstance(r, basetask.FailedTaskResults) for r in fitter_results]): fitting_failed = True failed_results = basetask.ResultsList([r for r in fitter_results if isinstance(r, basetask.FailedTaskResults)]) if fitting_failed: for r in failed_results: r.origtask_cls = self.__class__ return failed_results results_dict = dict((os.path.basename(r.outcome['infile']), r) for r in fitter_results) for ms in context.observing_run.measurement_sets: if ms.basename not in results_dict: continue result = results_dict[ms.basename] vis = result.outcome['infile'] outfile = result.outcome['outfile'] LOG.debug('infile: {0}, outfile: {1}'.format(os.path.basename(vis), os.path.basename(outfile))) vis_map[ms.name] = outfile if 'plot_list' in result.outcome: plot_list.extend(result.outcome['plot_list']) if 'baseline_quality_stat' in result.outcome: baseline_quality_stat.extend(result.outcome['baseline_quality_stat']) outcome = {'baselined': baselined, 'vis_map': vis_map, 'edge': edge, 'deviation_mask': deviation_mask, 'plots': plot_list, 'baseline_quality_stat': baseline_quality_stat} results = SDBaselineResults(task=self.__class__, success=True, outcome=outcome) return results
[docs] def analyse(self, result: SDBaselineResults) -> SDBaselineResults: """Generate measurementset domain object for output MS. The method generates measumentsets objects from output MSes of SDBaseline task. Newly created objects are registered to results instance. Args: result: SDBaselineResults instance Returns: SDBaselineResults instance """ # Generate domain object of baselined MS for infile, outfile in result.outcome['vis_map'].items(): in_ms = self.inputs.context.observing_run.get_ms(infile) new_ms = generate_ms(outfile, in_ms) new_ms.set_data_column(DataType.BASELINED, 'DATA') result.out_mses.append(new_ms) return result
class HeuristicsTask: """Executor for heuristics class. It is an adaptor to mpihelper framework.""" def __init__(self, heuristics_cls: type[Heuristic], *args: Any, **kwargs: Any) -> None: """Construct HeuristicsTask instance. Args: heuristics_cls: Heuristic class to run """ self.heuristics = heuristics_cls() # print(args, kwargs) self.args = args self.kwargs = kwargs # print(self.args, self.kwargs) def execute(self) -> Any: """Perform Heuristics and return its result. Returns: Heuristics result. Actual contents of return value depends on the Heuristics class. """ return self.heuristics.calculate(*self.args, **self.kwargs) def get_executable(self) -> Callable[[], Any]: """Return function that runs execute method. Returns: Function to run execute method """ return lambda: self.execute() class DeviationMaskHeuristicsTask(HeuristicsTask): """Executor class specialized to MaskDeviationHeuristic.""" def __init__(self, heuristics_cls: type[MaskDeviationHeuristic], vis: str, field_list: list[int], antenna_list: list[int], real_spw_list: list[int], deviationmask_sigma_threshold: float, consider_flag: bool = False) -> None: """Construct DeviationMaskHeuristicsTask instance. Executes heuristics to find deviation masks for given set of field id, antenna id, and (real) spw ids. Those ids are taken from field_list, antenna_list, and real_spw_list via zip so that their length must be identical. Args: heuristics_cls: Heuristics class vis: Name of the MS field_list: List of field ids to process antenna_list: List of antenna ids to process real_spw_list: List of (real) spectral window ids to process deviationmask_sigma_threshold: Threshold factor to detect the deviation. (see SDBaselineInputs for details) consider_flag: Consider flag when performing heuristics. Defaults to False. """ super().__init__(heuristics_cls, vis=vis, consider_flag=consider_flag) self.vis = vis self.field_list = field_list self.antenna_list = antenna_list self.real_spw_list = real_spw_list self.deviationmask_sigma_threshold = deviationmask_sigma_threshold def execute(self) -> list: """Execute heuristics. Returns: Deviation mask for each set of field id, antenna id, and real spw id. """ result = [] for field_id, antenna_id, real_spw_id in zip(self.field_list, self.antenna_list, self.real_spw_list): self.kwargs.update({'field_id': field_id, 'antenna_id': antenna_id, 'spw_id': real_spw_id, 'detection': self.deviationmask_sigma_threshold}) mask_list = super().execute() result.append(mask_list) return result def deviation_mask_heuristic( vis: str, field_list: list[int], antenna_list: list[int], real_spw_list: list[int], deviationmask_sigma_threshold: float, consider_flag: bool = False, parallel: bool | None = None) -> tuple[str, mpihelpers.SyncTask | mpihelpers.AsyncTask]: """Prepare task instance that can be executed in mpihelpers framework. Args: vis: Name of MS field_list: List of field ids to process antenna_list: List of antenna ids to process real_spw_list: List of (real) spectral window ids to process consider_flag: Consider flag when performing heuristics. Defaults to False. deviationmask_sigma_threshold: Threshold factor (F) to detect the deviation. (see SDBaselineInputs for detail) parallel: Parallel execution or not. Currently, disabled. Task is always executed in serial mode. Defaults to None. Returns: Name of the MS and the task instance for mpihelpers framework. """ # parallel_wanted = mpihelpers.parse_mpi_input_parameter(parallel) mytask = DeviationMaskHeuristicsTask(MaskDeviationHeuristic, vis=vis, field_list=field_list, antenna_list=antenna_list, real_spw_list=real_spw_list, deviationmask_sigma_threshold=deviationmask_sigma_threshold, consider_flag=consider_flag) # if parallel_wanted: if False: task = mpihelpers.AsyncTask(mytask) else: LOG.trace('Deviation Mask Heuristic always runs in serial mode.') task = mpihelpers.SyncTask(mytask) return vis, task def test_deviation_mask_heuristic(real_spw: int | None = None) -> None: """Test deviation mask heuristic. Args: real_spw: (real) spectral window id to process. If None is given, spw is set to 17. """ import glob vislist = glob.glob('uid___A002_X*.ms') print('vislist={0}'.format(vislist)) field_list = [1, 1, 1] antenna_list = [0, 1, 2] real_spw_list = [17, 17, 17] if real_spw is None else [real_spw, real_spw, real_spw] consider_flag = True import time start_time = time.time() serial_tasks = [deviation_mask_heuristic(vis=vis, field_list=field_list, antenna_list=antenna_list, real_spw_list=real_spw_list, consider_flag=consider_flag, parallel=False) for vis in vislist] serial_results = [(v, t.get_result()) for v, t in serial_tasks] end_time = time.time() print('serial execution duration {0}sec'.format(end_time-start_time)) start_time = time.time() parallel_tasks = [deviation_mask_heuristic(vis=vis, field_list=field_list, antenna_list=antenna_list, real_spw_list=real_spw_list, consider_flag=consider_flag, parallel=True) for vis in vislist] parallel_results = [(v, t.get_result()) for v, t in parallel_tasks] end_time = time.time() print('parallel execution duration {0}sec'.format(end_time-start_time)) for vis, smask in serial_results: for _vis, pmask in parallel_results: if vis == _vis: for field_id in field_list: for antenna_id in antenna_list: for real_spw_id in real_spw_list: print('vis "{0}", field {1} antenna {2} spw {3}:'.format(vis, field_id, antenna_id, real_spw_id)) print(' serial mask: {0}'.format(smask)) print(' parallel mask: {0}'.format(pmask)) print(' {0}'.format(smask == pmask))