Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge v0.2.4 dev branch for release. #349

Merged
merged 6 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-20.04, ubuntu-22.04]
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Set up Python ${{ matrix.python-version }}
Expand Down
3,535 changes: 1,936 additions & 1,599 deletions poetry.lock

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@ include = [
]

[tool.poetry.dependencies]
python = ">=3.9, <3.12"
astro-tigger-lsm = ">=1.7.2, <=1.7.3"
codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = ">=0.3.6, <=0.3.6"}
python = ">=3.10, <3.13"
astro-tigger-lsm = [
{ version = ">=1.7.2, <=1.7.3", python = "<3.12" },
{ version = ">=1.7.4, <=1.7.4", python = ">=3.12"}
]
codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = ">=0.4.1, <=0.4.1"}
colorama = ">=0.4.6, <=0.4.6"
columnar = ">=1.4.1, <=1.4.1"
dask = {extras = ["diagnostics"], version = ">=2023.5.0, <=2024.4.2"}
dask-ms = {extras = ["s3", "xarray", "zarr"], version = ">=0.2.20, <=0.2.20"}
distributed = ">=2023.5.0, <=2024.4.2"
dask = {extras = ["diagnostics"], version = ">=2023.5.0, <=2024.10.0"}
dask-ms = {extras = ["s3", "xarray", "zarr"], version = ">=0.2.23, <=0.2.23"}
distributed = ">=2023.5.0, <=2024.10.0"
loguru = ">=0.7.0, <=0.7.2"
matplotlib = ">=3.5.1, <=3.8.2"
matplotlib = ">=3.5.1, <=3.9.2"
omegaconf = ">=2.3.0, <=2.3.0"
pytest = ">=7.3.1, <=7.4.4"
requests = ">=2.31.0, <=2.31.0"
"ruamel.yaml" = ">=0.17.26, <=0.17.40"
stimela = "^2.0rc17" # Volatile - be less strict.
pytest = ">=7.3.1, <=8.3.3"
requests = ">=2.31.0, <=2.32.3"
"ruamel.yaml" = ">=0.17.26, <=0.18.6"
stimela = ">=2.0"
tbump = ">=6.10.0, <=6.11.0"

[tool.poetry.scripts]
Expand Down
3 changes: 3 additions & 0 deletions quartical/apps/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def to_plot_dict(xdsl, iter_attrs):


def _plot(group, xds, args):
# get rid of question marks
qstrip = lambda x: x.replace('?', 'N/A')
group = tuple(map(qstrip, group))

xds = xds.compute(scheduler="single-threaded")

Expand Down
10 changes: 5 additions & 5 deletions quartical/calibration/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


term_spec_tup = namedtuple("term_spec_tup", "name type shape pshape")
aux_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")
log_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")


