Source code for pipeline.hsd.tasks.applycal.applycal

from __future__ import annotations

import os
from typing import TYPE_CHECKING
import numpy

import pipeline.extern.sd_applycal_qa.sd_applycal_qa as sd_applycal_qa
import pipeline.extern.sd_applycal_qa.sd_qa_reports as sd_qa_reports
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.renderer.logger as logger
import pipeline.infrastructure.vdp as vdp
import pipeline.infrastructure.sessionutils as sessionutils
from pipeline.domain.datatable import DataTableImpl as DataTable
from pipeline.domain import DataType
from pipeline.h.tasks.applycal.applycal import SerialApplycal, ApplycalInputs, ApplycalResults
from pipeline.hsd.tasks.applycal.display import ApplyCalSingleDishPlotmsSpwComposite, ApplyCalSingleDishPlotmsAntSpwComposite
import pipeline.hsd.tasks.applycal.display as display
from pipeline.infrastructure import casa_tools
from pipeline.infrastructure import task_registry

if TYPE_CHECKING:
    from pipeline.domain import MeasurementSet
    from pipeline.infrastructure import CalApplication
    from pipeline.infrastructure.launcher import Context

LOG = infrastructure.get_logger(__name__)


class SDApplycalInputs(ApplycalInputs):
    """
    ApplycalInputs defines the inputs for the Applycal pipeline task.
    """
    # use common implementation for parallel inputs argument
    parallel = sessionutils.parallel_inputs_impl()

    flagdetailedsum = vdp.VisDependentProperty(default=True)
    intent = vdp.VisDependentProperty(default='TARGET')

    # docstring and type hints: supplements hsd_applycal
    def __init__(self,
                 context: Context,
                 output_dir: str | None = None,
                 vis: str | list[str] | None = None,
                 field: str | list[str] | None = None,
                 spw: str | list[str] | None = None,
                 antenna: str | list[str] | None = None,
                 intent: str | list[str] | None = None,
                 parang: bool | None = None,
                 applymode: str | None = None,
                 flagbackup: bool | None = None,
                 flagsum: bool | None = None,
                 flagdetailedsum: bool | None = None,
                 parallel: bool | str | None = None):
        """Inputs for SDApplycal task.

        Args:
            context: Pipeline context object containing state information.

            output_dir: Output directory.

            vis: The list of input MeasurementSets. Defaults to the list of MeasurementSets in the pipeline context.

                Example: ['X227.ms']

            field: A string containing the list of field names or field ids to which the calibration will be applied.
                Defaults to all fields in the pipeline context.

                Example: '3C279', '3C279, M82'

            spw: The list of spectral windows and channels to which the calibration will be applied.
                Defaults to all science windows in the pipeline context.

                Example: '17', '11, 15'

            antenna: The selection of antennas to which the calibration will be applied.
                Defaults to all antennas. Not currently supported.

            intent: A string containing the list of intents against which the selected fields will be matched.
                Defaults to all supported intents in the pipeline context.

                Example: `'*TARGET*'`

            parang: Apply parallactic angle correction. Not used.

            applymode: Calibration apply mode.

                - 'calflag': calibrate data and apply flags from solutions.
                - 'calflagstrict': same as above except flag spws for which calibration is
                  unavailable in one or more tables (instead of allowing them to pass
                  uncalibrated and unflagged). This is the default applymode.
                - 'trial': report on flags from solutions, dataset entirely unchanged.
                - 'flagonly': apply flags from solutions only, data not calibrated.
                - 'flagonlystrict': same as above except flag spws for which calibration is
                  unavailable in one or more tables.
                - 'calonly': calibrate data only, flags from solutions NOT applied.

            flagbackup: Backup the flags before the apply.

                Default: None (equivalent to True)

            flagsum: Run flagdata task for flagging summary.

                Default: None (equivalent to True)

            flagdetailedsum: Generate detailed flagging summary.

                Default: None (equivalent to False)

            parallel: Execute using CASA HPC functionality, if available.
                Default is None, which is equivalent to 'automatic' that intends to
                turn on parallel processing if possible.

                Options: 'automatic', 'true', 'false', True, False
        """
        super().__init__(
            context, output_dir=output_dir, vis=vis,
            field=field, spw=spw, antenna=antenna, intent=intent,
            parang=parang, applymode=applymode, flagbackup=flagbackup,
            flagsum=flagsum, flagdetailedsum=flagdetailedsum,
            parallel=parallel
        )


class SDApplycalResults(ApplycalResults):
    """
    ApplycalResults generated by SDApplycal task.
    Please see parent task's docstring for detail.
    """
    def __init__(self,
                 applied: list[CalApplication] | None = None,
                 data_type: DataType | None = None):
        """Construct SDApplycalResults instance.
        Please see parent task's docstring for detail.

        Args:
            applied: caltables applied by this task.
            data_type: data type enum.
        """
        super().__init__(applied, data_type=data_type)
        self.xy_deviation_score = []
        self.xy_deviation_plots = []


