"""
The sorting module contains utility functions used by the pipeline web log.
"""
# Do not evaluate type annotations at definition time.
from __future__ import annotations
import collections
import datetime
import functools
import html
import itertools
import operator
import os
from typing import TYPE_CHECKING
import astropy.table
import numpy as np
from pipeline import infrastructure
from pipeline.infrastructure import casa_tools, utils
from pipeline.infrastructure.utils import conversion
if TYPE_CHECKING:
from pipeline.domain import MeasurementSet
from pipeline.domain.measures import Distance
__all__ = ['OrderedDefaultdict', 'merge_td_columns', 'merge_td_rows', 'get_vis_from_plots', 'total_time_on_source',
'total_time_on_target_on_source', 'get_logrecords', 'get_intervals', 'table_to_html', 'plots_to_html',
'scale_uv_range', 'split_spw', 'get_directory_size']
LOG = infrastructure.logging.get_logger(__name__)
[docs]
class OrderedDefaultdict(collections.OrderedDict):
"""This class behaves as defaultdict from the collections module but maintaining the order of insertion.
It is usually called in our codebase using the following structure: my_list = utils.OrderedDefaultdict(list)
The dict can then be filled straight away: my_list[2] = [1, 2, 3]
For example,
>>> my_list = OrderedDefaultdict(list)
>>> my_list[2] = [1, 2, 3]
>>> my_list[1]
[]
>>> my_list[2]
[1, 2, 3]
Note that from Python 3.8 this class should probably work as collections.defaultdict given that dicts now preserve
the insertion order as a feature and the __reverse__ method is implemented in dicts.
"""
def __init__(self, *args, **kwargs):
if not args:
self.default_factory = None
else:
if not (args[0] is None or callable(args[0])):
raise TypeError('first argument must be callable or None')
self.default_factory = args[0]
args = args[1:]
super().__init__(*args, **kwargs)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
self[key] = default = self.default_factory()
return default
def __reduce__(self): # optional, for pickle support
args = (self.default_factory,) if self.default_factory else ()
return self.__class__, args, None, None, iter(self.items())
[docs]
def merge_td_columns(rows, num_to_merge=None, vertical_align=False):
"""Merge HTML TD columns with identical values using rowspan.
Arguments:
rows -- a list of tuples, one tuple per row, containing n elements for the
n columns.
num_to_merge -- the number of columns to merge, starting from the left
hand column. Leave as None to merge all columns.
vertical_align -- Set to True to vertically centre any merged/unmerged cells.
Output:
A list of strings, one string per row, containing TD elements.
"""
transposed = list(zip(*rows))
if num_to_merge is None:
num_to_merge = len(transposed)
valign = ' style="vertical-align:middle;"' if vertical_align else ''
new_cols = []
for col_idx, col in enumerate(transposed):
if col_idx > num_to_merge - 1:
new_cols.append(['<td>%s</td>' % v for v in col])
continue
merged = []
start = 0
while start < len(col):
l = col[start:]
same_vals = list(itertools.takewhile(lambda x: x == col[start], l))
rowspan = len(same_vals)
start += rowspan
if rowspan > 1:
new_td = [f'<td rowspan="{rowspan}"{valign}>{same_vals[0]}</td>']
blanks = [''] * (rowspan - 1)
merged.extend(new_td + blanks)
else:
td = f'<td{valign}>{same_vals[0]}</td>'
merged.append(td)
new_cols.append(merged)
return list(zip(*new_cols))
[docs]
def merge_td_rows(table):
"""
Merge HTML TD rows with identical values using colspan.
Arguments:
table -- a list of tuples, one tuple per row, containing n elements for the n columns.
Output:
A list of tuples with adjusted idential values merged with colspan.
"""
new_table = []
for row in table:
row_list = list(row)
start = 0
while start < len(row):
start_cell = row[start]
merge_count = 0
end = start+1
while end < len(row):
if start_cell == row[end]:
row_list[end] = ''
merge_count += 1
end += 1
else:
break
if merge_count > 0:
row_list[start] = row_list[start].replace('<td', fr'<td colspan="{merge_count+1}"')
start += 1
new_table.append(tuple(row_list))
return new_table
[docs]
def get_vis_from_plots(plots):
"""
Get the name to be used for the MS from the given plots.
:param plots:
:return:
"""
vis = {p.parameters['vis'] for p in plots}
vis = vis.pop() if len(vis) == 1 else 'all data'
return vis
[docs]
def total_time_on_target_on_source(ms, autocorr_only=False):
"""
Return the nominal total time on target source for the given MeasurementSet
excluding OFF-source integrations (REFERENCE). The flag is not taken into account.
Background of development: ALMA-TP observations have integrations of both TARGET
and REFERENCE intents in one scan. Scan.time_on_source does not return appropriate
exposure time in the case.
:param ms: MeasurementSet domain object to examine
:param autocorr_only:
:return: a datetime.timedelta object set to the total time on source
"""
science_spws = ms.get_spectral_windows(science_windows_only=True)
state_ids = [s.id for s in ms.states if 'TARGET' in s.intents]
max_time = 0.0
ant_ids = [a.id for a in ms.antennas]
dds = [ms.get_data_description(spw=spw) for spw in science_spws]
science_dds = np.unique([dd.id for dd in dds])
with casa_tools.TableReader(ms.name) as tb:
for dd in science_dds:
for a1 in ant_ids:
for a2 in ant_ids:
if autocorr_only and a1 != a2:
continue
seltb = tb.query('DATA_DESC_ID == %d AND ANTENNA1 == %d AND ANTENNA2 == %d AND STATE_ID IN %s' % (
dd, a1, a2, utils.list_to_str(state_ids)))
try:
if seltb.nrows() == 0:
continue
target_exposures = seltb.getcol('EXPOSURE').sum()
LOG.debug(
"Selected %d ON-source rows for DD=%d, Ant1=%d, Ant2=%d: total exposure time = %f sec" % (
seltb.nrows(), dd, a1, a2, target_exposures))
max_time = max(max_time, target_exposures)
finally:
seltb.close()
LOG.debug("Max ON-source exposure time = %f sec" % max_time)
return datetime.timedelta(int(max_time / 86400), int(max_time % 86400), int((max_time % 1) * 1e6))
[docs]
def total_time_on_source(scans):
"""
Return the total time on source for the given Scans.
:param scans: collection of Scan domain objects
:return: a datetime.timedelta object set to the total time on source
"""
times_on_source = [scan.time_on_source for scan in scans]
if times_on_source:
return functools.reduce(operator.add, times_on_source)
else:
# could potentially be zero matching scans, such as when the
# measurement set is missing scans with science intent
return datetime.timedelta(0)
[docs]
def get_logrecords(result, loglevel):
"""
Get the logrecords for the result, removing any duplicates
:param result: a result containing logrecords
:param loglevel: the loglevel to match
:return:
"""
try:
# WeakProxy is registered as an Iterable (and a Container, Hashable, etc.)
# so we can't check for isinstance(result, collections.abc.Iterable)
# see https://bugs.python.org/issue24067
_ = iter(result)
except TypeError:
if not hasattr(result, 'logrecords'):
return []
records = [l for l in result.logrecords if l.levelno is loglevel]
else:
# note that flatten returns a generator, which empties after
# traversal. we convert to a list to allow multiple traversals
g = conversion.flatten([get_logrecords(r, loglevel) for r in result])
records = list(g)
# append the message target to the LogRecord so we can link to the
# matching page in the web log
try:
target = os.path.basename(result.inputs['vis'])
for r in records:
r.target = {'vis': target}
except:
pass
dset = set()
# relies on the fact that dset.add() always returns None.
return [r for r in records if
r.msg not in dset and not dset.add(r.msg)]
[docs]
def get_intervals(context, calapp, spw_ids=None):
"""
Get the integration intervals for scans processed by a calibration.
The scan and spw selection is formed through inspection of the
CalApplication representing the calibration.
:param context: the pipeline context
:param calapp: the CalApplication representing the calibration
:param spw_ids: a set of spw IDs to get intervals for. Leave as None to
use all spws specified in the CalApplication.
:return: a list of datetime objects representing the unique scan intervals
"""
# With the advent of session calibrations, the target MS for the
# calibration may be different from the MS used to calculate the
# calibration. Therefore, we must look to the calapp.origin, which
# refers to the originating calls, to calculate the true values.
vis = {o.inputs['vis'] for o in calapp.origin}
assert (len(vis) == 1)
vis = vis.pop()
ms = context.observing_run.get_ms(vis)
from_intent = {o.inputs['intent'] for o in calapp.origin}
assert (len(from_intent) == 1)
from_intent = from_intent.pop()
# let CASA parse spw arg in case it contains channel spec
if not spw_ids:
task_spw_args = {o.inputs['spw'] for o in calapp.origin}
spw_arg = ','.join(task_spw_args)
spw_ids = {spw_id for (spw_id, _, _, _) in conversion.spw_arg_to_id(vis, spw_arg, ms.get_spectral_windows)}
# from_intent is given in CASA intents, ie. *AMPLI*, *PHASE*
# etc. We need this in pipeline intents.
pipeline_intent = conversion.to_pipeline_intent(ms, from_intent)
scans = ms.get_scans(scan_intent=pipeline_intent)
# scan with intent may not have data for the spws used in the
# gaincal call, eg. X20fb. Only get the solint for spws in the call
# by using the intersection.
all_solints = {scan.mean_interval(spw_id)
for scan in scans
for spw_id in spw_ids.intersection({spw.id for spw in scan.spws})}
return all_solints
# ms = context.observing_run.get_ms(calapp.vis)
#
# from_intent = calapp.origin.inputs['intent']
# # from_intent is given in CASA intents, ie. *AMPLI*, *PHASE*
# # etc. We need this in pipeline intents.
# pipeline_intent = to_pipeline_intent(ms, from_intent)
# scans = ms.get_scans(scan_intent=pipeline_intent)
#
# # let CASA parse spw arg in case it contains channel spec
# if not spw_ids:
# spw_ids = set([spw_id for (spw_id, _, _, _)
# in spw_arg_to_id(calapp.vis, calapp.spw)])
#
# all_solints = set()
# for scan in scans:
# scan_spw_ids = set([spw.id for spw in scan.spws])
# # scan with intent may not have data for the spws used in
# # the gaincal call, eg. X20fb, so only get the solint for
# # the intersection
# solints = [scan.mean_interval(spw_id)
# for spw_id in spw_ids.intersection(scan_spw_ids)]
# all_solints.update(set(solints))
#
# return all_solints
[docs]
def table_to_html(table, tableclass='table table-bordered table-striped table-condensed', rotate=False):
"""Convert a astropy.table.Table object to an HTML table snippet."""
if rotate:
table_rows = [table.colnames]+list(table.as_array())
table_rotate = astropy.table.QTable(rows=list(zip(*table_rows)))
table_html = table_rotate.pformat(html=True, max_width=-1, tableclass=tableclass, show_name=False)
else:
table_html = table.pformat(html=True, max_width=-1, tableclass=tableclass, show_name=True)
table_html = '\n'.join([html.unescape(line) for line in table_html])
return table_html
[docs]
def plots_to_html(plots, title=None, alt=None, caption=None, group=None,
align='middle', width='auto', height='auto', report_dir='./'):
"""Convert a list of plots to HTML snippets.
examples:
plots_to_html(plots, caption=None, width='400px', height='300px')
notes:
the generated snippet requires lazyload.
"""
def desc_lookup(plot, key, value=None):
"""Get a plot description value from the plot object attribute or parameters dictionary.
The order of precedence:
non-None input >
matching plot parameters dict key >
attribute with the same name
"""
ret_value = ''
if hasattr(plot, key):
ret_value = getattr(plot, key)
if hasattr(plot, 'parameters') and key in plot.parameters:
ret_value = plot.parameters[key]
if value is not None:
ret_value = value
return ret_value
plots_html = []
for plot in plots:
fullsize_relpath = os.path.relpath(plot.abspath, report_dir)
thumbnail_relpath = os.path.relpath(plot.thumbnail, report_dir)
html_args = {
'fullsize': fullsize_relpath,
'thumbnail': thumbnail_relpath,
'title': desc_lookup(plot, 'title', title),
'caption': desc_lookup(plot, 'caption', caption),
'alt': desc_lookup(plot, 'alt', alt),
'group': desc_lookup(plot, 'group', group),
'width': width,
'height': height,
'align': align,
}
html = ('<a href="{fullsize}"'
' title="{title}"'
' data-fancybox="{group}"'
' data-caption="{caption}">'
' <img data-src="{thumbnail}"'
' style="width:{width};height:{height}"'
' title="{title}"'
' alt="{alt}"'
' align="{align}"'
' class="lazyload img-responsive">'
'</a>'.format(**html_args))
plots_html.append(html)
return plots_html
[docs]
def scale_uv_range(ms: MeasurementSet) -> tuple[Distance, str]:
"""
Return the UV range that captures the inner half of the data.
This function returns a 2-tuple with first value set to the UV range
as a domain object, and the second value set to a uvrange constraint
suitable for use in plotms calls.
:param ms: measurement set to analyse
:return: tuple of (UV range, plotms constraint)
"""
# method=higher is preferred over method=nearest to ensure that the upper
# limit data point is included in the selection
uv_max = np.percentile(ms.antenna_array.baselines_m, 50, method='higher')
uv_range = f"<{uv_max}"
# domain.measures classes must be imported at runtime to avoid a circular
# dependency
from pipeline.domain.measures import Distance
from pipeline.domain.measures import DistanceUnits
return Distance(value=uv_max, units=DistanceUnits.METRE), uv_range
[docs]
def split_spw(spw_string: str) -> str:
"""
Splits an SPW string on '#' characters while retaining them, and inserts a <br/>
before the 'ALMA' portion to improve readability. Returns the original string
if 'ALMA' is not found.
Example:
Input: "X100001#X900000004#ALMA_RB_06#BB_1#SW-01#FULL_RES"
Output: "X100001#<br/>X900000004#<br/>ALMA_RB_06#BB_1#SW-01#FULL_RES"
Args:
spw_string: The original SPW string containing '#' delimiters.
Returns:
The reformatted SPW string with <br/> inserted before 'ALMA'.
"""
parts = []
current = ""
for segment in spw_string.split('#'):
if current:
parts.append(current + '#')
current = segment
parts.append(current)
for i, part in enumerate(parts):
if 'ALMA' in part:
return "<br/>".join(parts[:i] + ["".join(parts[i:])])
return spw_string
[docs]
def get_directory_size(directory):
"""
Calculate the total size of a directory in megabytes (MB).
This includes all files in the directory and its subdirectories.
Symlinks are ignored to avoid counting linked files or causing errors.
Parameters:
directory (str): Path to the target directory.
Returns:
float: Total size of the directory in megabytes.
"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
if not os.path.islink(filepath):
try:
total_size += os.path.getsize(filepath)
except OSError as e:
# make it log.warning
print(f"Error accessing {filepath}: {e}")
return total_size / (1024 * 1024)