Source code for pipeline.hsd.tasks.atmcor.atmcor

"""Offline ATM correction stage."""
from __future__ import annotations

import collections
import collections.abc
import os
from typing import TYPE_CHECKING

import pipeline.hsd.heuristics.SDcalatmcorr as SDcalatmcorr
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.callibrary as callibrary
import pipeline.infrastructure.casa_tasks as casa_tasks
import pipeline.infrastructure.casa_tools as casa_tools
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.sessionutils as sessionutils
import pipeline.infrastructure.utils as utils
import pipeline.infrastructure.vdp as vdp
from pipeline.domain import DataType
from pipeline.h.heuristics import fieldnames
from pipeline.hsd.tasks.common.inspection_util import generate_ms, inspect_reduction_group, merge_reduction_group
from pipeline.infrastructure import task_registry
from pipeline.infrastructure.utils import relative_path
from .. import common

if TYPE_CHECKING:
    from pipeline.infrastructure.launcher import Context

LOG = infrastructure.logging.get_logger(__name__)


ATMModelParam = collections.namedtuple('ATMModelParam', 'atmtype maxalt dtem_dh h0')
ATMModelParam.__str__ = lambda self: f'atmtype {self.atmtype}, dtem_dh {self.dtem_dh}K/km, h0 {self.h0}km.'


# default atmtype list that is used when atmtype is 'auto'
DEFAULT_ATMTYPE_LIST = [1, 2, 3, 4]