def construct_solver(
Expand Down Expand Up @@ -54,9 +54,9 @@ def construct_solver(
corr_mode = data_xds.sizes["corr"]

block_id_arr = get_block_id_arr(data_col)
aux_block_info = {
k: data_xds.attrs.get(k, "?") for k in aux_info_fields
}
data_xds_meta = data_xds.attrs.copy()
for k in log_info_fields:
data_xds_meta[k] = data_xds_meta.get(k, "?")

# Grab the number of input chunks - doing this on the data should be
# safe.
Expand Down Expand Up @@ -87,7 +87,7 @@ def construct_solver(
)
blocker.add_input("term_spec_list", spec_list, ("row", "chan"))
blocker.add_input("corr_mode", corr_mode)
blocker.add_input("aux_block_info", aux_block_info)
blocker.add_input("data_xds_meta", data_xds_meta)
blocker.add_input("solver_opts", solver_opts)
blocker.add_input("chain", chain)

Expand Down
27 changes: 16 additions & 11 deletions quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def solver_wrapper(
solver_opts,
chain,
block_id_arr,
aux_block_info,
data_xds_meta,
corr_mode,
**kwargs
):
Expand Down Expand Up @@ -108,11 +108,11 @@ def solver_wrapper(
# Perform term specific setup e.g. init gains and params.
if term.is_parameterized:
gains, gain_flags, params, param_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
else:
gains, gain_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
# Dummy arrays with standard dtypes - aids compilation.
params = np.empty(term_pshape, dtype=np.float64)
Expand Down Expand Up @@ -190,6 +190,7 @@ def solver_wrapper(
for ind, (term, iters) in enumerate(zip(cycle(chain), iter_recipe)):

active_term = chain.index(term)
active_spec = term_spec_list[term_ind]

ms_fields = term.ms_inputs._fields
ms_inputs = term.ms_inputs(
Expand Down Expand Up @@ -219,13 +220,17 @@ def solver_wrapper(
term.solve_per
)

jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
if term.solver:
jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
else:
jhj = np.zeros(getattr(active_spec, "pshape", active_spec.shape))
conv_iter, conv_perc = 0, 1

# If reweighting is enabled, do it when the epoch changes, except
# for the final epoch - we don't reweight if we won't solve again.
Expand Down Expand Up @@ -269,7 +274,7 @@ def solver_wrapper(
corr_mode
)

log_chisq(presolve_chisq, postsolve_chisq, aux_block_info, block_id)
log_chisq(presolve_chisq, postsolve_chisq, data_xds_meta, block_id)

results_dict["presolve_chisq"] = presolve_chisq
results_dict["postsolve_chisq"] = postsolve_chisq
Expand Down
12 changes: 9 additions & 3 deletions quartical/config/argument_schema.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
input_ms:
path:
required: true
dtype: str
dtype: URI
writable: true
info:
Path to input measurement set.

Expand Down Expand Up @@ -164,15 +165,20 @@ input_model:
output:
gain_directory:
default: gains.qc
dtype: str
dtype: URI
writable: true
must_exist: false
write_parent_dir: true
info:
Name of directory in which QuartiCal gain outputs will be stored.
Accepts both local and s3 paths. QuartiCal will always produce gain
outputs.

log_directory:
default: logs.qc
dtype: str
dtype: Directory
writable: true
must_exist: false
info:
Name of directory in which QuartiCal logging outputs will be stored.
s3 is not currently supported for these outputs.
Expand Down
3 changes: 3 additions & 0 deletions quartical/config/gain_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ gain:
- amplitude
- delay
- delay_and_offset
- delay_and_tec
- phase
- tec_and_offset
- rotation_measure
- rotation
- crosshand_phase
- crosshand_phase_null_v
- leakage
- parallactic_angle
info:
Type of gain to solve for.

Expand Down
82 changes: 41 additions & 41 deletions quartical/data_handling/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@
_thread_local = threading.local()


def assign_parangle_data(ms_path, data_xds_list):

anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0]
feedtab = xds_from_storage_table(ms_path + "::FEED")[0]
fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0]

# We do the following eagerly to reduce graph complexity.
feeds = feedtab.POLARIZATION_TYPE.values
unique_feeds = np.unique(feeds)

if np.all([feed in "XxYy" for feed in unique_feeds]):
feed_type = "linear"
elif np.all([feed in "LlRr" for feed in unique_feeds]):
feed_type = "circular"
else:
raise ValueError("Unsupported feed type/configuration.")

phase_dirs = fieldtab.PHASE_DIR.values

updated_data_xds_list = []
for xds in data_xds_list:
xds = xds.assign(
{
"RECEPTOR_ANGLE": (
("ant", "feed"), clone(feedtab.RECEPTOR_ANGLE.data)
),
"POSITION": (
("ant", "xyz"),
clone(anttab.POSITION.data)
)
}
)
xds.attrs["FEED_TYPE"] = feed_type
xds.attrs["FIELD_CENTRE"] = tuple(phase_dirs[xds.FIELD_ID, 0])

updated_data_xds_list.append(xds)

return updated_data_xds_list


def make_parangle_xds_list(ms_path, data_xds_list):
"""Create a list of xarray.Datasets containing the parallactic angles."""

Expand Down Expand Up @@ -266,7 +306,7 @@ def nb_apply_parangle_rot(data_col, parangles, utime_ind, ant1_col, ant2_col,
v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode)
v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode)
valloc = factories.valloc_factory(corr_mode)
rotmat = rotation_factory(corr_mode, feed_type)
rotmat = factories.rotation_factory(corr_mode, feed_type)

def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
corr_mode, feed_type):
Expand Down Expand Up @@ -299,43 +339,3 @@ def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
return data_col

return impl


def rotation_factory(corr_mode, feed_type):

if feed_type.literal_value == "circular":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = 0
out[2] = 0
out[3] = np.exp(1j*rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = np.exp(1j*rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
else:
raise ValueError("Unsupported number of correlations.")
elif feed_type.literal_value == "linear":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.sin(rot0)
out[2] = -np.sin(rot1)
out[3] = np.cos(rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.cos(rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
else:
raise ValueError("Unsupported number of correlations.")
else:
raise ValueError("Unsupported feed type.")

return factories.qcjit(impl)
26 changes: 19 additions & 7 deletions quartical/data_handling/ms_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from quartical.data_handling.selection import filter_xds_list
from quartical.data_handling.angles import apply_parangles

DASKMS_ATTRS = {
"__daskms_partition_schema__",
"SCAN_NUMBER",
"FIELD_ID",
"DATA_DESC_ID"
}


def read_xds_list(model_columns, ms_opts):
"""Reads a measurement set and generates a list of xarray data sets.
Expand Down Expand Up @@ -237,7 +244,8 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):

# If the xds has fewer correlations than the measurement set, reindex.
if xds.sizes["corr"] < ms_n_corr:
xds = xds.reindex(corr=corr_types, fill_value=0)
# Note that we have to remove chunks from the reindexed axis.
xds = xds.reindex(corr=corr_types, fill_value=0).chunk({"corr": -1})

# Do some special handling on the flag column if we reindexed -
# we need a value dependent fill value.
Expand Down Expand Up @@ -292,14 +300,18 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):

logger.info("Outputs will be written to {}.", ", ".join(output_cols))

# Select only the output columns to simplify datasets.
xds_list = [xds[list(output_cols)] for xds in xds_list]

# Remove all coords bar ROWID so that they do not get written.
xds_list = [
xds.drop_vars(set(xds.coords.keys()) - {"ROWID"}, errors='ignore')
for xds in xds_list
]

# Remove attrs added by QuartiCal so that they do not get written.
for xds in xds_list:
xds.attrs.pop("UTIME_CHUNKS", None)
xds.attrs.pop("FIELD_NAME", None)

# Remove coords added by QuartiCal so that they do not get written.
xds_list = [xds.drop_vars(["chan", "corr"], errors='ignore')
for xds in xds_list]
xds.attrs = {k: v for k, v in xds.attrs.items() if k in DASKMS_ATTRS}

with warnings.catch_warnings(): # We anticipate spurious warnings.
warnings.simplefilter("ignore")
Expand Down
4 changes: 3 additions & 1 deletion quartical/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
preprocess_xds_list,
postprocess_xds_list)
from quartical.data_handling.model_handler import add_model_graph
from quartical.data_handling.angles import make_parangle_xds_list
from quartical.data_handling.angles import (make_parangle_xds_list,
assign_parangle_data)
from quartical.calibration.calibrate import add_calibration_graph
from quartical.statistics.statistics import make_stats_xds_list
from quartical.statistics.logging import log_summary_stats
Expand Down Expand Up @@ -110,6 +111,7 @@ def _execute(exitstack):

# Preprocess the xds_list - initialise some values and fix bad data.
data_xds_list = preprocess_xds_list(data_xds_list, ms_opts)
data_xds_list = assign_parangle_data(ms_opts.path, data_xds_list)

# Make a list of datasets containing the parallactic angles as these
# can be expensive to compute and may be used several times. NOTE: At
Expand Down
9 changes: 7 additions & 2 deletions quartical/gains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from quartical.gains.tec_and_offset import TecAndOffset
from quartical.gains.rotation import Rotation
from quartical.gains.rotation_measure import RotationMeasure
from quartical.gains.crosshand_phase import CrosshandPhase
from quartical.gains.crosshand_phase import CrosshandPhase, CrosshandPhaseNullV
from quartical.gains.leakage import Leakage
from quartical.gains.delay_and_tec import DelayAndTec
from quartical.gains.parallactic_angle import ParallacticAngle


TERM_TYPES = {
Expand All @@ -21,5 +23,8 @@
"rotation": Rotation,
"rotation_measure": RotationMeasure,
"crosshand_phase": CrosshandPhase,
"leakage": Leakage
"crosshand_phase_null_v": CrosshandPhaseNullV,
"leakage": Leakage,
"delay_and_tec": DelayAndTec,
"parallactic_angle": ParallacticAngle
}
2 changes: 1 addition & 1 deletion quartical/gains/amplitude/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_param_names(cls, correlations):

return [f"amplitude_{c}" for c in param_corr]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
Loading