"""
The imaging module contains utility functions used by the imaging tasks.
TODO These utility functions should migrate to hif.tasks.common
"""
from __future__ import annotations
import os
import re
from typing import TYPE_CHECKING
import numpy
import pipeline.infrastructure as infrastructure
from pipeline.infrastructure import casa_tasks
from .. import casa_tools, utils
if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any
LOG = infrastructure.logging.get_logger(__name__)
__all__ = ['chan_selection_to_frequencies', 'freq_selection_to_channels', 'spw_intersect', 'update_sens_dict',
'update_beams_dict', 'set_nested_dict', 'intersect_ranges', 'intersect_ranges_by_weight', 'merge_ranges', 'equal_to_n_digits',
'velocity_to_frequency', 'frequency_to_velocity',
'predict_kernel', 'get_vlass_image_type', 'get_stats']
def _get_cube_freq_axis(img: str) -> tuple[float, float, str, float, int]:
"""
Get CASA image/cube frequency axis.
Args:
img: CASA image/cube name
Returns:
Tuple of frequency axis components
(reference frequency, delta frequency per channel, frequency unit,
reference pixel, number of pixels of frequency axis)
"""
iaTool = casa_tools.image
# Get frequency axis
iaTool.open(img)
imInfo = iaTool.summary()
iaTool.close()
fIndex = imInfo['axisnames'].tolist().index('Frequency')
refFreq = imInfo['refval'][fIndex]
deltaFreq = imInfo['incr'][fIndex]
freqUnit = imInfo['axisunits'][fIndex]
refPix = imInfo['refpix'][fIndex]
numPix = imInfo['shape'][fIndex]
return refFreq, deltaFreq, freqUnit, refPix, numPix
[docs]
def chan_selection_to_frequencies(img: str, selection: str, unit: str = 'GHz') -> list[float] | list[str]:
"""
Convert channel selection to frequency tuples for a given CASA cube.
Args:
img: CASA cube name
selection: Channel selection string using CASA selection syntax
unit: Frequency unit
Returns:
List of pairs of frequency values (float) in the desired units
"""
if selection in ('NONE', 'ALL', 'ALLCONT'):
return [selection]
frequencies = []
if selection != '':
qaTool = casa_tools.quanta
# Get frequency axis
try:
refFreq, deltaFreq, freqUnit, refPix, numPix = _get_cube_freq_axis(img)
except:
LOG.error('No frequency axis found in %s.' % (img))
return ['NONE']
for crange in selection.split(';'):
c0, c1 = list(map(float, crange.split('~')))
# Make sure c0 is the lower channel so that the +/-0.5 channel
# adjustments below go in the right direction.
if (c1 < c0):
c0, c1 = c1, c0
# Convert the channel range (c0-c1) to the corresponding frequency range
# that spans between the outer edges of this channel range. I.e., from
# the lower frequency edge of c0 to the upper frequency edge of c1.
f0 = qaTool.convert({'value': refFreq + (c0 - 0.5 - refPix) * deltaFreq, 'unit': freqUnit}, unit)
f1 = qaTool.convert({'value': refFreq + (c1 + 0.5 - refPix) * deltaFreq, 'unit': freqUnit}, unit)
if qaTool.lt(f0, f1):
frequencies.append((f0['value'], f1['value']))
else:
frequencies.append((f1['value'], f0['value']))
else:
frequencies = ['NONE']
return frequencies
[docs]
def freq_selection_to_channels(img: str, selection: str) -> list[int] | list[str]:
"""
Convert frequency selection to channel tuples for a given CASA cube.
Args:
img: CASA cube name
selection: Frequency selection string using CASA syntax
Returns:
List of pairs of channel values (int)
"""
if selection in ('NONE', 'ALL', 'ALLCONT'):
return [selection]
channels = []
if selection != '':
qaTool = casa_tools.quanta
# Get frequency axis
try:
refFreq, deltaFreq, freqUnit, refPix, numPix = _get_cube_freq_axis(img)
except:
LOG.error('No frequency axis found in %s.' % (img))
return ['NONE']
p = re.compile(r'([\d.]*)(~)([\d.]*)(\D*)')
for frange in p.findall(selection.replace(';', '')):
f0 = qaTool.convert('%s%s' % (frange[0], frange[3]), freqUnit)['value']
f1 = qaTool.convert('%s%s' % (frange[2], frange[3]), freqUnit)['value']
# It is assumed here that the frequency ranges are given from
# the lower edge of the lowest frequency channel to the upper
# edge of the highest frequency channel, while the reference frequency
# is specified at the center of the reference pixel (channel). To calculate
# the corresponding channel range, we need to add 0.5 to the lower channel,
# and subtract 0.5 from the upper channel.
c0 = (f0 - refFreq) / deltaFreq
c1 = (f1 - refFreq) / deltaFreq
# Avoid stepping outside possible channel range
c0 = max(c0, 0)
c0 = min(c0, numPix - 1)
c0 = int(utils.round_half_up(c0 + 0.5))
c0 = max(c0, 0)
c0 = min(c0, numPix - 1)
c1 = max(c1, 0)
c1 = min(c1, numPix - 1)
c1 = int(utils.round_half_up(c1 - 0.5))
c1 = max(c1, 0)
c1 = min(c1, numPix - 1)
if c0 < c1:
channels.append((c0, c1))
else:
channels.append((c1, c0))
else:
channels = ['NONE']
return channels
[docs]
def spw_intersect(spw_range: list[float], line_regions: list[list[float]]) -> list[list[float]]:
"""
This utility function takes a frequency range (as numbers with arbitrary
but common units) and computes the intersection with a list of frequency
ranges defining the regions of spectral lines. It returns the remaining
ranges excluding the line frequency ranges.
Args:
spw_range: List of two numbers defining the spw frequency range
line_regions: List of lists of pairs of numbers defining frequency ranges
to be excluded
Returns:
List of lists of pairs of numbers defining the remaining frequency ranges
"""
spw_sel_intervals = []
for line_region in line_regions:
if (line_region[0] <= spw_range[0]) and (line_region[1] >= spw_range[1]):
spw_sel_intervals = []
spw_range = []
break
elif (line_region[0] <= spw_range[0]) and (line_region[1] >= spw_range[0]):
spw_range = [line_region[1], spw_range[1]]
elif (line_region[0] >= spw_range[0]) and (line_region[1] < spw_range[1]):
spw_sel_intervals.append([spw_range[0], line_region[0]])
spw_range = [line_region[1], spw_range[1]]
elif line_region[0] >= spw_range[1]:
spw_sel_intervals.append(spw_range)
spw_range = []
break
elif (line_region[0] >= spw_range[0]) and (line_region[1] >= spw_range[1]):
spw_sel_intervals.append([spw_range[0], line_region[0]])
spw_range = []
break
if spw_range != []:
spw_sel_intervals.append(spw_range)
return spw_sel_intervals
[docs]
def update_sens_dict(dct: dict, udct: dict) -> None:
"""
Update a sensitivity dictionary. All generic solutions
tried so far did not do the job. So this method assumes
an explicit dictionary structure of
['<MS name>']['<field name']['<intent>'][<spw>]: {<sensitivity result>}.
Args:
dct: Sensitivities dictionary
udct: Sensitivities update dictionary
Returns:
None. The main dictionary is modified in place.
"""
for msname in udct:
# Exclude special primary keys that are not MS names
if msname not in ['recalc', 'robust', 'uvtaper']:
if msname not in dct:
dct[msname] = {}
for field in udct[msname]:
if field not in dct[msname]:
dct[msname][field] = {}
for intent in udct[msname][field]:
if intent not in dct[msname][field]:
dct[msname][field][intent] = {}
for spw in udct[msname][field][intent]:
if spw not in dct[msname][field][intent]:
dct[msname][field][intent][spw] = {}
dct[msname][field][intent][spw] = udct[msname][field][intent][spw]
[docs]
def update_beams_dict(dct: dict, udct: dict) -> None:
"""
Update a beams dictionary. All generic solutions
tried so far did not do the job. So this method assumes
an explicit dictionary structure of
['<field name']['<intent>'][<spwids>]: {<beam>}.
Args:
dct: Beams dictionary
udct: Beams update dictionary
Returns:
None. The main dictionary is modified in place.
"""
for field in udct:
# Exclude special primary keys that are not MS names
if field not in ['recalc', 'robust', 'uvtaper']:
if field not in dct:
dct[field] = {}
for intent in udct[field]:
if intent not in dct[field]:
dct[field][intent] = {}
for spwids in udct[field][intent]:
if spwids not in dct[field][intent]:
dct[field][intent][spwids] = {}
dct[field][intent][spwids] = udct[field][intent][spwids]
[docs]
def set_nested_dict(dct: dict, keys: tuple[Any], value: Any) -> None:
"""
Set a hierarchy of dictionaries with given keys and value
for the lowest level key.
>>> d = {}
>>> set_nested_dict(d, ('key1', 'key2', 'key3'), 1)
>>> print(d)
{'key1': {'key2': {'key3': 1}}}
Args:
dct: Any dictionary
keys : List of keys to build hierarchy
value: Value for lowest level key
Returns:
None. The dictionary is modified in place.
"""
for key in keys[:-1]:
dct = dct.setdefault(key, {})
dct[keys[-1]] = value
[docs]
def intersect_ranges(ranges: list[tuple[float | int]]) -> tuple[float | int]:
"""
Compute intersection of ranges.
Args:
ranges: List of tuples defining (frequency) intervals
Returns:
intersect_range: Tuple of two numbers defining the intersection
"""
if len(ranges) == 0:
return ()
elif len(ranges) == 1:
return ranges[0]
else:
intersect_range = ranges[0]
for myrange in ranges[1:]:
i0 = max(intersect_range[0], myrange[0])
i1 = min(intersect_range[1], myrange[1])
if i0 <= i1:
intersect_range = (i0, i1)
else:
return ()
return intersect_range
[docs]
def intersect_ranges_by_weight(ranges: list[tuple[float | int]], delta: float, threshold: float) -> tuple[float]:
"""
Compute intersection of ranges through weight arrays and a threshold.
Args:
ranges: List of tuples defining frequency intervals
delta: Frequency step to be used for the intersection
threshold: Threshold to be used for the intersection
Returns:
intersect_range: Tuple of two numbers defining the intersection
"""
if len(ranges) == 0:
return ()
elif len(ranges) == 1:
return ranges[0]
else:
min_v = min(numpy.array(ranges).flatten())
max_v = max(numpy.array(ranges).flatten())
max_range = numpy.arange(min_v, max_v+delta, delta)
range_weights = numpy.zeros(max_range.shape, 'd')
for myrange in ranges:
range_weights += numpy.where((max_range >= myrange[0]) & (max_range <= myrange[1]), 1.0, 0.0)
range_weights /= len(ranges)
valid_indices = numpy.where(range_weights >= threshold)[0]
if valid_indices.shape != (0,):
return (max_range[valid_indices[0]], max_range[valid_indices[-1]])
else:
return ()
[docs]
def merge_ranges(ranges: list[tuple[float | int]]) -> Generator[list[tuple[float]], None, None]:
"""
Merge overlapping and adjacent ranges and yield the merged ranges
in order. The argument must be an iterable of pairs (start, stop).
Args:
ranges: List of tuples of two numbers defining ranges
Returns:
Generator yielding tuples of merged ranges
>>> list(merge_ranges([(5,7), (3,5), (-1,3)]))
[(-1, 7)]
>>> list(merge_ranges([(5,6), (3,4), (1,2)]))
[(1, 2), (3, 4), (5, 6)]
>>> list(merge_ranges([]))
[]
(c) Gareth Rees 02/2013
"""
ranges = iter(sorted(ranges))
try:
current_start, current_stop = next(ranges)
except StopIteration:
return
for start, stop in ranges:
if start > current_stop:
# Gap between segments: output current segment and start a new one.
yield current_start, current_stop
current_start, current_stop = start, stop
else:
# Segments adjacent or overlapping: merge.
current_stop = max(current_stop, stop)
yield current_start, current_stop
[docs]
def equal_to_n_digits(x: float, y: float, numdigits: int = 7) -> bool:
"""
Approximate equality check up to a given number of digits.
Args:
x: First floating point number
y: Second floating point number
numdigits: Number of digits to check
Returns:
Boolean
"""
try:
numpy.testing.assert_approx_equal(x, y, numdigits)
return True
except:
return False
[docs]
def velocity_to_frequency(velocity: dict | str, restfreq: dict | str) -> dict | str:
"""
Convert radial velocity to frequency using radio convention.
f = f_rest * (1 - v/c)
Args:
velocity: velocity quantity
restfreq: rest frequency quantity
Returns:
Frequency quantity in units of restfreq
"""
cqa = casa_tools.quanta
light_speed = float(cqa.getvalue(cqa.convert(cqa.constants('c'), 'km/s'))[0])
velocity = float(cqa.getvalue(cqa.convert(cqa.quantity(velocity), 'km/s'))[0])
val = float(cqa.getvalue(restfreq)[0]) * (1 - velocity / light_speed)
unit = cqa.getunit(restfreq)
frequency = cqa.tos(cqa.quantity(val, unit))
return frequency
[docs]
def frequency_to_velocity(frequency: dict | str, restfreq: dict | str) -> dict | str:
"""
Convert frequency to radial velocity using radio convention.
v = c * (f_rest - f) / f_rest
Args:
frequency: frequency quantity
restfreq: rest frequency quantity
Returns:
Velocity quantity in units of km/s
"""
cqa = casa_tools.quanta
light_speed = float(cqa.getvalue(cqa.convert(cqa.constants('c'), 'km/s'))[0])
restfreq = float(cqa.getvalue(cqa.convert(restfreq, 'MHz'))[0])
freq = float(cqa.getvalue(cqa.convert(frequency, 'MHz'))[0])
val = light_speed * ((restfreq - freq) / restfreq)
velocity = cqa.tos(cqa.quantity(val, 'km/s'))
return velocity
[docs]
def predict_kernel(beam, target_beam, pstol=1e-6, patol=1e-3):
"""Predict the required convolution kernel to each a target restoring beam.
pstol: the tolerance in arcsec for original vs. target bmaj/bmin identical or kernel "point source" like.
patol: the tolerance in degree for original vs. target bpa identical
return_code:
0: sucess, the target beam can be reached with a valid convolution kernel
1: fail, "point source" like
2: fail, unable to reach the target beam shape, and the original beam is probably already too large in a certain direction.
Note:
Although ia.deconvolvefrombeam() can also predict convolution kernel sizes, its return can be misleading
in some circumstances (see CAS-13804). Therefore, we use ia.beamforconvolvedsize() here even though we have to catch the CASA runtime error messages.
"""
cqa = casa_tools.quanta
cia = casa_tools.image
clog = casa_tools.casalog
# default return code and kernel: fail (code=2) and a dummy kernel
rt_kernel = {'major': {'unit': 'arcsec', 'value': 0.0},
'minor': {'unit': 'arcsec', 'value': 0.0},
'pa': {'unit': 'deg', 'value': 0.0}}
rt_code = 2
# ia.restoringbeam() return bpa under the key 'positionangle' while ia.commombeam() return bpa under 'pa'
# we search the exact key here so both versions will work.
t_bpa_key = 'positionangle' if 'positionangle' in target_beam else 'pa'
bpa_key = 'positionangle' if 'positionangle' in beam else 'pa'
t_bmaj = cqa.convert(target_beam['major'], 'arcsec')['value']
t_bmin = cqa.convert(target_beam['minor'], 'arcsec')['value']
t_bpa = cqa.convert(target_beam[t_bpa_key], 'deg')['value']
bmaj = cqa.convert(beam['major'], 'arcsec')['value']
bmin = cqa.convert(beam['minor'], 'arcsec')['value']
bpa = cqa.convert(beam[bpa_key], 'deg')['value']
if abs(t_bmaj-bmaj) < pstol and abs(t_bmin-bmin) < pstol and abs(t_bpa-bpa) < patol:
LOG.info(
'The target beam is identical or close to the original beam under the specified tolerance: ' +
f'pstol = {pstol} arcsec and patol = {patol} deg.')
rt_code = 1
else:
target_bm = [cqa.tos(target_beam['major']), cqa.tos(target_beam['minor']), cqa.tos(target_beam[t_bpa_key])]
origin_bm = [cqa.tos(beam['major']), cqa.tos(beam['minor']), cqa.tos(beam[bpa_key])]
# filter out the potential runtime error message when ia.beamforconvolvedsize() fails.
with infrastructure.logging.log_filtermsg('Unable to reach target resolution of major'):
try:
rt_kernel = cia.beamforconvolvedsize(source=origin_bm, convolved=target_bm)
if cqa.convert(rt_kernel['major'], 'arcsec')['value'] < pstol:
LOG.info('The kernel from ia.deconvolvefrombeam() is considered as a point-source under the tolerance ' +
f'pstol = {pstol} arcsec.')
rt_code = 1
else:
LOG.info(f'The convolution kernel prediced by ia.deconvolvefrombeam is {rt_kernel} and larger than the tolerence ' +
f'pstol = {pstol} arcsec')
rt_code = 0
except RuntimeError as e:
LOG.info("Unable to reach the target beam shape because the original beam is probably already too large.")
rt_code = 2
return rt_kernel, rt_code
[docs]
def get_vlass_image_type(filename:str, append_tt: bool = True) -> str:
"""
Determine the VLASS image type based on specific substrings in the filename.
"""
filename = os.path.basename(filename).lower()
base = (
"ALPHAERR" if ".alpha.error" in filename else
"ALPHA" if ".alpha" in filename else
"RMS" if ".rms" in filename else
"INTENSITY_PBCOR" if "image.pbcor." in filename else
"WEIGHT" if ".weight." in filename else
"SUMWT" if ".sumwt." in filename else
"PSF" if ".psf." in filename else
"UNKNOWN"
)
if base == "UNKNOWN" or not append_tt:
return base
tt = ("_TT0" if "tt0" in filename else "_TT1" if "tt1" in filename else "")
return base + tt
[docs]
def get_stats(image_name: str, metrics: list, stokes: str = 'I') -> dict:
"""Return a dict of requested statistics for the given image."""
if not os.path.exists(image_name):
return {m: None for m in metrics}
job = casa_tasks.imstat(imagename=image_name, stokes=stokes)
stats = job.execute()
return {m: (float(stats.get(m)[0]) if stats.get(m) is not None else None) for m in metrics}