from collections.abc import Mapping
from collections import defaultdict
import numba
from numba import types
from numba.core.errors import TypingError
import numpy as np
from africanus.util.patterns import Multiton
from africanus.util.numba import overload, njit, JIT_OPTIONS
from africanus.experimental.rime.fused.arguments import ArgumentDependencies
from africanus.experimental.rime.fused.intrinsics import IntrinsicFactory
from africanus.experimental.rime.fused.specification import RimeSpecification
DATASET_TYPES = []
try:
from daskms.dataset import Dataset as dmsds
except ImportError:
pass
else:
DATASET_TYPES.append(dmsds)
try:
from xarray import Dataset as xrds
except ImportError:
pass
else:
DATASET_TYPES.append(xrds)
def rime_impl_factory(rime_spec, ncorr):
@njit(**JIT_OPTIONS)
def rime(*args):
return rime_impl(*args)
def rime_impl(*args):
raise NotImplementedError
@overload(rime_impl, jit_options=JIT_OPTIONS, prefer_literal=True)
def nb_rime(*args):
if not len(args) > 0:
raise TypingError(
"rime must be at least be called with the signature argument"
)
if not isinstance(args[0], types.StringLiteral):
raise TypingError(
f"Compiling the RIME failed for specification {rime_spec} failed. "
f"This often results from passing incorrectly shaped or typed inputs "
f"to the various Terms and Transformers or bugs in their implementations. "
f"Please search the associate stack trace for the error"
)
if not len(args) % 2 == 1:
raise TypingError(
f"Length of named arguments {len(args)} is not divisible by 2"
)
argstart = 1 + (len(args) - 1) // 2
names = args[1:argstart]
if not all(isinstance(n, types.StringLiteral) for n in names):
raise TypingError(f"{names} must be a Tuple of Literal strings")
if not all(n.literal_type is types.unicode_type for n in names):
raise TypingError(f"{names} must be a Tuple of Literal strings")
# Get literal argument names
names = tuple(n.literal_value for n in names)
# Generate intrinsics
argdeps = ArgumentDependencies(names, rime_spec.terms, rime_spec.transformers)
factory = IntrinsicFactory(argdeps)
out_names, pack_opts_indices = factory.pack_optionals_and_indices_fn()
out_names, pack_transformed = factory.pack_transformed_fn(out_names)
term_state = factory.term_state_fn(out_names)
term_sampler = factory.term_sampler_fn()
try:
lm_i = out_names.index("lm")
uvw_i = out_names.index("uvw")
chan_freq_i = out_names.index("chan_freq")
except ValueError as e:
raise ValueError(f"{str(e)} is required")
def impl(*args):
args_opt_idx = pack_opts_indices(args[argstart:])
args = pack_transformed(args_opt_idx)
state = term_state(args)
nsrc, _ = args[lm_i].shape
nrow, _ = args[uvw_i].shape
(nchan,) = args[chan_freq_i].shape
vis = np.zeros((nrow, nchan, ncorr), np.complex128)
# Kahan summation compensation
compensation = np.zeros_like(vis)
for s in range(nsrc):
for r in range(nrow):
t = state.time_inverse[r]
a1 = state.antenna1[r]
a2 = state.antenna2[r]
f1 = state.feed1[r]
f2 = state.feed2[r]
for ch in range(nchan):
X = term_sampler(state, s, r, t, f1, f2, a1, a2, ch)
for co, value in enumerate(numba.literal_unroll(X)):
# Kahan summation
y = value - compensation[r, ch, co]
current = vis[r, ch, co]
x = current + y
compensation[r, ch, co] = (x - current) - y
vis[r, ch, co] = x
return vis
return impl
return rime
class RimeFactory(metaclass=Multiton):
REQUIRED_ARGS = ArgumentDependencies.REQUIRED_ARGS
REQUIRED_ARGS_LITERAL = tuple(types.literal(n) for n in REQUIRED_ARGS)
DEFAULT_SPEC = "(Kpq, Bpq): [I, Q, U, V] -> [XX, XY, YX, YY]"
def __reduce__(self):
return (RimeFactory, (self.rime_spec,))
def __hash__(self):
return hash(self.rime_spec)
def __eq__(self, rhs):
return isinstance(rhs, RimeFactory) and self.rime_spec == rhs.rime_spec
def __init__(self, rime_spec=DEFAULT_SPEC):
if isinstance(rime_spec, RimeSpecification):
pass
elif isinstance(rime_spec, (list, tuple)):
rime_spec = RimeSpecification(*rime_spec)
elif isinstance(rime_spec, str):
rime_spec = RimeSpecification(rime_spec)
self.rime_spec = rime_spec
self.impl = rime_impl_factory(rime_spec, len(rime_spec.corrs))
def dask_blockwise_args(self, **kwargs):
"""Get the dask schema"""
argdeps = ArgumentDependencies(
tuple(kwargs.keys()), self.rime_spec.terms, self.rime_spec.transformers
)
# Holds kwargs + any dummy outputs from transformations
dummy_kw = kwargs.copy()
dask_schema = defaultdict(list)
for a in argdeps.REQUIRED_ARGS:
dask_schema[a].append(("internal", ("row",)))
POISON = object()
for transformer in argdeps.can_create.values():
kw = {}
for a in transformer.ARGS:
v = dummy_kw.get(a, None if a in argdeps.KEY_ARGS else POISON)
kw[a] = v
for a, d in transformer.KWARGS.items():
kw[a] = dummy_kw.get(a, d)
inputs, outputs = transformer.dask_schema(**kw)
for k, schema in inputs.items():
dask_schema[k].append((transformer, schema))
dummy_kw.update(outputs)
for term in self.rime_spec.terms:
kw = {a: dummy_kw[a] for a in term.ALL_ARGS if a in dummy_kw}
for k, v in term.dask_schema(**kw).items():
dask_schema[k].append((term, v))
merged_schema = {}
for a, candidates in dask_schema.items():
dims = set(pair[1] for pair in candidates)
if len(dims) != 1:
raise ValueError(
f"Multiple candidates provided conflicting "
f"dimension definitions for {a}: {candidates}."
)
merged_schema[a] = dims.pop()
names = list(sorted(argdeps.valid_inputs & set(kwargs.keys())))
blockwise_args = [
e for n in names for e in (kwargs[n], merged_schema.get(n, None))
]
assert 2 * len(names) == len(blockwise_args)
return names, blockwise_args
def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs):
keys = self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys()))
args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values())
return self.impl(types.literal(self.rime_spec.spec_hash), *args)
def consolidate_args(args, kw):
mapping = {}
oargs = []
for element in args:
if isinstance(element, tuple(DATASET_TYPES)):
mapping.update((k.lower(), v.data) for k, v in element.items())
elif isinstance(element, Mapping):
mapping.update(element)
else:
oargs.append(element)
mapping.update(zip(oargs, RimeFactory.REQUIRED_ARGS))
mapping.update(kw)
return mapping
[docs]
def rime(rime_spec, *args, **kw):
"""
Evaluates the Radio Interferometer Measurement Equation (RIME), given
the Specification of the RIME :code:`rime_spec`, as well as the
inputs to the RIME given in :code:`*args` and :code:`**kwargs`.
"""
mapping = consolidate_args(args, kw)
factory = RimeFactory(rime_spec=rime_spec)
return factory(**mapping)