class SDATMCorrectionInputs(vdp.StandardInputs):
    """Inputs class for SDATMCorrection task."""
    # Search order of input vis
    processing_data_type = [DataType.REGCAL_CONTLINE_ALL, DataType.RAW]

    parallel = sessionutils.parallel_inputs_impl()

    atmtype = vdp.VisDependentProperty(default='auto')
    dtem_dh = vdp.VisDependentProperty(default=-5.6)
    h0 = vdp.VisDependentProperty(default=2.0)
    maxalt = vdp.VisDependentProperty(default=120)
    intent = vdp.VisDependentProperty(default='TARGET')

    @atmtype.convert
    def atmtype(self, value: int | str | list[int | str]) -> str | list[str]:
        """Convert atmtype into str or a list of str.

        Args:
            value: atmtype value(s)

        Returns:
            atmtype as string type or a list of strings
        """
        # check if value is compatible with list
        if (not isinstance(value, (str, dict))) and isinstance(value, collections.abc.Iterable):
            list_value = list(value)
            value = [
                v if isinstance(v, str) else str(v) for v in list_value
            ]

            if len(value) == 1:
                value = value[0]
        else:
            value = str(value)
        return value

    def __to_float_value(self, value: float | str | dict | list[float | str | dict], default_unit: str) -> float | list[float]:
        """Convert input value into float value or list of float values.

        This method converts any value into float value. If input is a list,
        then return value is a list of float values obtained by converting
        each element of the input list. Return value(s) are interpreted
        as a quantity with default_unit.

        Args:
            value: Input value. The value can be a numerical value or
                   a quantity in the form of a dictionary (casa quantity)
                   or a string. A list of these values is also acceptable.
            default_unit: Unit string for conversion

        Returns:
            Float value or list of float values in the unit specified by
            default_unit. If the unit for input quantity is incompatible
            with default_unit, the method will emit a warning message and
            return value will be set to 0.
        """
        # check if value is compatible with list
        if (not isinstance(value, (str, dict))) and isinstance(value, collections.abc.Iterable):
            list_value = list(value)
            ret = [self.__to_float_value(v, default_unit) for v in list_value]

            if len(ret) == 1:
                ret = ret[0]

            return ret

        # non-list value
        qa = casa_tools.quanta
        if isinstance(value, dict):
            qvalue = value
        else:
            qvalue = qa.quantity(value)

        if qvalue['unit'] == '':
            ret = qvalue['value']
        elif qa.compare(qvalue, qa.quantity(0, default_unit)):
            ret = qa.convert(qvalue, default_unit)['value']
        else:
            LOG.warning(f'incompatible unit: input {value} requires unit {default_unit}')
            ret = 0.
        return ret

    @h0.convert
    def h0(self, value: float | str | dict | list[float | str | dict]) -> float | list[float]:
        """Convert any h0 value into float or a list of float.

        Input value(s) can be numerical value or a quantity in the
        form of a dictionary (casa quantity) or a string.
        A list of these values is also acceptable.

        Args:
            value: h0 value(s)

        Returns:
            h0 value(s) in the unit of km
        """
        return self.__to_float_value(value, 'km')

    @dtem_dh.convert
    def dtem_dh(self, value: float | str | dict | list[float | str | dict]) -> float | list[float]:
        """Convert any dtem_dh value into float or a list of float.

        Input value(s) can be numerical value or a quantity in the
        form of a dictionary (casa quantity) or a string.
        A list of these values is also acceptable.

        Args:
            value: dtem_dh value(s)

        Returns:
            dtem_dh value(s) in the unit of K/km
        """
        return self.__to_float_value(value, 'K/km')

    @maxalt.convert
    def maxalt(self, value: float | str | dict | list[float | str | dict]) -> float | list[float]:
        """Convert any maxalt value into float or a list of float.

        Input value(s) can be numerical value or a quantity in the
        form of a dictionary (casa quantity) or a string.
        A list of these values is also acceptable.

        Args:
            value: maxalt value(s)

        Returns:
            maxalt value(s) in the unit of km
        """
        return self.__to_float_value(value, 'km')

    @vdp.VisDependentProperty
    def infiles(self) -> str:
        """Return infiles.

        infiles is an alias of vis

        Returns:
            infiles string
        """
        return self.vis

    @infiles.convert
    def infiles(self, value: str) -> str:
        """Update infiles and vis consistently.

        Args:
            value: new infiles value

        Returns:
            input value
        """
        self.vis = value
        return value

    @vdp.VisDependentProperty
    def antenna(self) -> str:
        """Return antenna selection.

        By default, empty string (all antennas) is returned.

        Returns:
            antenna selection string
        """
        return ''

    @antenna.convert
    def antenna(self, value: str) -> str:
        """Convert antenna selection string.

        Args:
            value: input antenna selection

        Returns:
            converted antenna selection
        """
        antennas = self.ms.get_antenna(value)
        # if all antennas are selected, return ''
        if len(antennas) == len(self.ms.antennas):
            return ''
        return utils.find_ranges([a.id for a in antennas])

    @vdp.VisDependentProperty
    def field(self) -> str:
        """Return field selection string.

        By default, only fields that matches given intent are returned.

        Returns:
            field selection string
        """
        # this will give something like '0542+3243,0343+242'
        field_finder = fieldnames.IntentFieldnames()
        intent_fields = field_finder.calculate(self.ms, self.intent)

        # run the answer through a set, just in case there are duplicates
        fields = set()
        fields.update(utils.safe_split(intent_fields))

        return ','.join(fields)

    @vdp.VisDependentProperty
    def spw(self) -> str:
        """Return spw selection string.

        By default, channelized spws are returned.

        Returns:
            spw selection string
        """
        science_spws = self.ms.get_spectral_windows(with_channels=True)
        return ','.join([str(spw.id) for spw in science_spws])

    @vdp.VisDependentProperty
    def pol(self) -> str:
        """Return pol selection string.

        By default, polarization corresponding to selected spws are selected.

        Returns:
            pol selection string
        """
        # filters polarization by self.spw
        selected_spwids = [int(spwobj.id) for spwobj in self.ms.get_spectral_windows(self.spw, with_channels=True)]
        pols = set()
        for idx in selected_spwids:
            pols.update(self.ms.get_data_description(spw=idx).corr_axis)

        return ','.join(pols)

    # docstring and type hints: supplements hsd_atmcor
    def __init__(self,
                 context: Context,
                 atmtype: int | str | list[int] | list[str] | None = None,
                 dtem_dh: float | str | dict | list[float] | list[str] | list[dict] | None = None,
                 h0: float | str | dict | list[float] | list[str] | list[dict] | None = None,
                 maxalt: float | str | dict | list[float] | list[str] | list[dict] | None = None,
                 infiles: str | list[str] | None = None,
                 antenna: str | list[str] | None = None,
                 field: str | list[str] | None = None,
                 spw: str | list[str] | None = None,
                 pol: str | list[str] | None = None,
                 parallel: bool | str | None = None):
        """Initialize Inputs instance for hsd_atmcor.

        Args:
            context: pipeline context

            atmtype: Type of atmospheric transmission model represented as an integer.
                Available options are as follows. Integer values can be given as
                either integer or string, i.e. both 1 and '1' are acceptable.

                - 'auto': perform heuristics to choose best model (default).
                - 1: tropical.
                - 2: mid latitude summer.
                - 3: mid latitude winter.
                - 4: subarctic summer.
                - 5: subarctic winter.

                If list of integer is given, it also performs heuristics using the
                provided values instead of default, [1, 2, 3, 4], which is used
                when 'auto' is provided. List input should not contain 'auto'.

                Default: None (equivalent to 'auto')

            dtem_dh: Temperature gradient [K/km], e.g. -5.6 ("" = Tool default).
                The value is directly passed to initialization method for ATM model.
                Float and string types are acceptable. Float value is interpreted as
                the value in K/km. String value should be the numeric value with unit
                such as '-5.6K/km'. When list of values are given, it will
                trigger heuristics to choose best model from the provided value.

                Default: None (equivalent to tool default, -5.6K/km)

            h0: Scale height for water [km], e.g. 2.0 ("" = Tool default).
                The value is directly passed to initialization method for ATM model.
                Float and string types are acceptable. Float value is interpreted as
                the value in kilometer. String value should be the numeric value with
                unit compatible with length, such as '2km' or '2000m'.
                When list of values are given, it will trigger heuristics to
                choose best model from the provided value.

                Default: None (equivalent to tool default, 2.0km)

            maxalt: maximum altitude of the model [km]. Defaults to None.

            infiles: ASDM or MS files to be processed. This parameter behaves as
                data selection parameter. The name specified by infiles must be
                registered to context before you run hsd_atmcor.

            antenna: Data selection by antenna names or ids.

                Example: 'PM03,PM04', '' (all antennas)

            field: Data selection by field names or ids.

                Example: '`*Sgr*,M100`', '' (all fields)

            spw: Data selection by spw ids.

                Example: '3,4' (spw 3 and 4), '' (all spws)

            pol: Data selection by polarizations.

                Example: 'XX,YY' (correlation XX and YY), '' (all polarizations)

            parallel: Execute using CASA HPC functionality, if available.
                Default is None, which intends to turn on parallel
                processing if possible.
        """
        super().__init__()

        self.context = context
        self.atmtype = atmtype
        self.dtem_dh = dtem_dh
        self.h0 = h0
        self.maxalt = maxalt
        self.infiles = infiles
        self.antenna = antenna
        self.field = field
        self.spw = spw
        self.pol = pol
        self.parallel = parallel

    def _identify_datacolumn(self, vis: str) -> str:
        """Identify data column.

        Args:
            vis: MS name

        Raises:
            Exception: no datacolumn exists

        Returns:
            datacolumn parameter
        """
        datacolumn = ''
        with casa_tools.TableReader(vis) as tb:
            colnames = tb.colnames()

        names = (('CORRECTED_DATA', 'corrected'),
                 ('FLOAT_DATA', 'float_data'),
                 ('DATA', 'data'))
        for name, value in names:
            if name in colnames:
                datacolumn = value
                break

        if len(datacolumn) == 0:
            raise Exception('No datacolumn is found.')

        return datacolumn

    def get_caltable_from_callibrary(self) -> str:
        """Retrieve k2jycal caltable name from callibrary.

        Returns:
            Name of the caltable.
            Return empty string if k2jycal caltable is not applied.
        """
        applied_state = self.context.callibrary.applied
        calto = callibrary.CalTo(vis=self.vis)
        state_for_vis = applied_state.trimmed(self.context, calto)
        caltables = state_for_vis.get_caltable(caltypes=('amp', 'gaincal'))

        k2jycal_caltable = ''
        if len(caltables) > 0:
            k2jycal_caltable = caltables.pop()

        return k2jycal_caltable

    def get_gainfactor(self) -> float | str:
        """Retrieve k2jycal table from callibrary.

        Returns:
            name of the k2jycal table or 1.0
        """
        k2jycal_caltable = self.get_caltable_from_callibrary()
        gainfactor = 1.0
        if k2jycal_caltable:
            gainfactor = k2jycal_caltable
        return gainfactor

    def to_casa_args(self) -> dict:
        """Return task arguments for sdatmcor.

        Note that it might return invalid argument list when
        the user intends to run heuristics for ATM parameter.
        Please check if require_atm_heuristics method returns
        True to make sure the return value is valid.

        Returns:
            task arguments for sdatmcor
        """
        args = super().to_casa_args()

        # infile
        args.pop('infiles', None)
        infile = args.pop('vis')
        args['infile'] = infile

        # datacolumn
        args['datacolumn'] = self._identify_datacolumn(infile)

        # atmtype
        if isinstance(args['atmtype'], str) and args['atmtype'].isdigit():
            args['atmtype'] = int(args['atmtype'])

        # maxalt is not available
        args.pop('maxalt')

        # outfile
        if 'outfile' not in args:
            basename = os.path.basename(infile.rstrip('/'))
            suffix = '.atmcor.atmtype{}'.format(args['atmtype'])
            outfile = basename + suffix
            args['outfile'] = relative_path(os.path.join(self.output_dir, outfile))

        # ganfactor
        args['gainfactor'] = self.get_gainfactor()

        # overwrite is always True
        args['overwrite'] = True

        # spw -> outputspw
        args['outputspw'] = args.pop('spw', '')

        # pol -> correlation
        args['correlation'] = args.pop('pol', '')

        # correlation selection should be empty
        # to avoid strange error in VI/VB2 framework
        args['correlation'] = ''

        # process ON_SOURCE data only
        args['intent'] = 'OBSERVE_TARGET#ON_SOURCE'

        # remove parallel
        del args['parallel']

        return args

    def require_atm_heuristics(self) -> bool:
        """Check if ATM heuristics is required.

        ATM heuristics is required if any of the following
        conditions are met.

            - atmtype is either 'auto' or list of type IDs
            - dtem_dh is a list of float values
            - h0 is a list of float values

        Returns:
            True if ATM heuristics is required. Otherwise, False.
        """
        check_atmtype = isinstance(self.atmtype, list) or self.atmtype.lower() == 'auto'
        check_dtem_dh = isinstance(self.dtem_dh, list)
        check_h0 = isinstance(self.h0, list)
        return check_atmtype or check_dtem_dh or check_h0