class SerialSDApplycal(SerialApplycal):
    """
    Applycal executes CASA applycal tasks for the current context state,
    applying calibrations registered with the pipeline context to the target
    measurement set.

    Applying the results from this task to the context marks the referred
    tables as applied. As a result, they will not be included in future
    on-the-fly calibration arguments.
    """
    Inputs = SDApplycalInputs

    def modify_task_args(self, task_args: dict) -> dict:
        """Override template arguments for applycal execution.

        This method receives template task_args created by the parent task,
        and override it with our SD-specific antenna selection arguments.

        Args:
            task_args: Template arguments for applycal execution.

        Returns:
            Updated task arguments.
        """
        task_args['antenna'] = '*&&&'
        return task_args

    def _get_flagsum_arg(self, args: dict) -> dict:
        """Update arguments for flag summary.

        According to the requirement in CAS-8813, flag fraction should be
        based on target instead of total.

        Args:
            args: Template arguments for flag summary.

        Returns:
            Updated task arguments.
        """
        task_args = super()._get_flagsum_arg(args)
        task_args['intent'] = 'OBSERVE_TARGET#ON_SOURCE'
        return task_args

    def _tweak_flagkwargs(self, template: list[str]) -> list[str]:
        """Override flagging commands.

        According to the requirement in CAS-8813, flag fraction should be
        based on target instead of total.

        Args:
            template: List of flagging commands.

        Returns:
            Updated flagging commands.
        """
        # use of ' rather than " is required to prevent escaping of flagcmds
        return [row + " intent='OBSERVE_TARGET#ON_SOURCE'" for row in template]

    def prepare(self):
        # execute Applycal
        results = super().prepare()

        # Update Tsys in datatable
        context = self.inputs.context

        # this task uses _handle_multiple_vis framework
        msobj = self.inputs.ms
        origin_basename = os.path.basename(msobj.origin_ms)
        datatable_name = os.path.join(context.observing_run.ms_datatable_name, origin_basename)
        datatable = DataTable()
        datatable.importdata(name=datatable_name, readonly=False)
        datatable._update_flag(msobj.name)
        for calapp in results.applied:
            filename = os.path.join(context.output_dir, calapp.vis)
            fieldids = [fieldobj.id for fieldobj in msobj.get_fields(calapp.field)]
            for _calfrom in calapp.calfrom:
                if _calfrom.caltype == 'tsys':
                    LOG.info('Updating Tsys for {0}'.format(os.path.join(calapp.vis)))
                    tsystable = _calfrom.gaintable
                    spwmap = _calfrom.spwmap
                    gainfield = _calfrom.gainfield
                    datatable._update_tsys(context, filename, tsystable, spwmap, fieldids, gainfield)

        # here, full export is necessary
        datatable.exportdata(minimal=False)

        sdresults = SDApplycalResults(applied=results.applied, data_type=self.applied_data_type)
        sdresults.summaries = results.summaries
        if hasattr(results, 'flagsummary'):
            sdresults.flagsummary = results.flagsummary

        # set unit according to applied calibration
        set_unit(msobj, results.applied)
        return sdresults

    def analyse(self, results: SDApplycalResults) -> SDApplycalResults:
        """Analyse the results of the task.

        This method assesses the quality of the calibration applied in
        this stage. The analysis focuses on the deviation of calibrated
        data between XX and YY polarizations, and also the generation of
        calibrated amplitude vs time plots.

        Returns:
            SDApplycalResults: The results of the task.
        """
        results = super().analyse(results)
        context = self.inputs.context
        msobj = self.inputs.ms

        # perform XX-YY deviation QA
        ms_name = self.inputs.ms.name
        if self.inputs.ms.antenna_array.name == 'ALMA':
            applycal_qa_dir = './sd_applycal_output'
            os.makedirs(applycal_qa_dir, exist_ok=True)

            stage_dir = os.path.join(
                self.inputs.context.report_dir,
                f'stage{self.inputs.context.task_counter}'
            )
            if basetask.DISABLE_WEBLOG:
                # Since weblog is disabled, all the plots will be saved
                # in applycal_qa_dir
                weblog_output_dir = applycal_qa_dir
            else:
                os.makedirs(stage_dir, exist_ok=True)
                weblog_output_dir = stage_dir

            qa_result = sd_applycal_qa.get_ms_applycal_qascores(
                msNames=[ms_name],
                plot_output_path=applycal_qa_dir,
                weblog_output_path=weblog_output_dir,
            )
            qascore_list, plots_fnames, qascore_per_scan_list = qa_result
            sd_qa_reports.makeSummaryTable(
                qascore_list,
                '',
                plfolder=applycal_qa_dir,
                output_file=os.path.join(applycal_qa_dir, f'qascore_summary_{self.inputs.ms.basename}.csv')
            )
            sd_qa_reports.makeQAmsgTable(
                qascore_list,
                plfolder=applycal_qa_dir,
                output_file=os.path.join(applycal_qa_dir, f'qascores_details_{self.inputs.ms.basename}.csv')
            )
            valid_plots_fnames = [x for x in plots_fnames if x != "N/A"]
            results.xy_deviation_score.extend(qascore_list)
            results.xy_deviation_plots.extend(valid_plots_fnames)

        # Generating calibrated amplitude vs time plots
        results.amp_vs_time_summary_plots = None
        results.amp_vs_time_detail_plots = None
        if not basetask.DISABLE_WEBLOG:
            # mkdir stage_dir if it doesn't exist
            stage_dir = os.path.join(context.report_dir, 'stage%s' % context.task_counter)
            os.makedirs(stage_dir, exist_ok=True)

            fields = [x.name for x in msobj.get_fields(intent='TARGET')]
            if len(fields) > 0:
                # For summary plots
                amp_vs_time_summary_plots = self.sd_plots_for_result(
                    context,
                    results,
                    display.ApplyCalSingleDishPlotmsSpwComposite
                )

                # For detail plots
                amp_vs_time_detail_plots = self.sd_plots_for_result(
                    context,
                    results,
                    display.ApplyCalSingleDishPlotmsAntSpwComposite
                )
                results.amp_vs_time_summary_plots = amp_vs_time_summary_plots
                results.amp_vs_time_detail_plots = amp_vs_time_detail_plots

        return results

    def sd_plots_for_result(self, context: Context, results: SDApplycalResults, plotter_cls: ApplyCalSingleDishPlotmsSpwComposite | ApplyCalSingleDishPlotmsAntSpwComposite, **kwargs) -> list[logger.Plot]:
        """Generate amplitude vs. time plots from results instance.

        Args:
            context: Pipeline context object containing state information.
            results: Results instance.
            plotter_cls: Plotter class to generate plot objects of amplitude vs. time.

        Returns:
            plots: List of plot objects of amplitude vs. time.
        """
        xaxis = 'time'
        yaxis = 'real'
        msobj = context.observing_run.get_ms(self.inputs.vis)
        plotter = plotter_cls(context, results, msobj, xaxis, yaxis, **kwargs)
        plots = plotter.plot()

        return plots


