"""general-purpose uncategorised utility functions and classes."""
from __future__ import annotations
import ast
import bisect
import collections
import contextlib
import copy
import errno
import fcntl
import glob
import inspect
import itertools
import json
import operator
import os
import pickle
import re
import shutil
import string
import tarfile
import time
from collections.abc import Iterable
from datetime import datetime, timezone
from functools import wraps
from numbers import Number
from typing import TYPE_CHECKING
from urllib.parse import urlparse
import numpy as np
from pipeline import infrastructure
from pipeline.infrastructure import casa_tools
from .conversion import commafy, dequote, range_to_list
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterator, Sequence
from io import TextIOWrapper
from typing import Any, DefaultDict, OrderedDict
from numpy import generic
from numpy.typing import NDArray
from pipeline.domain import Field, MeasurementSet
from pipeline.infrastructure.filenamer import PipelineProductNameBuilder
from pipeline.infrastructure.launcher import Context
ConditionType = Callable | dict[str, dict[str, dict[str, Any]]]
LOG = infrastructure.logging.get_logger(__name__)
__all__ = [
'absolute_path',
'approx_equal',
'are_equal',
'build_refantignore',
'clear_time_cache',
'compute_zenith_distance',
'deduplicate',
'dict_merge',
'ensure_products_dir_exists',
'export_weblog_as_tar',
'fieldname_clean',
'fieldname_for_casa',
'filter_intents_for_ms',
'find_ranges',
'flagged_intervals',
'get_casa_quantity',
'get_field_identifiers',
'get_obj_size',
'get_products_dir',
'get_receiver_type_for_spws',
'get_row_count',
'get_si_prefix',
'get_spectralspec_to_spwid_map',
'get_stokes',
'get_task_result_count',
'get_taskhistory_fromimage',
'get_valid_url',
'glob_ordered',
'ignore_pointing',
'imstat_items',
'list_to_str',
'nested_dict',
'obs_long_lat',
'obs_midtime',
'open_with_lock',
'place_repr_source_first',
'relative_path',
'remove_trailing_string',
'string_to_val',
'validate_url',
]
# Import TypedDict definitions from centralized module for type checking only
if TYPE_CHECKING:
from pipeline.infrastructure.utils.casa_types import DirectionDict, EpochDict, QuantityDict
[docs]
def find_ranges(data: str | list[int]) -> str:
"""Identify numeric ranges in string or list.
This utility function takes a string or a list of integers (e.g. spectral
window lists) and returns a string containing identified ranges.
Examples:
>>> find_ranges([1,2,3])
'1~3'
>>> find_ranges('1,2,3,5~12')
'1~3,5~12'
"""
if isinstance(data, str):
# barf if channel ranges are also in data, eg. 23:1~10,24
if ':' in data:
return data
data = range_to_list(data)
if len(data) == 0:
return ''
try:
integers = [int(d) for d in data]
except ValueError:
return ','.join(data)
s = sorted(integers)
ranges = []
for _, g in itertools.groupby(enumerate(s), lambda i_x: i_x[0] - i_x[1]):
rng = list(map(operator.itemgetter(1), g))
if len(rng) == 1:
ranges.append('%s' % rng[0])
else:
ranges.append('%s~%s' % (rng[0], rng[-1]))
return ','.join(ranges)
[docs]
def dict_merge(a: dict, b: dict | Any) -> dict:
"""Recursively merge dictionaries.
This utility function recursively merges dictionaries. If second argument
(b) is a dictionary, then a copy of first argument (dictionary a) is created
and the elements of b are merged into the new dictionary. Otherwise return
argument b.
Examples:
>>> dict_merge({'a': {'b': 1}}, {'c': 2})
{'a': {'b': 1}, 'c': 2}
"""
if not isinstance(b, dict):
return b
result = copy.deepcopy(a)
for k, v in b.items():
if k in result and isinstance(result[k], dict):
result[k] = dict_merge(result[k], v)
else:
result[k] = copy.deepcopy(v)
return result
[docs]
def are_equal(a: list | NDArray[generic], b: list | NDArray[generic]) -> bool:
"""Return True if the contents of the given arrays are equal.
This utility function check the equivalence of array like objects. Two arrays
are equal if they have the same number of elements and elements of the same
index are equal.
Examples:
>>> are_equal([1, 2, 3], [1, 2, 3])
True
>>> are_equal([1, 2, 3], [1, 2, 3, 4])
False
"""
return len(a) == len(b) and len(a) == sum([1 for i, j in zip(a, b) if i == j])
[docs]
def approx_equal(x: float, y: float, tol: float = 1e-15) -> bool:
"""Return True if two numbers are equal within the given tolerance.
This utility function returns True if two numbers are equal within the
given tolerance.
Examples:
>>> approx_equal(1.0e-2, 1.2e-2, 1e-2)
True
>>> approx_equal(1.0e-2, 1.2e-2, 1e-3)
False
"""
lo = min(x, y)
hi = max(x, y)
return (lo + 0.5 * tol) >= (hi - 0.5 * tol)
[docs]
def flagged_intervals(vec: list | NDArray[generic]) -> list:
"""Idendity isnads of ones in input array or list.
This utility function finds islands of ones in array or list provided in argument.
Used to find contiguous flagged channels in a given spw. Returns a list of
tuples with the start and end channels.
Examples:
>>> flagged_intervals([0, 1, 0, 1, 1])
[(1, 1), (3, 4)]
"""
if len(vec) == 0:
return []
elif not isinstance(vec, np.ndarray):
vec = np.array(vec)
edges, = np.nonzero(np.diff((vec == True) * 1))
edge_vec = [edges + 1]
if vec[0] != 0:
edge_vec.insert(0, [0])
if vec[-1] != 0:
edge_vec.append([len(vec)])
edges = np.concatenate(edge_vec)
return list(zip(edges[::2], edges[1::2] - 1))
[docs]
def fieldname_for_casa(field: str) -> str:
"""Prepare field string to be used as CASA argument.
This utility function ensures that field string can be used as CASA argument.
If field contains special characters, then return field string enclose in
quotation marks, otherwise return unchanged string.
Examples:
>>> fieldname_for_casa('helm=30')
'"helm=30"'
"""
if field.isdigit() or field != fieldname_clean(field):
return '"{0}"'.format(field)
return field
[docs]
def fieldname_clean(field: str) -> str:
"""Indicate if the fieldname is allowed as-is.
This utility function replaces special characters in string with underscore.
The return string is used in fieldname_for_casa() function to determine
whether the field name, when given as a CASA argument, should be enclosed
in quotes.
Examples:
>>> fieldname_clean('helm=30')
'helm_30'
"""
allowed = string.ascii_letters + string.digits + '+-'
return ''.join([c if c in allowed else '_' for c in field])
[docs]
def filter_intents_for_ms(ms: MeasurementSet, intents: str) -> str:
"""
Filter string of comma-separated intents to keep only those that are present
in the given MS.
"""
intents = set(intents.split(','))
removed_intents = intents - ms.intents
if removed_intents:
LOG.debug(f"The following intents are not present in the MS {ms.basename} and will be skipped:"
f" {commafy(removed_intents)}")
intents.intersection_update(ms.intents)
return ','.join(sorted(intents))
def get_field_accessor(ms: MeasurementSet, field: Field) -> operator.attrgetter:
"""Returns accessor to field name or field ID, if field name is ambiguous.
"""
fields = ms.get_fields(name=field.name)
if len(fields) == 1:
return operator.attrgetter('name')
def accessor(x):
return str(operator.attrgetter('id')(x))
return accessor
[docs]
def get_field_identifiers(ms: MeasurementSet) -> dict[int, str | int]:
"""Maps numeric field IDs to field names.
Get a dict of numeric field ID to unambiguous field identifier, using the
field name where possible and falling back to numeric field ID where the
name is duplicated, for instance in mosaic pointings.
"""
field_name_accessors = {field.id: get_field_accessor(ms, field) for field in ms.fields}
return {field.id: field_name_accessors[field.id](field) for field in ms.fields}
[docs]
def get_receiver_type_for_spws(ms: MeasurementSet, spwids: Sequence) -> dict[int, str]:
"""Return dictionary of receiver types for requested spectral window IDs.
If spwid is not found in MeasurementSet instance, then detector type is
set to "N/A".
Args:
ms: MeasurementSet to query for receiver types.
spwids: list of spw ids (integers) to query for.
Returns:
A dictionary assigning receiver types as values to spwid keys.
"""
rxmap = {}
for spwid in spwids:
spw = ms.get_spectral_windows(spwid, science_windows_only=False)
if not spw:
rxmap[spwid] = "N/A"
else:
rxmap[spwid] = spw[0].receiver
return rxmap
[docs]
def get_spectralspec_to_spwid_map(spws: Collection) -> DefaultDict[str | None, list[int]]:
"""
Returns a dictionary of spectral specs mapped to corresponding spectral
window IDs for requested list of spectral window objects.
:param spws: list of spectral window objects
:return: dictionary with spectral spec as keys, and corresponding
list of spectral window IDs as values.
"""
spwmap = collections.defaultdict(list)
for spw in sorted(spws, key=lambda s: s.id):
spwmap[spw.spectralspec].append(spw.id)
return spwmap
[docs]
def imstat_items(
image: Any,
items: list[str] = ['min', 'max'],
mask: str | None = None,
) -> OrderedDict[str, Any]:
"""Extract desired stats properties (per Stokes) using ia.statistics().
Beside the standard output, some additional stats property keys are supported.
Note: 'image' is expected to an instance of CASA ia tool.
"""
imstats = image.statistics(robust=True, axes=[0, 1, 3], mask=mask)
stats = collections.OrderedDict()
for item in items:
if item.lower() == 'madrms':
stats['madrms'] = imstats['medabsdevmed']*1.4826 # see CAS-9631
elif item.lower() == 'max/madrms':
stats['max/madrms'] = imstats['max']/imstats['medabsdevmed']*1.4826 # see CAS-9631
elif item.lower() == 'maxabs':
stats['maxabs'] = np.maximum(np.abs(imstats['max']), np.abs(imstats['min']))
elif 'pct<' in item:
threshold = float(item.replace('pct<', ''))
imstats_threshold = image.statistics(robust=True, axes=[0, 1, 3], includepix=[0, threshold], mask=mask)
if len(imstats_threshold['npts']) == 0:
# if no pixel is selected from the restricted pixel value range, the return of ia.statitics() would be empty.
imstats_threshold['npts'] = np.zeros(4)
stats[item] = imstats_threshold['npts']/imstats['npts']
elif item.lower() == 'pct_masked':
im_shape = (imstats['trc']-imstats['blc'])+1
stats[item] = 1.-imstats['npts']/im_shape[0]/im_shape[1]
elif item.lower() == 'peak': # Here 'peak' means the pixel value with largest deviation from zero.
stats[item] = np.where(np.abs(imstats['max']) > np.abs(imstats['min']), imstats['max'], imstats['min'])
elif item.lower() == 'peak/madrms':
peak = np.where(np.abs(imstats['max']) > np.abs(imstats['min']), imstats['max'], imstats['min'])
madrms = imstats['medabsdevmed']*1.4826 # see CAS-9631
stats['peak/madrms'] = peak/madrms
else:
stats[item] = imstats[item.lower()]
return stats
[docs]
def get_stokes(imagename: str) -> list[str]:
"""Get the labels of all stokes planes present in a CASA image."""
with casa_tools.ImageReader(imagename) as image:
cs = image.coordsys()
stokes_labels = cs.stokes()
stokes_present = [stokes_labels[idx] for idx in range(image.shape()[2])]
cs.done()
return stokes_present
[docs]
def get_casa_quantity(value: None | dict | str | float | int) -> QuantityDict:
"""Wrapper around quanta.quantity() that handles None input.
Starting with CASA 6, quanta.quantity() no longer accepts None as input. This
utility function handles None values when calling CASA quanta.quantity() tool
method.
Returns:
A CASA quanta.quantity (dictionary)
Examples:
>>> get_casa_quantity(None)
{'unit': '', 'value': 0.0}
>>> get_casa_quantity('10klambda')
{'unit': 'klambda', 'value': 10.0}
"""
if value is not None:
return casa_tools.quanta.quantity(value)
else:
return casa_tools.quanta.quantity(0.0)
[docs]
def get_si_prefix(value: float, select: str = 'mu', lztol: int = 0) -> tuple[str, float]:
"""Obtain the best SI unit prefix option for a numeric value.
A "best" SI prefix from a specified prefix collection is defined by minimizing :
* leading zeros (possibly to a specified tolerance limit, see `lztol`)
* significant digits before the decimal point
, after the prefix is applied.
Args:
value: the numerical value for picking the prefix.
select: SI prefix candidates, a substring of "yzafpnum kMGTPEZY".
Defaults to 'mu'.
lztol: leading zeros tolerance.
Defaults to 0 (avoid any leading zeros when possible).
Returns:
tuple: (prefix_string, prefix_scale)
Examples:
e.g. for frequency value in Hz
>>> get_si_prefix(10**7,select='kMGT')
('M', 1000000.0)
e.g. for flux value in Jy
>>> get_si_prefix(1.0,select='um')
('', 1.0)
>>> get_si_prefix(0.0,select='um')
('', 1.0)
>>> get_si_prefix(-0.9,select='um')
('m', 0.001)
>>> get_si_prefix(0.9,select='um',lztol=1)
('', 1.0)
>>> get_si_prefix(1e-7,select='um')
('u', 1e-06)
>>> get_si_prefix(1e3,select='um')
('', 1.0)
"""
if value == 0:
return '', 1.0
else:
sp_tab = "yzafpnum kMGTPEZY"
sp_list, sp_pow = zip(*[(p, (idx-8)*3.0)
for idx, p in enumerate(sp_tab) if p in select+' '])
idx = bisect.bisect(sp_pow, np.log10(abs(value))+lztol)
idx = max(idx-1, 0)
return sp_list[idx].strip(), 10.**sp_pow[idx]
[docs]
def absolute_path(name: str) -> str:
"""Return an absolute path of a given file."""
return os.path.abspath(os.path.expanduser(os.path.expandvars(name)))
[docs]
def relative_path(name: str, start: str | None = None) -> str:
"""
Retun a relative path of a given file with respect a given origin.
Args:
name: A path to file.
start: An origin of relative path. If the start is not given, the
current directory is used as the origin of relative path.
Examples:
>>> relative_path('/root/a/b.txt', '/root/c')
'../a/b.txt'
>>> relative_path('../a/b.txt', './c')
'../../a/b.txt'
"""
if start is not None:
start = absolute_path(start)
return os.path.relpath(absolute_path(name), start)
[docs]
def get_task_result_count(context: Context, taskname: str = 'hif_makeimages') -> int:
"""Count occurrences of a task result in the context.results list.
Loop over the content of the context.results list and compare taskname to the pipeline_casa_task
attribute of each result object. Increase counter if taskname substring is found in the attribute.
The order number is determined by counting the number of previous execution of
the task, based on the content of the context.results list. The introduction
of this method is necessary because VLASS-SE-CONT imaging happens in multiple
stages (hif_makeimages calls). Imaging parameters change from stage to stage,
therefore it is necessary to know what is the current stage ordinal number.
"""
count = 0
for r in context.results:
# Work around the fact that r has read() method in some cases (e.g. editimlist)
# but not in others (e.g. in tclean renderer)
try:
if taskname in r.read().pipeline_casa_task:
count += 1
except AttributeError:
if taskname in r.pipeline_casa_task:
count += 1
return count
[docs]
def place_repr_source_first(itemlist: list[str] | list[tuple], repr_source: str) -> list[str] | list[tuple]:
"""
Place representative source first in a list of source names
or tuples with source name as first tuple element.
"""
try:
itemtype = type(itemlist[0])
if itemtype is str:
repr_source_index = [dequote(item) for item in itemlist].index(dequote(repr_source))
elif itemtype is tuple or itemtype is list:
repr_source_index = [dequote(item[0]) for item in itemlist].index(dequote(repr_source))
else:
raise Exception('Cannot handle items of type {}'.format(itemtype))
repr_source_entry = itemlist.pop(repr_source_index)
itemlist = [repr_source_entry] + itemlist
except ValueError:
LOG.warning('Could not reorder field list to place representative source first')
return itemlist
[docs]
def get_taskhistory_fromimage(imagename: str):
"""Retrieve past CASA/tclean() call parameters from the image history.
Note: the tclean history is only added to images/logtable in CASA ver>=6.2 (see CAS-13247)
For tclean products generated by earlier CASA versions, an empty list will be returned.
"""
taskhistory_list = []
with casa_tools.ImageReader(imagename) as image:
history_list = image.history(list=False)
is_fromtask = False
for line in history_list:
if 'taskname' in line and '=' in line:
is_fromtask = True
k, v = line.partition('=')[::2]
k = k.strip()
v = v.strip()
taskhistory_list.append(collections.OrderedDict([('taskname', v), ('taskversion', 'unkown')]))
continue
if 'version:' in line and 'CASAtools:' in line and is_fromtask:
taskhistory_list[-1]['taskversion'] = line
continue
if '=' in line and is_fromtask:
k, v = line.partition('=')[::2]
k = k.strip()
v = v.strip()
taskhistory_list[-1][k] = ast.literal_eval(v)
else:
is_fromtask = False
LOG.info(f'Found {len(taskhistory_list)} task history entry/entries from {imagename}')
return taskhistory_list
[docs]
def get_obj_size(obj: Any, serialize: bool = True) -> int:
"""Estimate the size of a Python object.
If serialize is True, returns the size of the serialized object. Note that this is NOT
the same as the object size in memory.
When serialize is False, returns the memory consumption of the object via the asizeof
method of Pympler (https://pypi.org/project/Pympler).
An alternative is the get_deep_size() function from objsize (https://pypi.org/project/objsize).
The serialization-based approach was a fallback solution for the issues described in PIPE-1698/PIPE-2877.
The `asize.py` bug described in PIPE-2877 remain unresolved as of Pympler ver 1.1.
See the GH issues for details and PIPE-1698 and PIPE-2877 for background:
https://github.com/pympler/pympler/issues/155
https://github.com/pympler/pympler/issues/151
Args:
obj: The Python object to measure.
serialize: If True, measure serialized size; if False, measure memory size using Pympler.
Returns:
The size of the object in bytes.
Raises:
Exception: If serialize is False and Pympler is not installed.
"""
if serialize:
return len(pickle.dumps(obj, protocol=-1))
try:
from pympler.asizeof import asizeof
return asizeof(obj)
except ImportError as err:
LOG.debug('Pympler import failed: %s', err)
raise ModuleNotFoundError(
'Pympler is required for in-memory size calculation. Please install it with: pip install pympler'
) from err
[docs]
def glob_ordered(pattern: str, *args, order: str | None = None, **kwargs) -> list[str]:
"""Return a sorted list of paths matching a pathname pattern."""
path_list = glob.glob(pattern, *args, **kwargs)
if order == 'mtime':
path_list.sort(key=os.path.getmtime)
elif order == 'ctime':
path_list.sort(key=os.path.getctime)
else:
if order is not None:
LOG.warning("Unknown sorting order requested: order=%r. Only 'mtime', 'ctime', or None is allowed.", order)
LOG.warning("We will use the default alphabetically/numerically ascending order (order=None) instead.")
path_list = sorted(path_list)
return path_list
[docs]
def deduplicate(items: Iterable) -> list:
"""Remove duplicate entries from a list, but preserve the order.
Note that the use of list(set(x)) can cause random order in the output.
The return of this function is guaranteed to be in the order that unique items show up in the input, unlike
a deduplicate-resorting solution like sorted(set(x).
Ref: https://stackoverflow.com/questions/480214/how-do-i-remove-duplicates-from-a-list-while-preserving-order
This solution only works for Python 3.7+.
"""
deduplicated_items = list(dict.fromkeys(items))
return deduplicated_items
[docs]
@contextlib.contextmanager
def ignore_pointing(vis: str | list[str] | set[str]):
"""A context manager to ignore pointing tables of MSes during I/O operations.
The original pointing table will be temperarily renamed to POINTING_ORIGIN, and a new empty pointing table
is created. When the context manager exits, the original table is restored.
For example, to ignore the pointing table of a MS during mstransform() calls, use:
with ignore_pointing('test.ms'):
casatasks.mstransform(vis='test.ms',outputvis='test_output.ms',scan='16',datacolumn='data')
The pointing table of the output MS should be empty.
On the other hand, if the pointing table is needed in the output vis, e.g. for imaging with tclean(usepointing=True),
we can manually create hardlinks of pointing table afterwards while minimizing the disk space usage:
import shutil, os
shutil.rmtree('test_small.ms/POINTING')
shutil.copytree('test.ms/POINTING', 'test_output.ms/POINTING', copy_function=os.link)
One can verify the inodes of the pointing table files, which should be the same:
ls -lih test.ms/POINTING
ls -lih test_small.ms/POINTING
"""
if isinstance(vis, (list, set)):
vis_list = vis
else:
vis_list = [vis]
vis_list_ignore = []
try:
for ms in vis_list:
if not os.path.isdir(ms+'/POINTING') and not os.path.isdir(ms+'/POINTING_ORIGIN'):
LOG.warning(f'No pointing table found in {ms}.')
continue
vis_list_ignore.append(ms)
if not os.path.isdir(ms+'/POINTING_ORIGIN'):
LOG.info(f'backup the pointing table for {ms}')
shutil.move(ms+'/POINTING', ms+'/POINTING_ORIGIN')
with casa_tools.TableReader(ms+'/POINTING_ORIGIN', nomodify=True) as table:
tabdesc = table.getdesc()
dminfo = table.getdminfo()
if os.path.isdir(ms+'/POINTING'):
shutil.rmtree(ms+'/POINTING')
LOG.info(f'empty the pointing table for {ms}')
tb = casa_tools.table
tb.create(ms+'/POINTING', tabdesc, dminfo=dminfo)
tb.close()
yield
finally:
for ms in vis_list_ignore:
if os.path.isdir(ms+'/POINTING_ORIGIN'):
if os.path.isdir(ms+'/POINTING'):
shutil.rmtree(ms+'/POINTING')
LOG.info(f'restore the pointing table for {ms}')
shutil.move(ms+'/POINTING_ORIGIN', ms+'/POINTING')
[docs]
@contextlib.contextmanager
def open_with_lock(filename: str, mode: str = 'r', *args: Any, **kwargs: Any) -> Iterator[TextIOWrapper]:
"""Open a file with an exclusive lock.
This context manager attempts to acquire an exclusive lock on the file upon opening.
Other processes using `open_with_lock` will wait until the lock is automatically released
when exiting the context.
Args:
filename: Path to the file to open and lock.
mode: File open mode (e.g., 'rt', 'wt', 'a', 'rb').
*args: Additional positional arguments to the built-in `open()` function.
**kwargs: Additional keyword arguments to the built-in `open()` function.
Yields:
The opened file object with an exclusive lock acquired (if supported by the filesystem).
Notes:
All file open calls with `open_with_lock` will block other `open_with_lock` usages
(even from different processes/Lustre clients) from read/write until the lock is
released. However, this locking/blocking mechanism does not prevent file access
(which could potentially cause IO errors) from other file openers (e.g., the
built-in `open()` function).
**Filesystem Support:**
The Python `fcntl` API's behavior depends on the underlying OS/storage
implementation: https://docs.python.org/3/library/fcntl.html. Not all
OS/file systems fully support file locking.
- **Lustre**: Requires mount option `-o flock`. Check with: `mount -l | grep lustre`
e.g. : `192.168.1.30@o2ib:/aoclst03 on /.lustre/aoc type lustre (rw,flock,user_xattr,lazystatfs)`
- **NFS**: Support varies by configuration (see PIPE-2051 for details)
- **Local filesystems**: Generally well-supported on Unix-like systems
**Testing Lock Behavior:**
To verify exclusive locking works on your system, run this from multiple processes:
```python
import fcntl
with open(filename, 'w') as fd:
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) # Should block/fail on second process
```
"""
needs_truncate = 'w' in mode
is_binary = 'b' in mode
base_mode = 'r+b' if is_binary else 'r+'
try:
fd = open(filename, base_mode, *args, **kwargs)
except FileNotFoundError:
base_mode = 'w+b' if is_binary else 'w+'
fd = open(filename, base_mode, *args, **kwargs)
LOG.debug('Attempting to acquire lock on %s', filename)
lock_acquired = False
try:
try:
# Acquire an exclusive lock on the file
fcntl.flock(fd, fcntl.LOCK_EX)
lock_acquired = True
LOG.debug('Successfully acquired lock on %s', filename)
except OSError as ex:
LOG.warning('Failed to acquire file lock for %s due to filesystem limitation, '
'which might cause racing conditions if multiple processes access '
'the file simultaneously: %s', filename, ex)
if lock_acquired and needs_truncate:
fd.truncate(0)
fd.seek(0)
LOG.debug('Truncated file after acquiring lock: %s', filename)
yield fd
finally:
if lock_acquired:
try:
fcntl.flock(fd, fcntl.LOCK_UN)
LOG.debug('Successfully released lock on %s', filename)
except OSError as ex:
LOG.warning('Failed to release file lock for %s due to filesystem limitation, '
'which might cause racing conditions if multiple processes access '
'the file simultaneously: %s', filename, ex)
fd.close()
[docs]
def ensure_products_dir_exists(products_dir: str) -> None:
try:
LOG.trace(f"Creating products directory: {products_dir}")
os.makedirs(products_dir)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
[docs]
def export_weblog_as_tar(context: Context, products_dir: str, name_builder: PipelineProductNameBuilder) -> str:
# Construct filename prefix from oussid and recipe name if available.
prefix = context.get_oussid()
recipe_name = context.get_recipe_name()
if recipe_name:
prefix = prefix + '.' + recipe_name
# Construct filename for weblog output tar archive.
tarfilename = name_builder.weblog(project_structure=context.project_structure,
ousstatus_entity_id=prefix)
# Save weblog directory to tar archive.
LOG.info(f"Saving weblog in {tarfilename}")
tar = tarfile.open(os.path.join(products_dir, tarfilename), "w:gz")
tar.add(os.path.join(os.path.basename(os.path.dirname(context.report_dir)), 'html'))
tar.close()
return tarfilename
[docs]
def get_products_dir(context: Context) -> str:
if context.products_dir is None:
return os.path.abspath('./')
else:
return context.products_dir
class pl_defaultdict(collections.defaultdict):
def __repr__(self) -> str:
return str(dict(self))
def as_plain_dict(self) -> dict:
return to_plain_dict(self)
[docs]
def nested_dict() -> dict:
return pl_defaultdict(nested_dict)
def to_plain_dict(default_dict: dict) -> dict:
plain_dict = dict()
for k, v in default_dict.items():
if isinstance(v, collections.defaultdict):
plain_dict[k] = to_plain_dict(v)
else:
plain_dict[k] = v
return plain_dict
[docs]
def string_to_val(s: str) -> Any:
"""
Convert a string to a Python data type.
"""
try:
pyobj = ast.literal_eval(s)
# PIPE-1030: prevent a string like "1,2,3" from being unexpectedly translated into tuple
if type(pyobj) is tuple and s.strip()[0] != '(':
pyobj = s
return pyobj
except ValueError:
return s
except SyntaxError:
return s
[docs]
def remove_trailing_string(s: str, t: str) -> str:
"""
Remove a trailing string if it exists.
"""
if s.endswith(t):
return s[:-len(t)]
else:
return s
def function_io_dumper(to_pickle: bool=True, to_json: bool=False, json_max_depth: int=5,
condition: ConditionType | None = None, timestamp: bool=True):
"""
Dump arguments and return-objects of a function implement the decolator into pickle files and/or JSON(-like) file.
This function is a helper method for development. It should not be used in production codes.
Usage:
@function_io_dumper()
def foobar(self, bar):
...
return ret
When foobar() is executed, the decolator makes a directory 'foobar.[timestamp]', then it dumps
all objects of arguments and return values of foobar() as pickle files and|or JSON files.
We can get the same behavior of foobar() with the pickles as when they were dumped:
with open('bar.pickle', 'rb') as f:
bar = pickle.load(f)
foobar(bar)
or, understand input/result of the function by JSON-like output. To avoid recursive or tremendous output,
it sets max depth of recursive (default to 5).
Args:
to_pickle (bool, optional): Dump pickle or not. Defaults to True.
to_json (bool, optional): Dump JSON-like file or not. Defaults to False.
json_max_depth (int, optional): Set depth of dumping JSON tree. Defaults to 5.
condition (ConditionType): Set callable function user defined, or conditions of
argument dumping. Defaults to None.
ex1) Callable condition
condition = my_condition
And, a function my_condition() must be defined as:
def my_condition(kwargs) -> bool:
# Serialize only if the keyword argument 'spw' is 10 or True.
return kwargs.get("spw") in (10, True)
ex2) Dict condition
condition = {'self', {'spw':10}}
timestamp (bool, optional): Add timestamp to the names of output directories. If not and the same function was
executed multiple times, the output directory will be overwritten. Defaults to True.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# create timestamp
epoch_time = time.time()
dt = datetime.fromtimestamp(epoch_time, tz=timezone.utc)
ns = int((epoch_time - int(epoch_time)) * 1_000_000_000)
_timestamp = dt.strftime(f'%Y%m%d-%H%M%S.{ns}')
# make signatures of args
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
try:
if callable(condition):
exec_dumpargs = condition(bound_args)
else:
exec_dumpargs = _eval_condition(condition, bound_args)
except Exception as e:
LOG.warning(f'Error evaluating condition callable: {e}')
exec_dumpargs = False
extention = ''
if timestamp:
extention = f'.{_timestamp}'
output_folder_name = f"{func.__name__}{extention}"
json_dict = None
if to_json:
json_dict = object_to_dict(bound_args.arguments, max_depth=json_max_depth)
try:
os.makedirs(output_folder_name, exist_ok=True)
LOG.info(f"Function '{_get_full_method_path(func)}' called at {_timestamp}")
for arg_name, arg_value in bound_args.arguments.items():
_dump(arg_value, output_folder_name, arg_name, dump_pickle=to_pickle, dump_json=to_json, json_dict=json_dict)
except pickle.PicklingError as e:
exec_dumpargs = False
LOG.warning(f'Contained unpickleable object: {e}')
except Exception as e:
exec_dumpargs = False
LOG.warning(f'Exception occurred: {e}')
result = func(*args, **kwargs)
if exec_dumpargs and result is not None:
try:
_name = f'{output_folder_name}.result'
if to_json:
json_dict = object_to_dict({_name:result}, max_depth=json_max_depth)
_dump({_name:result}, output_folder_name, _name, dump_pickle=to_pickle, dump_json=to_json, json_dict=json_dict)
except pickle.PicklingError as e:
LOG.warning(f'Contained unpickleable object: {e}')
except Exception as e:
LOG.warning(f'Exception occurred: {e}')
return result
return wrapper
return decorator
def _dump(obj: object, path: str, name: str, dump_pickle: bool=True, dump_json: bool=False, json_dict={}) -> None:
file_path = os.path.join(path, f'{name}')
if dump_pickle:
with open(file_path+'.pickle', 'wb') as f:
pickle.dump(obj, f)
if dump_json:
with open(file_path+'.json', 'w') as f:
f.write(json.dumps(json_dict[name], default=str))
def _get_full_method_path(func: Callable) -> str:
module_name = func.__module__
if hasattr(func, '__qualname__'):
qualname = func.__qualname__
else:
qualname = func.__name__
return f'{module_name}.{qualname}'
def _eval_condition(condition: dict | None, args: dict) -> bool:
# {'self', {'spw':10}}, need unittest
if condition is None:
return True
LOG.info(f'condition: {condition}')
for key, val in condition:
arg = args.get(key, False)
if arg:
for propname, propval in val:
if isinstance(arg, dict) and arg.get(propname, False):
obj = arg.getattr(propname)
if obj == propval:
return True
return False
def object_to_dict(obj: object, max_depth: int = 5, current_depth: int = 0) -> object | dict | None:
if current_depth > max_depth:
return None
if isinstance(obj, dict):
result = {}
for key, value in obj.items():
if hasattr(value, '__dict__') or isinstance(value, Iterable):
result[key] = object_to_dict(value, max_depth=max_depth, current_depth=current_depth + 1)
else:
result[key] = value
return result
elif isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
result = []
for item in obj:
if hasattr(item, '__dict__') or isinstance(item, Iterable):
result.append(object_to_dict(item, max_depth=max_depth, current_depth=current_depth + 1))
else:
result.append(item)
return result
elif hasattr(obj, '__dict__'):
result = {}
for key, value in obj.__dict__.items():
if hasattr(value, '__dict__') or isinstance(value, Iterable):
result[key] = object_to_dict(value, max_depth=max_depth, current_depth=current_depth + 1)
else:
result[key] = value
return result
else:
return obj
def decorate_io_dumper(cls: object, functions: list[str | None] = [], *args: Any, **kwargs: Any) -> None:
"""Apply function_io_dumper dynamically.
Usage:
...
import pipeline.infrastructure.utils.utils as ut
ut.decorate_io_dumper(SDInspection, ['execute'])
...
hsd_importdata() # -> dump args of SDInspection.execute()
Args:
cls (object): the class has functions to be decorated.
functions (list[str], optional): Function names to decodate. If not specified or
set empty list, then all functions of the class are decorated. Defaults to [].
"""
if len(functions) == 0:
_functions = inspect.getmembers(cls, predicate=inspect.isfunction)
else:
_functions = []
for _f in functions:
_c = _str_to_func(cls, _f)
if _c:
_functions.append(_c)
for _func in _functions:
_name = _func.__name__
decorated_func = function_io_dumper(*args, **kwargs)(_func)
setattr(cls, _name, decorated_func)
def _str_to_func(cls: object, _name: str) -> Callable | bool:
if hasattr(cls, _name):
_c = getattr(cls, _name)
if callable(_c):
return _c
return False
[docs]
def list_to_str(value: list[Number | str] | NDArray) -> str:
"""Convert list or numpy.ndarray into string.
The list/ndarray should be 1-dimensional. In that case, the function
returns comma-separated sequence of its elements. Otherwise it just
returns default string representation of the value, str(value).
Args:
value: 1-dimensional list or numpy array
Returns:
String representation of the value. If input value is in
compliance with the requirement, it will be comma-separated
sequence of elements.
"""
if isinstance(value, (list, np.ndarray)) \
and all(isinstance(x, (Number, str)) for x in value):
# use np.ndarray.tolist to ensure all the elements
# have Python builtin types
ret = str(np.asarray(value).tolist())
else:
ret = str(value)
return ret
[docs]
def validate_url(url: str) -> bool:
"""Validates whether a given URL is properly formatted.
This function checks if the URL follows a valid format using a regular expression.
It also ensures that the parsed URL contains both a scheme (e.g., "http" or "https")
and a network location (netloc), which are required components for a valid URL.
Args:
url: The URL to validate.
Returns:
True if the URL is valid, False otherwise.
"""
url_regex = re.compile(
r'^(https?:\/\/)?' # HTTP or HTTPS
r'([\w.-]+)' # Domain name or IP address
r'(:\d+)?' # Optional port
r'(\/[^\s]*)?$', # Optional path
re.IGNORECASE
)
if not url_regex.match(url):
return False
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc])
[docs]
def obs_long_lat(observatory: str) -> tuple[QuantityDict, QuantityDict]:
"""Return longitude and latitude values of the given observatory."""
observatory = casa_tools.measures.observatory(observatory)
return observatory['m0'], observatory['m1']
[docs]
def obs_midtime(start_time: datetime, end_time: datetime) -> EpochDict:
"""Returns the mid time in a CASA measures dictionary."""
mid_time = start_time + (end_time - start_time) / 2
return casa_tools.measures.epoch('utc', mid_time.isoformat())
[docs]
def compute_zenith_distance(
field_direction: DirectionDict,
epoch: EpochDict,
observatory: str,
coordinate_frame: str = 'AZELGEO',
) -> QuantityDict:
"""Calculate zenith distance for a field at a given time and observatory.
This function uses CASA measures to calculate the zenith distance, similar to
how compute_az_el_to_field works in htmlrenderer.py.
Args:
field_direction: CASA direction measure dictionary for the field.
epoch: CASA epoch measure dictionary for the observation time.
observatory: Name of the observatory (e.g., 'VLA', 'ALMA').
coordinate_frame: Coordinate frame for the calculation (e.g., 'AZELGEO', 'AZEL').
Defaults to 'AZELGEO'.
Returns:
A CASA quantity dictionary containing the zenith distance in radians.
Examples:
>>> direction = {'m0': {'value': 4.18879, 'unit': 'rad'},
... 'm1': {'value': 0.58534, 'unit': 'rad'},
... 'refer': 'J2000', 'type': 'direction'}
>>> epoch = {'m0': {'value': 58089.82, 'unit': 'd'}, 'refer': 'UTC', 'type': 'epoch'}
>>> zd = compute_zenith_distance(direction, epoch, 'VLA')
>>> print(f"{casa_tools.quanta.convert(zd, 'deg')['value']:.2f} degrees")
"""
me = casa_tools.measures
# Set the reference frame with the epoch and observatory
me.doframe(epoch)
me.doframe(me.observatory(observatory))
# Convert field direction to horizontal coordinates
horizontal = me.measure(field_direction, coordinate_frame)
elevation_rad = horizontal['m1']['value']
zenith_distance_rad = np.pi / 2.0 - elevation_rad
return casa_tools.quanta.quantity(zenith_distance_rad, 'rad')
[docs]
def get_row_count(table_name: str, taql: str) -> int:
"""Return the number of rows in the specified table that match the given TaQL query.
Parameters:
table_name: Path to the CASA table.
taql: The TaQL query string used to filter the table rows.
Returns:
The number of rows matching the query, or 0 if an error occurs.
"""
nrows = 0
try:
with casa_tools.TableReader(table_name) as table:
subtb = table.query(taql)
nrows = subtb.nrows()
subtb.close()
except Exception as ex:
nrows = 0
LOG.warning(ex)
return nrows
[docs]
def get_valid_url(env_var: str, default: str | list[str]) -> str | list[str]:
"""
Fetches one or more URLs from an environment variable. If a comma-delimited string is provided,
it is split and each URL is validated. Falls back to the default if any URL is invalid or not set.
Args:
env_var: The name of the environment variable.
default: A single default URL or a list of default URLs.
Returns:
A valid URL string or a list of valid URL strings.
"""
envvar_value = os.getenv(env_var)
if not envvar_value:
LOG.info('Environment variable %s not defined. Switching to default %s.', env_var, default)
return default
urls = [u.strip() for u in envvar_value.split(',') if u.strip()]
if not urls:
LOG.warning('Environment variable %s is empty after parsing. Switching to default %s.', env_var, default)
return default
for url in urls:
if not validate_url(url):
LOG.warning('Environment variable %s URL was set to %s but is misconfigured.', env_var, url)
LOG.info('Switching to default %s.', default)
return default
if len(urls) == 1:
LOG.info('Environment variable %s set to URL %s.', env_var, urls[0])
return urls[0]
LOG.info('Environment variable %s set to list of URLs %s.', env_var, urls)
return urls
[docs]
def clear_time_cache():
"""Clear the time cache in the CASA measures tool.
See details in PIPE-2891/PIPEREQ-402/CAS-13831.
A workaround solution provided by T. Nakazato (NAOJ), 2025-10-16.
"""
# Define constants for the dummy conversion.
DUMMY_EPOCH_MJD = 37665.0 # MJD corresponds to beginning of IERS data coverage: 1962-01-01
DUMMY_AZIMUTH_DEG = 0.0
DUMMY_ELEVATION_DEG = 90.0
OBSERVATORY = 'ALMA'
with contextlib.closing(casa_tools.measures) as measures_tool:
quanta_tool = casa_tools.quanta
# Set a time frame well before any real observation.
epoch = measures_tool.epoch(rf='UTC', v0=quanta_tool.quantity(DUMMY_EPOCH_MJD, 'd'))
measures_tool.doframe(epoch)
# Set an observatory position frame.
position = measures_tool.observatory(OBSERVATORY)
measures_tool.doframe(position)
# Perform the dummy direction conversion that triggers the cache clear.
dummy_direction = measures_tool.direction(
rf='AZELGEO',
v0=quanta_tool.quantity(DUMMY_AZIMUTH_DEG, 'deg'),
v1=quanta_tool.quantity(DUMMY_ELEVATION_DEG, 'deg'),
)
measures_tool.measure(rf='ICRS', v=dummy_direction)
LOG.debug('Successfully cleared the CASA measures tool time cache.')
[docs]
def build_refantignore(refantignore: str = "", ignorerefant: list | None = None) -> str:
"""Return a comma-separated string from refantignore
and ignorerefant, ignoring empty strings."""
parts = ([refantignore.strip()] if refantignore.strip() else []) + (ignorerefant or [])
return ",".join(parts).strip(",")