class SDATMCorrectionResults(common.SingleDishResults):
    """Results instance for hsd_atmcor."""

    def __init__(self,
                 task: basetask.StandardTaskTemplate | None =None,
                 success: bool | None =None,
                 outcome: dict | None =None):
        """Initialize results instance for hsd_atmcor.

        The outcome must be a dict that contains:

            - 'task_args': actual argument list of the sdatmcor
                           (The "inputs" dictionary associated with
                           the results object holds "nominal"
                           argument list)
            - 'atm_heuristics': status string of the ATM heuristics
            - 'model_list': list of attempted ATM models
            - 'best_model_index': index for the best ATM model

        Args:
            task: task class. Defaults to None.
            success: task execution was successful or not. Defaults to None.
            outcome: outcome of the task execution. Defaults to None.
        """
        super().__init__(task, success, outcome)
        self.task_args = outcome['task_args']
        self.atmcor_ms_name = self.task_args['outfile']
        self.best_atmtype = self.task_args['atmtype']
        self.atm_heuristics = outcome['atm_heuristics']
        self.best_model_index = outcome['best_model_index']
        self.model_list = outcome['model_list']
        self.out_mses = []

    def merge_with_context(self, context: Context):
        """Merge execution result of atmcor stage into pipeline context.

        Args:
            pipeline context
        """
        super().merge_with_context(context)

        # register output MS domain object and reduction_group to context
        target = context.observing_run
        for ms in self.out_mses:
            # remove existing MS in context if the same MS is already in list.
            oldms_index = None
            for index, oldms in enumerate(target.get_measurement_sets()):
                if ms.name == oldms.name:
                    oldms_index = index
                    break
            if oldms_index is not None:
                LOG.info('Replace {} in context'.format(ms.name))
                del target.measurement_sets[oldms_index]

            # Adding mses to context
            LOG.info('Adding {} to context'.format(ms.name))
            target.add_measurement_set(ms)
            # Initialize callibrary
            calto = callibrary.CalTo(vis=ms.name)
            LOG.info('Registering {} with callibrary'.format(ms.name))
            context.callibrary.add(calto, [])
            # register output MS to processing group
            reduction_group = inspect_reduction_group(ms)
            merge_reduction_group(target, reduction_group)

    def _outcome_name(self) -> str:
        """Return representative string for the outcome.

        Any string that represents outcome is returned.
        In case of hsd_atmcor, output MS name is returned.

        Returns:
            output MS name for sdatmcor
        """
        return os.path.basename(self.atmcor_ms_name)