def set_unit(ms: MeasurementSet, calapp: list[CalApplication]):
    """Set unit to MS data column according to applied calibrations.

    Args:
        ms: MeasurementSet domain object.
        calapp: List of CalApplication objects.
    """
    target_fields = ms.get_fields(intent='TARGET')
    data_units = dict((f.id, '') for f in target_fields)
    for a in calapp:
        calto = a.calto
        field_name = calto.field
        if len(field_name) == 0:
            field = ms.get_fields(intent='TARGET')
        else:
            field = [f for f in ms.get_fields(name=field_name) if 'TARGET' in f.intents]
        if len(field) == 0:
            continue

        assert len(field) == 1
        field_id = field[0].id
        caltypes = [cf.caltype for cf in a.calfrom]
        if ('ps' in caltypes) and ('tsys' in caltypes):
            data_units[field_id] = 'K'
        else:
            LOG.warning(
                f'Calibration of {ms.basename} (field {field_name}) is not '
                'correct. Missing pscal and/or tsyscal.'
            )
        if 'amp' in caltypes or 'gaincal' in caltypes:
            if data_units[field_id] == 'K':
                data_units[field_id] = 'Jy'

    unit_list = numpy.asarray(list(data_units.values()))
    if numpy.all(unit_list == 'K'):
        data_unit = 'K'
    elif numpy.all(unit_list == 'Jy'):
        data_unit = 'Jy'
    elif numpy.any(unit_list == 'Jy'):
        LOG.warning(
            f'Calibration of {ms.basename} is not correct. '
            'Some of calibrations (pscal, tsyscal, ampcal) are missing.'
        )
        data_unit = ''
    else:
        data_unit = ''

    if data_unit != '':
        with casa_tools.TableReader(ms.name, nomodify=False) as tb:
            colnames = tb.colnames()
            target_columns = set(colnames) & set(['DATA', 'FLOAT_DATA', 'CORRECTED_DATA'])
            for col in target_columns:
                tb.putcolkeyword(col, 'UNIT', data_unit)


# Tier-0 parallelization
[docs] @task_registry.set_equivalent_casa_task('hsd_applycal') @task_registry.set_casa_commands_comment('Calibrations are applied to the data. Final flagging summaries are computed') class SDApplycal(sessionutils.ParallelTemplate): Inputs = SDApplycalInputs Task = SerialSDApplycal