class SerialSDATMCorrection(basetask.StandardTaskTemplate):
    """Offline ATM correction task."""

    Inputs = SDATMCorrectionInputs

    def prepare(self) -> SDATMCorrectionResults:
        """Execute task and produce results instance.

        Raises:
            Exception: execution of sdatmcor was failed

        Returns:
            results instance for hsd_atmcor stage
        """
        # args for sdatmcor
        if self.inputs.require_atm_heuristics():
            # select best ATM model
            atm_heuristics, args, best_model_index, model_list = self._perform_atm_heuristics()
        else:
            atm_heuristics = 'N'
            args = self.inputs.to_casa_args()
            best_model_index = -1
            model_list = [
                ATMModelParam(args['atmtype'], self.inputs.maxalt, args['dtem_dh'], args['h0'])
            ]

        LOG.info('Processing parameter for sdatmcor: %s', args)
        job = casa_tasks.sdatmcor(**args)
        task_exec_status = self._executor.execute(job)
        LOG.info('atmcor: task_exec_status = %s', task_exec_status)

        if not os.path.exists(args['outfile']):
            raise Exception('Output MS does not exist. It seems sdatmcor failed.')

        if task_exec_status is None:
            # no news is good news, this is a sign of success
            is_successful = True
        elif task_exec_status is False:
            # it indicates any problem
            is_successful = False
        else:
            # unexpected, mark as failed
            is_successful = False

        results = SDATMCorrectionResults(
            task=self.__class__,
            success=is_successful,
            outcome={
                'task_args': args,
                'atm_heuristics': atm_heuristics,
                'best_model_index': best_model_index,
                'model_list': model_list,
            }
        )

        return results

    def analyse(self, result: SDATMCorrectionResults) -> SDATMCorrectionResults:
        """Analyse results produced by prepare method.

        Generate domain object of MS with offline ATM correction.

        Args:
            result: results instance

        Returns:
            input results instance
        """
        in_ms = self.inputs.ms
        new_ms = generate_ms(result.atmcor_ms_name, in_ms)
        new_ms.set_data_column(DataType.ATMCORR, 'DATA')
        result.out_mses.append(new_ms)
        return result

    def _perform_atm_heuristics(self) -> tuple[str, dict, int, list[tuple[int, float, float, float]]]:
        """Perform ATM model heuristics.

        Perform ATM model heuristics, SDcalatmcorr.

        Returns:
            Four tuple, status of ATM model heuristics, argument list for sdatmcor,
            index of the best ATM model, and list of attempted ATM models.
        """
        # create weblog directory
        stage_number = self.inputs.context.task_counter
        stage_dir = os.path.join(
            self.inputs.context.report_dir,
            f'stage{stage_number}'
        )
        os.makedirs(stage_dir, exist_ok=True)

        # perform atmtype heuristics if atmtype is 'auto'
        # run Harold's script here
        LOG.info('Performing atmtype heuristics')
        atm_heuristics = 'Fallback'
        default_model = ATMModelParam(atmtype=1, maxalt=120, dtem_dh=-5.6, h0=2.0)
        # best_model will fall back to default_model if heuristics is failed
        best_model = default_model
        args = self.inputs.to_casa_args()
        ms_name = args['infile']
        model_list = [default_model]
        best_model_index = -1
        LOG.info(f'default_model: {default_model}')

        # handle list inputs
        if isinstance(args['atmtype'], list):
            atmtype_list = [int(x) for x in args['atmtype']]
        else:
            # should be 'auto'
            atmtype_list = DEFAULT_ATMTYPE_LIST

        try:
            heuristics_result = SDcalatmcorr.selectModelParams(
                mslist=[ms_name],
                context=self.inputs.context,
                decisionmetric='intsqdiff',
                atmtype=atmtype_list,
                maxalt=self.inputs.maxalt,
                lapserate=self.inputs.dtem_dh,
                scaleht=self.inputs.h0,
                plotsfolder=stage_dir,
                defatmtype=default_model.atmtype,
                defmaxalt=default_model.maxalt,
                deflapserate=default_model.dtem_dh,
                defscaleht=default_model.h0,
                diffsmooth=0.016
            )
            best_model = ATMModelParam(*heuristics_result[0][ms_name])

            status = heuristics_result[3][ms_name]
            if status == 'bestfitmodel':
                atm_heuristics = 'Y'
                model_list = [ATMModelParam(*x) for x in heuristics_result[1][ms_name]]
                best_model_index = model_list.index(best_model)
                LOG.info(f'Best ATM model is {best_model}.')
            else:
                LOG.info(f'ATM heuristics failed. Using default model {default_model}.')
                model_list = [best_model]

        except Exception as e:
            LOG.info(f'ATM heuristics failed. Falling back to default model {default_model}.')
            LOG.info('Original error:')
            LOG.info(str(e))
            if LOG.isEnabledFor(infrastructure.logging.DEBUG):
                import traceback
                LOG.debug(traceback.format_exc())

        # construct argument list for sdatmcor
        inputs_local = utils.pickle_copy(self.inputs)
        inputs_local.atmtype = best_model.atmtype
        inputs_local.maxalt = best_model.maxalt
        inputs_local.dtem_dh = best_model.dtem_dh
        inputs_local.h0 = best_model.h0
        args = inputs_local.to_casa_args()

        return atm_heuristics, args, best_model_index, model_list

[docs] @task_registry.set_equivalent_casa_task('hsd_atmcor') @task_registry.set_casa_commands_comment( 'Apply offline correction of atmospheric transmission model.' ) class SDATMCorrection(sessionutils.ParallelTemplate): """Parallel implementation of offline ATM correction task.""" Inputs = SDATMCorrectionInputs Task = SerialSDATMCorrection