From 4beaca42f07dea2406a36554197078b5948098e7 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Mon, 22 Jul 2024 16:59:22 +0800 Subject: [PATCH 01/17] Ruffage --- .pre-commit-config.yaml | 16 ++++----- arrakis/makecat.py | 2 +- docs/source/imaging.rst | 14 ++++---- docs/source/pipeline.rst | 76 ++++++++++++++++++++-------------------- licenses/flint.txt | 2 +- pyproject.toml | 3 ++ 6 files changed, 58 insertions(+), 55 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e334cc3..3fea9444 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,20 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.10 + rev: v0.5.4 hooks: # Run the linter. - id: ruff args: [ --fix ] # Run the formatter. - id: ruff-format -# - repo: https://github.com/pre-commit/pre-commit-hooks -# rev: v4.6.0 -# hooks: -# - id: trailing-whitespace -# - id: end-of-file-fixer -# - id: check-yaml -# - id: check-added-large-files +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files ci: autofix_commit_msg: | diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 00037135..6ec80be6 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -1174,7 +1174,7 @@ def main( # Replace all infs with nans for col in rmtab.colnames: # Check if column is a float - if type(rmtab[col][0]) == np.float_: + if isinstance(rmtab[col][0], np.float_): rmtab[col][np.isinf(rmtab[col])] = np.nan # Convert all mJy to Jy diff --git a/docs/source/imaging.rst b/docs/source/imaging.rst index 28bb005e..2a557a94 100644 --- a/docs/source/imaging.rst +++ b/docs/source/imaging.rst @@ -26,8 +26,8 @@ This can be run using: [--multiscale] [--multiscale_scale_bias MULTISCALE_SCALE_BIAS] [--multiscale_scales MULTISCALE_SCALES] [--absmem ABSMEM] [--make_residual_cubes] [--ms_glob_pattern MS_GLOB_PATTERN] [--data_column DATA_COLUMN] [--no_mf_weighting] [--skip_fix_ms] [--num_beams NUM_BEAMS] [--disable_pol_local_rms] [--disable_pol_force_mask_rounds] [--hosted-wsclean HOSTED_WSCLEAN | --local_wsclean LOCAL_WSCLEAN] msdir datadir - - + + mmm mmm mmm mmm mmm )-( )-( )-( )-( )-( ( S ) ( P ) ( I ) ( C ) ( E ) @@ -38,17 +38,17 @@ This can be run using: ( R ) ( A ) ( C ) ( S ) | | | | | | | | |___| |___| |___| |___| - + Arrkis imager - - + + options: -h, --help show this help message and exit --hosted-wsclean HOSTED_WSCLEAN Docker or Singularity image for wsclean (default: docker://alecthomson/wsclean:latest) --local_wsclean LOCAL_WSCLEAN Path to local wsclean Singularity image (default: None) - + imaging arguments: msdir Directory containing MS files --temp_dir_wsclean TEMP_DIR_WSCLEAN @@ -96,7 +96,7 @@ This can be run using: Disable local RMS for polarisation images (default: False) --disable_pol_force_mask_rounds Disable force mask rounds for polarisation images (default: False) - + workdir arguments: datadir Directory to create/find full-size images and 'cutout' directory diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst index 3a1002ea..5a97b8d3 100644 --- a/docs/source/pipeline.rst +++ b/docs/source/pipeline.rst @@ -46,8 +46,8 @@ With an initalised database you can call the pipeline on a single field: :: [--phi_max PHI_MAX] [--dphi DPHI] [--n_samples N_SAMPLES] [--poly_ord POLY_ORD] [--no_stokes_i] [--show_plots] [--not_rmsf] [--debug] [--cutoff CUTOFF] [--max_iter MAX_ITER] [--gain GAIN] [--window WINDOW] [--leakage_degree LEAKAGE_DEGREE] [--leakage_bins LEAKAGE_BINS] [--leakage_snr LEAKAGE_SNR] [--catfile OUTFILE] [--npix NPIX] [--map_size MAP_SIZE] [--overwrite] [--config CONFIG] datadir field msdir - - + + mmm mmm mmm mmm mmm )-( )-( )-( )-( )-( ( S ) ( P ) ( I ) ( C ) ( E ) @@ -58,14 +58,14 @@ With an initalised database you can call the pipeline on a single field: :: ( R ) ( A ) ( C ) ( S ) | | | | | | | | |___| |___| |___| |___| - + Arrakis pipeline. - + Before running make sure to start a session of mongodb e.g. $ mongod --dbpath=/path/to/database --bind_ip $(hostname -i) - - - + + + options: -h, --help show this help message and exit --hosted-wsclean HOSTED_WSCLEAN @@ -73,7 +73,7 @@ With an initalised database you can call the pipeline on a single field: :: --local_wsclean LOCAL_WSCLEAN Path to local wsclean Singularity image (default: None) --config CONFIG Config file path (default: None) - + pipeline arguments: --dask_config DASK_CONFIG Config file for Dask SlurmCLUSTER. (default: None) @@ -89,10 +89,10 @@ With an initalised database you can call the pipeline on a single field: :: --skip_cat Skip catalogue stage. (default: False) --skip_validate Skip validation stage. (default: False) --skip_cleanup Skip cleanup stage. (default: False) - + workdir arguments: datadir Directory to create/find full-size images and 'cutout' directory - + generic arguments: field Name of field (e.g. RACS_2132-50). --sbid SBID SBID of observation. (default: None) @@ -106,7 +106,7 @@ With an initalised database you can call the pipeline on a single field: :: --password PASSWORD Password of mongodb. (default: None) --limit LIMIT Limit the number of islands to process. (default: None) --database Add data to MongoDB. (default: False) - + imaging arguments: msdir Directory containing MS files --temp_dir_wsclean TEMP_DIR_WSCLEAN @@ -154,17 +154,17 @@ With an initalised database you can call the pipeline on a single field: :: Disable local RMS for polarisation images (default: False) --disable_pol_force_mask_rounds Disable force mask rounds for polarisation images (default: False) - + cutout arguments: -p PAD, --pad PAD Number of beamwidths to pad around source [3]. (default: 3) -d, --dryrun Do a dry-run [False]. (default: False) - + linmos arguments: --holofile HOLOFILE Path to holography image (default: None) --yanda YANDA Yandasoft version to pull from DockerHub [1.3.0]. (default: 1.3.0) --yanda_image YANDA_IMAGE Path to an existing yandasoft singularity container image. (default: None) - + frion arguments: --ionex_server IONEX_SERVER IONEX server (default: ftp://ftp.aiub.unibe.ch/CODE/) @@ -174,13 +174,13 @@ With an initalised database you can call the pipeline on a single field: :: --ionex_proxy_server IONEX_PROXY_SERVER Proxy server. (default: None) --ionex_predownload Pre-download IONEX files. (default: False) - + common rm arguments: --dimension DIMENSION How many dimensions for RMsynth '1d' or '3d'. (default: 1d) --save_plots save the plots. (default: False) --rm_verbose Verbose RMsynth/RMClean. (default: False) - + rm-synth arguments: --ion Use ionospheric-corrected data. (default: False) --tt0 TT0 TT0 MFS image -- will be used for model of Stokes I -- also needs --tt1. (default: None) @@ -201,13 +201,13 @@ With an initalised database you can call the pipeline on a single field: :: --show_plots show the plots. (default: False) --not_rmsf Skip calculation of RMSF? (default: False) --debug turn on debugging messages & plots. (default: False) - + rm-clean arguments: --cutoff CUTOFF CLEAN cutoff (+ve = absolute, -ve = sigma). (default: -3) --max_iter MAX_ITER maximum number of CLEAN iterations. (default: 10000) --gain GAIN CLEAN loop gain. (default: 0.1) --window WINDOW Further CLEAN in mask to this threshold. (default: None) - + catalogue arguments: --leakage_degree LEAKAGE_DEGREE Degree of leakage polynomial fit. (default: 4) @@ -216,14 +216,14 @@ With an initalised database you can call the pipeline on a single field: :: --leakage_snr LEAKAGE_SNR SNR cut for leakage fit. (default: 30.0) --catfile OUTFILE File to save table to. (default: None) - + validation options: --npix NPIX Number of pixels in the gridded maps (default: 512) --map_size MAP_SIZE Size of the maps in degrees (default: 8) - + cleanup arguments: --overwrite Overwrite existing tarball (default: False) - + Args that start with '--' can also be set in a config file (specified via --config). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at https://goo.gl/R74nmi). In general, command-line values override config file values which override defaults. .. code-block:: yaml @@ -358,8 +358,8 @@ Similarly, you can merge multiple fields togther using: :: [--host host] [--username USERNAME] [--password PASSWORD] [--holofile HOLOFILE] [--yanda YANDA] [--yanda_image YANDA_IMAGE] [--dimension DIMENSION] [--save_plots] [--rm_verbose] [--ion] [--tt0 TT0] [--tt1 TT1] [--validate] [--own_fit] [--weight_type WEIGHT_TYPE] [--fit_function FIT_FUNCTION] [--fit_rmsf] [--phi_max PHI_MAX] [--dphi DPHI] [--n_samples N_SAMPLES] [--poly_ord POLY_ORD] [--no_stokes_i] [--show_plots] [--not_rmsf] [--debug] [--cutoff CUTOFF] [--max_iter MAX_ITER] [--gain GAIN] [--window WINDOW] [--leakage_degree LEAKAGE_DEGREE] [--leakage_bins LEAKAGE_BINS] [--leakage_snr LEAKAGE_SNR] [--catfile OUTFILE] [--npix NPIX] [--map_size MAP_SIZE] [--overwrite] [--config CONFIG] - - + + mmm mmm mmm mmm mmm )-( )-( )-( )-( )-( ( S ) ( P ) ( I ) ( C ) ( E ) @@ -370,18 +370,18 @@ Similarly, you can merge multiple fields togther using: :: ( R ) ( A ) ( C ) ( S ) | | | | | | | | |___| |___| |___| |___| - + Arrakis regional pipeline. - + Before running make sure to start a session of mongodb e.g. $ mongod --dbpath=/path/to/database --bind_ip $(hostname -i) - - - + + + options: -h, --help show this help message and exit --config CONFIG Config file path (default: None) - + pipeline arguments: --dask_config DASK_CONFIG Config file for Dask SlurmCLUSTER. (default: None) @@ -391,7 +391,7 @@ Similarly, you can merge multiple fields togther using: :: --skip_cat Skip catalogue stage [False]. (default: False) --skip_validate Skip validation stage. (default: False) --skip_cleanup Skip cleanup stage [False]. (default: False) - + merge arguments: --merge_name MERGE_NAME Name of the merged region (default: None) @@ -406,19 +406,19 @@ Similarly, you can merge multiple fields togther using: :: --host host Host of mongodb (probably $hostname -i). (default: None) --username USERNAME Username of mongodb. (default: None) --password PASSWORD Password of mongodb. (default: None) - + linmos arguments: --holofile HOLOFILE Path to holography image (default: None) --yanda YANDA Yandasoft version to pull from DockerHub [1.3.0]. (default: 1.3.0) --yanda_image YANDA_IMAGE Path to an existing yandasoft singularity container image. (default: None) - + common rm arguments: --dimension DIMENSION How many dimensions for RMsynth '1d' or '3d'. (default: 1d) --save_plots save the plots. (default: False) --rm_verbose Verbose RMsynth/RMClean. (default: False) - + rm-synth arguments: --ion Use ionospheric-corrected data. (default: False) --tt0 TT0 TT0 MFS image -- will be used for model of Stokes I -- also needs --tt1. (default: None) @@ -439,13 +439,13 @@ Similarly, you can merge multiple fields togther using: :: --show_plots show the plots. (default: False) --not_rmsf Skip calculation of RMSF? (default: False) --debug turn on debugging messages & plots. (default: False) - + rm-clean arguments: --cutoff CUTOFF CLEAN cutoff (+ve = absolute, -ve = sigma). (default: -3) --max_iter MAX_ITER maximum number of CLEAN iterations. (default: 10000) --gain GAIN CLEAN loop gain. (default: 0.1) --window WINDOW Further CLEAN in mask to this threshold. (default: None) - + catalogue arguments: --leakage_degree LEAKAGE_DEGREE Degree of leakage polynomial fit. (default: 4) @@ -454,14 +454,14 @@ Similarly, you can merge multiple fields togther using: :: --leakage_snr LEAKAGE_SNR SNR cut for leakage fit. (default: 30.0) --catfile OUTFILE File to save table to. (default: None) - + validation options: --npix NPIX Number of pixels in the gridded maps (default: 512) --map_size MAP_SIZE Size of the maps in degrees (default: 8) - + cleanup arguments: --overwrite Overwrite existing tarball (default: False) - + Args that start with '--' can also be set in a config file (specified via --config). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at https://goo.gl/R74nmi). In general, command-line values override config file values which override defaults. diff --git a/licenses/flint.txt b/licenses/flint.txt index fd0a43b2..cefac032 100644 --- a/licenses/flint.txt +++ b/licenses/flint.txt @@ -25,4 +25,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pyproject.toml b/pyproject.toml index 9221a5f5..44614ed0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,3 +123,6 @@ create_mongodb = { reference="scripts/create_mongodb.py", type="file"} [tool.isort] profile = "black" + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "UP006"] From dac6e314cefcfb54dbbb14f8d8c020cff0252714 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Mon, 22 Jul 2024 18:04:21 +0800 Subject: [PATCH 02/17] Ruffage --- arrakis/cleanup.py | 2 +- arrakis/columns_possum.py | 1 - arrakis/cutout.py | 3 +-- arrakis/frion.py | 5 ++--- arrakis/imager.py | 15 +++++++-------- arrakis/linmos.py | 3 +-- arrakis/logger.py | 3 +-- arrakis/rmclean_oncuts.py | 6 +++--- arrakis/rmsynth_oncuts.py | 3 +-- arrakis/utils/fitsutils.py | 2 +- arrakis/utils/io.py | 8 ++++---- arrakis/utils/json.py | 2 +- arrakis/utils/meta.py | 2 +- arrakis/utils/pipeline.py | 2 +- arrakis/validate.py | 5 ++--- pyproject.toml | 6 +++--- scripts/casda_prepare.py | 9 ++++----- scripts/compare_leakage.py | 11 +++++------ scripts/compute_leakage.py | 5 ++--- scripts/copy_cutouts.py | 6 +++--- scripts/copy_cutouts_askap.py | 6 +++--- scripts/copy_data.py | 3 +-- scripts/create_mongodb.py | 3 +-- scripts/find_row.py | 3 +-- scripts/find_sbid.py | 3 +-- scripts/fix_dr1_cat.py | 14 +++++++------- scripts/fix_src_cat.py | 5 ++--- scripts/spica.py | 6 +++--- scripts/tar_cubelets.py | 3 +-- 29 files changed, 64 insertions(+), 81 deletions(-) diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py index 6725a95e..08cc9c6f 100644 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -8,13 +8,13 @@ from pathlib import Path from typing import List -from arrakis.utils.io import verify_tarball import astropy.units as u import numpy as np from prefect import flow, get_run_logger, task from tqdm.auto import tqdm from arrakis.logger import TqdmToLogger, UltimateHelpFormatter, logger +from arrakis.utils.io import verify_tarball from arrakis.utils.pipeline import generic_parser, logo_str logger.setLevel(logging.INFO) diff --git a/arrakis/columns_possum.py b/arrakis/columns_possum.py index 57d09ee2..7427af50 100644 --- a/arrakis/columns_possum.py +++ b/arrakis/columns_possum.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Column names from RM-tools to catalogue diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 34894149..3acd0ece 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -8,9 +8,8 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import List +from typing import List, Optional, Set, TypeVar from typing import NamedTuple as Struct -from typing import Optional, Set, TypeVar import astropy.units as u import numpy as np diff --git a/arrakis/frion.py b/arrakis/frion.py index 35823fa9..76716e6c 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -6,16 +6,14 @@ import os from pathlib import Path from pprint import pformat -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional, Union from typing import NamedTuple as Struct -from typing import Optional, Union from urllib.error import URLError import astropy.units as u import numpy as np import pymongo from astropy.time import Time, TimeDelta -from FRion import correct, predict from prefect import flow, task from tqdm.auto import tqdm @@ -28,6 +26,7 @@ ) from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser +from FRion import correct, predict logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/arrakis/imager.py b/arrakis/imager.py index 13f7433e..1bb568d3 100644 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -2,43 +2,42 @@ """Arrkis imager""" import argparse -from concurrent.futures import ThreadPoolExecutor import hashlib import logging import os import pickle import shutil +from concurrent.futures import ThreadPoolExecutor from glob import glob from pathlib import Path from subprocess import CalledProcessError -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple, Union from typing import NamedTuple as Struct -from typing import Optional, Tuple, Union -from arrakis.utils.meta import my_ceil import astropy.units as u +import matplotlib +import matplotlib.pyplot as plt import numpy as np from astropy.io import fits from astropy.stats import mad_std from astropy.table import Table from astropy.visualization import ( - SqrtStretch, ImageNormalize, MinMaxInterval, + SqrtStretch, ) from fitscube import combine_fits from fixms.fix_ms_corrs import fix_ms_corrs from fixms.fix_ms_dir import fix_ms_dir -import matplotlib.pyplot as plt -import matplotlib from prefect import flow, get_run_logger, task from racs_tools import beamcon_2D -from spython.main import Client as sclient from skimage.transform import resize +from spython.main import Client as sclient from tqdm.auto import tqdm from arrakis.logger import TqdmToLogger, UltimateHelpFormatter, logger from arrakis.utils.io import parse_env_path +from arrakis.utils.meta import my_ceil from arrakis.utils.msutils import ( beam_from_ms, field_idx_from_ms, diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 15391892..09d998e2 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -9,9 +9,8 @@ from glob import glob from pathlib import Path from pprint import pformat -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from typing import NamedTuple as Struct -from typing import Optional, Tuple import astropy.units as u import numpy as np diff --git a/arrakis/logger.py b/arrakis/logger.py index 4c22b943..287936d1 100644 --- a/arrakis/logger.py +++ b/arrakis/logger.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """Logging module for arrakis""" import argparse @@ -24,7 +23,7 @@ class TqdmToLogger(io.StringIO): buf = "" def __init__(self, logger, level=None): - super(TqdmToLogger, self).__init__() + super().__init__() self.logger = logger self.level = level or logging.INFO diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index 0f7a4a3e..1521ab38 100644 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -7,13 +7,13 @@ import warnings from pathlib import Path from pprint import pformat -from typing import Optional from shutil import copyfile +from typing import Optional -import numpy as np -from matplotlib import pyplot as plt import matplotlib +import numpy as np import pymongo +from matplotlib import pyplot as plt from prefect import flow, task from RMtools_1D import do_RMclean_1D from RMtools_3D import do_RMclean_3D diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 0d9c9e98..4e1a0668 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -9,9 +9,8 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import List +from typing import List, Optional, Tuple, Union from typing import NamedTuple as Struct -from typing import Optional, Tuple, Union import astropy.units as u import matplotlib diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index 82531778..0a55d845 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -11,12 +11,12 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS -from FRion.correct import find_freq_axis from spectral_cube import SpectralCube from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger from arrakis.utils.io import gettable +from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) diff --git a/arrakis/utils/io.py b/arrakis/utils/io.py index ed2d173f..96cbc698 100644 --- a/arrakis/utils/io.py +++ b/arrakis/utils/io.py @@ -3,12 +3,12 @@ import logging import os +import shlex import stat +import subprocess as sp import warnings from glob import glob from pathlib import Path -import shlex -import subprocess as sp from typing import Tuple from astropy.table import Table @@ -137,7 +137,7 @@ def copyfile(src, dst, *, follow_symlinks=True, verbose=True): """ if _samefile(src, dst): - raise SameFileError("{!r} and {!r} are the same file".format(src, dst)) + raise SameFileError(f"{src!r} and {dst!r} are the same file") for fn in [src, dst]: try: @@ -148,7 +148,7 @@ def copyfile(src, dst, *, follow_symlinks=True, verbose=True): else: # XXX What about other special files? (sockets, devices...) if stat.S_ISFIFO(st.st_mode): - raise SpecialFileError("`%s` is a named pipe" % fn) + raise SpecialFileError(f"`{fn}` is a named pipe") if not follow_symlinks and os.path.islink(src): os.symlink(os.readlink(src), dst) diff --git a/arrakis/utils/json.py b/arrakis/utils/json.py index 86c02a9a..3cae6d79 100644 --- a/arrakis/utils/json.py +++ b/arrakis/utils/json.py @@ -32,4 +32,4 @@ def default(self, obj): # pylint: disable=E0202 elif dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) else: - return super(MyEncoder, self).default(obj) + return super().default(obj) diff --git a/arrakis/utils/meta.py b/arrakis/utils/meta.py index fd5d7ebd..1c230117 100644 --- a/arrakis/utils/meta.py +++ b/arrakis/utils/meta.py @@ -5,8 +5,8 @@ import warnings from itertools import zip_longest -from astropy.utils.exceptions import AstropyWarning import numpy as np +from astropy.utils.exceptions import AstropyWarning from spectral_cube.utils import SpectralCubeWarning warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) diff --git a/arrakis/utils/pipeline.py b/arrakis/utils/pipeline.py index bdb1852a..334a5cd5 100644 --- a/arrakis/utils/pipeline.py +++ b/arrakis/utils/pipeline.py @@ -352,7 +352,7 @@ def __init__( start=True, **tqdm_kwargs, ): - super(TqdmProgressBar, self).__init__(keys, scheduler, interval, complete) + super().__init__(keys, scheduler, interval, complete) self.tqdm = tqdm(keys, **tqdm_kwargs) self.loop = loop or IOLoop() diff --git a/arrakis/validate.py b/arrakis/validate.py index 33d41a27..4a57bcc9 100644 --- a/arrakis/validate.py +++ b/arrakis/validate.py @@ -1,12 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Make validation plots from a catalogue""" import argparse import logging +from importlib import resources from pathlib import Path from typing import NamedTuple as Struct -from importlib import resources import astropy.units as u import matplotlib.pyplot as plt @@ -347,7 +346,7 @@ def plot_rm( label="$\pm 5 \sigma$", ) abs_max_val = np.nanmax( - (np.abs(np.concatenate([racs_match["rm"], other_match["rm"]]))) + np.abs(np.concatenate([racs_match["rm"], other_match["rm"]])) ) ax.plot( [-abs_max_val, abs_max_val], diff --git a/pyproject.toml b/pyproject.toml index 44614ed0..de72d0fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,8 +121,8 @@ spica = { reference="scripts/spica.py", type="file"} tar_cubelets = { reference="scripts/tar_cubelets.py", type="file"} create_mongodb = { reference="scripts/create_mongodb.py", type="file"} -[tool.isort] -profile = "black" +[tool.ruff] +src = ["arrakis", "tests", "scripts"] [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "UP006"] +select = ["E4", "E7", "E9", "F", "UP", "I"] diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index 17170379..a0454ab2 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -16,6 +16,10 @@ import numpy as np import pandas as pd import polspectra +from arrakis.logger import TqdmToLogger, logger +from arrakis.makecat import write_votable +from arrakis.utils.io import try_mkdir, try_symlink +from arrakis.utils.pipeline import chunk_dask from astropy.io import fits from astropy.table import Column, Table from astropy.units.core import get_current_unit_registry @@ -28,11 +32,6 @@ from spectral_cube.cube_utils import convert_bunit from tqdm.auto import tqdm -from arrakis.logger import TqdmToLogger, logger -from arrakis.makecat import write_votable -from arrakis.utils.io import try_mkdir, try_symlink -from arrakis.utils.pipeline import chunk_dask - TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index 53ba69e7..16b5128b 100644 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -22,17 +22,16 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from astropy.coordinates import SkyCoord -from astropy.io import fits -from astropy.stats import mad_std, sigma_clip -from astropy.wcs import WCS -from dask import delayed - from arrakis.linmos import gen_seps from arrakis.logger import logger, logging from arrakis.utils.database import get_db from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import chunk_dask, logo_str +from astropy.coordinates import SkyCoord +from astropy.io import fits +from astropy.stats import mad_std, sigma_clip +from astropy.wcs import WCS +from dask import delayed def make_plot(data, comp, imfile): diff --git a/scripts/compute_leakage.py b/scripts/compute_leakage.py index 23581c9d..0bc91094 100644 --- a/scripts/compute_leakage.py +++ b/scripts/compute_leakage.py @@ -5,13 +5,12 @@ import astropy.units as units import matplotlib.pyplot as plt import numpy as np +from arrakis.logger import TqdmToLogger, logger +from arrakis.utils.database import get_db from astropy.coordinates import SkyCoord from astropy.wcs import WCS from tqdm.auto import tqdm, trange -from arrakis.logger import TqdmToLogger, logger -from arrakis.utils.database import get_db - TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/scripts/copy_cutouts.py b/scripts/copy_cutouts.py index e184f206..66a0d70b 100755 --- a/scripts/copy_cutouts.py +++ b/scripts/copy_cutouts.py @@ -1,12 +1,12 @@ import argparse import os -import copy_data -import spica - from arrakis.logger import logger, logging from arrakis.utils.io import try_mkdir +import copy_data +import spica + logger.setLevel(logging.INFO) racs_area = os.path.abspath("/askapbuffer/payne/mcc381/RACS") diff --git a/scripts/copy_cutouts_askap.py b/scripts/copy_cutouts_askap.py index 1a3c65f2..0c1591b2 100644 --- a/scripts/copy_cutouts_askap.py +++ b/scripts/copy_cutouts_askap.py @@ -1,12 +1,12 @@ import argparse import os -import copy_data -import spica - from arrakis.logger import logger, logging from arrakis.utils.io import try_mkdir +import copy_data +import spica + logger.setLevel(logging.INFO) # racs_area = os.path.abspath('/askapbuffer/processing/len067/arrakis') diff --git a/scripts/copy_data.py b/scripts/copy_data.py index ca8426b8..61101d70 100755 --- a/scripts/copy_data.py +++ b/scripts/copy_data.py @@ -5,11 +5,10 @@ from pathlib import Path from shutil import SameFileError, copyfile -from astropy.table import Table - from arrakis.logger import logger from arrakis.utils.io import prsync, rsync, try_mkdir from arrakis.utils.meta import yes_or_no +from astropy.table import Table def main( diff --git a/scripts/create_mongodb.py b/scripts/create_mongodb.py index 5cd12b19..f9d94ab9 100644 --- a/scripts/create_mongodb.py +++ b/scripts/create_mongodb.py @@ -7,9 +7,8 @@ from typing import Optional import pymongo -from pymongo.database import Database - from arrakis.logger import logger +from pymongo.database import Database logger.setLevel(logging.INFO) diff --git a/scripts/find_row.py b/scripts/find_row.py index bd13a44b..f1d9761f 100755 --- a/scripts/find_row.py +++ b/scripts/find_row.py @@ -2,9 +2,8 @@ import argparse from pathlib import Path -from astropy.table import Table - from arrakis.logger import logger, logging +from astropy.table import Table logger.setLevel(logging.INFO) diff --git a/scripts/find_sbid.py b/scripts/find_sbid.py index 26063cb5..c1ad53a0 100755 --- a/scripts/find_sbid.py +++ b/scripts/find_sbid.py @@ -2,9 +2,8 @@ import argparse from pathlib import Path -from astropy.table import Table - from arrakis.logger import logger, logging +from astropy.table import Table logger.setLevel(logging.INFO) diff --git a/scripts/fix_dr1_cat.py b/scripts/fix_dr1_cat.py index 0f11365c..ed64e6fd 100755 --- a/scripts/fix_dr1_cat.py +++ b/scripts/fix_dr1_cat.py @@ -8,13 +8,6 @@ import astropy.units as u import numpy as np -from astropy.coordinates import SkyCoord -from astropy.table import Column, Table -from astropy.time import Time -from astropy.units import cds -from rmtable import RMTable -from spica import SPICA - from arrakis.logger import logger from arrakis.makecat import ( compute_local_rm_flag, @@ -22,6 +15,13 @@ is_leakage, write_votable, ) +from astropy.coordinates import SkyCoord +from astropy.table import Column, Table +from astropy.time import Time +from astropy.units import cds +from rmtable import RMTable + +from spica import SPICA def fix_fields( diff --git a/scripts/fix_src_cat.py b/scripts/fix_src_cat.py index 26af1342..b600263d 100644 --- a/scripts/fix_src_cat.py +++ b/scripts/fix_src_cat.py @@ -6,13 +6,12 @@ from pathlib import Path import numpy as np +from arrakis.logger import TqdmToLogger, logger +from arrakis.makecat import fix_blank_units, replace_nans, vot from astropy.coordinates import SkyCoord from astropy.table import Table from tqdm.auto import tqdm -from arrakis.logger import TqdmToLogger, logger -from arrakis.makecat import fix_blank_units, replace_nans, vot - TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) logger.setLevel("DEBUG") diff --git a/scripts/spica.py b/scripts/spica.py index 8d5ef92f..82c858ba 100755 --- a/scripts/spica.py +++ b/scripts/spica.py @@ -4,12 +4,12 @@ from glob import glob from pathlib import Path -import copy_data import numpy as np -from astropy.table import Table - from arrakis.logger import logger, logging from arrakis.utils.io import try_mkdir +from astropy.table import Table + +import copy_data logger.setLevel(logging.INFO) diff --git a/scripts/tar_cubelets.py b/scripts/tar_cubelets.py index fddbffce..ac401344 100755 --- a/scripts/tar_cubelets.py +++ b/scripts/tar_cubelets.py @@ -6,11 +6,10 @@ from glob import glob import dask +from arrakis.logger import TqdmToLogger, logger from dask import delayed from tqdm.auto import tqdm -from arrakis.logger import TqdmToLogger, logger - TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) From d5bf54adac88a778f995dfd74af241113c9c7727 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Mon, 22 Jul 2024 18:16:56 +0800 Subject: [PATCH 03/17] Major ruffage --- arrakis/__init__.py | 2 + arrakis/cleanup.py | 8 +- arrakis/columns_possum.py | 1 + arrakis/cutout.py | 55 +++++---- arrakis/frion.py | 50 ++++---- arrakis/imager.py | 116 +++++++++--------- arrakis/init_database.py | 56 ++++----- arrakis/linmos.py | 68 ++++++----- arrakis/logger.py | 2 + arrakis/makecat.py | 62 +++++----- arrakis/merge_fields.py | 44 +++---- arrakis/process_region.py | 2 + arrakis/process_spice.py | 2 + arrakis/rmclean_oncuts.py | 36 +++--- arrakis/rmsynth_oncuts.py | 100 ++++++++-------- arrakis/utils/coordinates.py | 6 +- arrakis/utils/database.py | 26 ++-- arrakis/utils/exceptions.py | 2 + arrakis/utils/fitsutils.py | 18 +-- arrakis/utils/fitting.py | 69 ++++++----- arrakis/utils/io.py | 17 +-- arrakis/utils/json.py | 4 +- arrakis/utils/meta.py | 13 +- arrakis/utils/msutils.py | 220 +++++++++++++++++----------------- arrakis/utils/pipeline.py | 26 ++-- arrakis/utils/plotting.py | 7 +- arrakis/utils/typing.py | 2 + arrakis/validate.py | 2 + arrakis/wsclean_rmsynth.py | 56 +++------ docs/source/conf.py | 2 + pyproject.toml | 38 +++++- scripts/casda_prepare.py | 23 ++-- scripts/check_cutout.py | 2 + scripts/compare_leakage.py | 8 +- scripts/compute_leakage.py | 6 +- scripts/copy_cutouts.py | 2 + scripts/copy_cutouts_askap.py | 2 + scripts/copy_data.py | 16 +-- scripts/create_mongodb.py | 21 ++-- scripts/find_row.py | 2 + scripts/find_sbid.py | 2 + scripts/fix_dr1_cat.py | 6 +- scripts/fix_src_cat.py | 8 +- scripts/hello_mpi_world.py | 2 + scripts/make_links.py | 6 +- scripts/spica.py | 12 +- scripts/tar_cubelets.py | 17 +-- submit/test_image.py | 1 + tests/cli_test.py | 6 +- tests/unit_test.py | 7 +- 50 files changed, 663 insertions(+), 598 deletions(-) mode change 100644 => 100755 arrakis/__init__.py mode change 100644 => 100755 arrakis/cleanup.py mode change 100644 => 100755 arrakis/columns_possum.py mode change 100644 => 100755 arrakis/cutout.py mode change 100644 => 100755 arrakis/frion.py mode change 100644 => 100755 arrakis/imager.py mode change 100644 => 100755 arrakis/init_database.py mode change 100644 => 100755 arrakis/linmos.py mode change 100644 => 100755 arrakis/logger.py mode change 100644 => 100755 arrakis/makecat.py mode change 100644 => 100755 arrakis/merge_fields.py mode change 100644 => 100755 arrakis/process_region.py mode change 100644 => 100755 arrakis/process_spice.py mode change 100644 => 100755 arrakis/rmclean_oncuts.py mode change 100644 => 100755 arrakis/rmsynth_oncuts.py mode change 100644 => 100755 arrakis/validate.py mode change 100644 => 100755 arrakis/wsclean_rmsynth.py diff --git a/arrakis/__init__.py b/arrakis/__init__.py old mode 100644 new mode 100755 index 530a519d..100b27e3 --- a/arrakis/__init__.py +++ b/arrakis/__init__.py @@ -1,5 +1,7 @@ """Processing polarized RACS data products.""" +from __future__ import annotations + from importlib.metadata import distribution __version__ = distribution("arrakis").version diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py old mode 100644 new mode 100755 index 08cc9c6f..911da3f4 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 """DANGER ZONE: Purge directories of un-needed FITS files.""" +from __future__ import annotations + import argparse import logging import shutil import tarfile from pathlib import Path -from typing import List import astropy.units as u import numpy as np @@ -62,7 +63,8 @@ def make_cutout_tarball(cutdir: Path, overwrite: bool = False) -> Path: verification = verify_tarball(tarball) if not verification: - raise RuntimeError(f"Verification of {tarball} failed!") + msg = f"Verification of {tarball} failed!" + raise RuntimeError(msg) logger.critical(f"Removing {cutdir}") shutil.rmtree(cutdir) @@ -106,7 +108,7 @@ def main( total_file_size = np.sum([p.stat().st_size for p in to_purge_all]) * u.byte logger.warning(f"Purging {len(to_purge_all)} files from {datadir}") logger.warning(f"Will free {total_file_size.to(u.GB)}") - purged: List[Path] = [] + purged: list[Path] = [] for to_purge in tqdm(to_purge_all, file=TQDM_OUT, desc="Purging big beams"): purged.append(purge_cubelet_beams(to_purge)) logger.info(f"Files purged: {len(purged)}") diff --git a/arrakis/columns_possum.py b/arrakis/columns_possum.py old mode 100644 new mode 100755 index 7427af50..0b2491a5 --- a/arrakis/columns_possum.py +++ b/arrakis/columns_possum.py @@ -12,6 +12,7 @@ # header (from FITS header), meta (from meta data) # dict key or column name in pipeline # unit (string) +from __future__ import annotations import astropy.units as u from astropy.units import cds diff --git a/arrakis/cutout.py b/arrakis/cutout.py old mode 100644 new mode 100755 index 3acd0ece..a3f9dabe --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Produce cutouts from RACS cubes""" +from __future__ import annotations + import argparse import logging import warnings @@ -8,8 +10,8 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import List, Optional, Set, TypeVar from typing import NamedTuple as Struct +from typing import TypeVar import astropy.units as u import numpy as np @@ -70,7 +72,7 @@ class CutoutArgs(Struct): def cutout_weight( image_name: Path, source_id: str, - cutout_args: Optional[CutoutArgs], + cutout_args: CutoutArgs | None, field: str, stoke: str, beam_num: int, @@ -118,7 +120,7 @@ def cutout_image( old_header: fits.Header, cube: SpectralCube, source_id: str, - cutout_args: Optional[CutoutArgs], + cutout_args: CutoutArgs | None, field: str, beam_num: int, stoke: str, @@ -217,7 +219,7 @@ def get_args( comps: pd.DataFrame, source: pd.Series, outdir: Path, -) -> Optional[CutoutArgs]: +) -> CutoutArgs | None: """Get arguments for cutout function Args: @@ -248,7 +250,7 @@ def get_args( # Find image size ras: u.Quantity = comps.RA.values * u.deg decs: u.Quantity = comps.Dec.values * u.deg - majs: List[float] = comps.Maj.values * u.arcsec + majs: list[float] = comps.Maj.values * u.arcsec coords = SkyCoord(ras, decs, frame="icrs") padder = np.max(majs) @@ -339,8 +341,8 @@ def worker( beam_num: int, stoke: str, pad: float = 3, - username: Optional[str] = None, - password: Optional[str] = None, + username: str | None = None, + password: str | None = None, ): _, _, comp_col = get_db( host=host, epoch=epoch, username=username, password=password @@ -387,16 +389,18 @@ def big_cutout( epoch: int, field: str, pad: float = 3, - username: Optional[str] = None, - password: Optional[str] = None, - limit: Optional[int] = None, -) -> List[pymongo.UpdateOne]: + username: str | None = None, + password: str | None = None, + limit: int | None = None, +) -> list[pymongo.UpdateOne]: wild = f"image.restored.{stoke.lower()}*contcube*beam{beam_num:02}.conv.fits" images = list(datadir.glob(wild)) if len(images) == 0: - raise Exception(f"No images found matching '{wild}'") + msg = f"No images found matching '{wild}'" + raise Exception(msg) elif len(images) > 1: - raise Exception(f"More than one image found matching '{wild}'. Files {images=}") + msg = f"More than one image found matching '{wild}'. Files {images=}" + raise Exception(msg) image_name = images[0] @@ -413,7 +417,7 @@ def big_cutout( logger.critical(f"Limiting to {limit} islands") sources = sources[:limit] - updates: List[pymongo.UpdateOne] = [] + updates: list[pymongo.UpdateOne] = [] with ThreadPoolExecutor() as executor: futures = [] for _, source in sources.iterrows(): @@ -449,13 +453,13 @@ def cutout_islands( directory: Path, host: str, epoch: int, - sbid: Optional[int] = None, - username: Optional[str] = None, - password: Optional[str] = None, + sbid: int | None = None, + username: str | None = None, + password: str | None = None, pad: float = 3, - stokeslist: Optional[List[str]] = None, + stokeslist: list[str] | None = None, dryrun: bool = True, - limit: Optional[int] = None, + limit: int | None = None, ) -> None: """Flow to cutout islands in parallel. @@ -501,13 +505,14 @@ def cutout_islands( field_col=field_col, ) if not sbid_check: - raise ValueError(f"SBID {sbid} does not match field {field}") + msg = f"SBID {sbid} does not match field {field}" + raise ValueError(msg) query = {"$and": [{f"beams.{field}": {"$exists": True}}]} if sbid is not None: query["$and"].append({f"beams.{field}.SBIDs": sbid}) - unique_beams_nums: Set[int] = set( + unique_beams_nums: set[int] = set( beams_col.distinct(f"beams.{field}.beam_list", query) ) source_ids = sorted(beams_col.distinct("Source_ID", query)) @@ -530,21 +535,21 @@ def cutout_islands( ) beam_source_list = [] - for i, row in tqdm(beams_df.iterrows()): + for _i, row in tqdm(beams_df.iterrows()): beam_list = row.beams[field]["beam_list"] for b in beam_list: beam_source_list.append({"Source_ID": row.Source_ID, "beam": b}) beam_source_df = pd.DataFrame(beam_source_list) - beam_source_df.set_index("beam", inplace=True) + beam_source_df = beam_source_df.set_index("beam") comps_df = pd.DataFrame( comp_col.find({"Source_ID": {"$in": source_ids}}).sort("Source_ID") ) - comps_df.set_index("Source_ID", inplace=True) + comps_df = comps_df.set_index("Source_ID") # Create output dir if it doesn't exist outdir.mkdir(parents=True, exist_ok=True) - cuts: List[pymongo.UpdateOne] = [] + cuts: list[pymongo.UpdateOne] = [] for stoke in stokeslist: for beam_num in unique_beams_nums: results = big_cutout.submit( diff --git a/arrakis/frion.py b/arrakis/frion.py old mode 100644 new mode 100755 index 76716e6c..e4c4bcff --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -1,12 +1,14 @@ #!/usr/bin/env python3 """Correct for the ionosphere in parallel""" +from __future__ import annotations + import argparse import logging import os from pathlib import Path from pprint import pformat -from typing import Callable, Dict, List, Optional, Union +from typing import Callable from typing import NamedTuple as Struct from urllib.error import URLError @@ -46,7 +48,7 @@ class FrionResults(Struct): @task(name="FRion correction") def correct_worker( - beam: Dict, outdir: str, field: str, prediction: Prediction, island: dict + beam: dict, outdir: str, field: str, prediction: Prediction, island: dict ) -> pymongo.UpdateOne: """Apply FRion corrections to a single island @@ -88,9 +90,9 @@ def correct_worker( @task(name="FRion predction") def predict_worker( - island: Dict, + island: dict, field: str, - beam: Dict, + beam: dict, start_time: Time, end_time: Time, freq: np.ndarray, @@ -98,8 +100,8 @@ def predict_worker( plotdir: Path, server: str = "ftp://ftp.aiub.unibe.ch/CODE/", prefix: str = "", - formatter: Optional[Union[str, Callable]] = None, - proxy_server: Optional[str] = None, + formatter: str | Callable | None = None, + proxy_server: str | None = None, pre_download: bool = False, ) -> Prediction: """Make FRion prediction for a single island @@ -169,9 +171,8 @@ def predict_worker( logger.warning("Trying next prefix.") continue else: - raise FileNotFoundError( - f"Could not find IONEX file with prefixes {_prefixes_to_try}" - ) + msg = f"Could not find IONEX file with prefixes {_prefixes_to_try}" + raise FileNotFoundError(msg) predict_file = os.path.join(i_dir, f"{iname}_ion.txt") predict.write_modulation(freq_array=freq, theta=theta, filename=predict_file) @@ -198,11 +199,10 @@ def predict_worker( @task(name="Index beams") -def index_beams(island: dict, beams: List[dict]) -> dict: +def index_beams(island: dict, beams: list[dict]) -> dict: island_id = island["Source_ID"] - beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] - beam = beams[beam_idx] - return beam + beam_idx = next(i for i, b in enumerate(beams) if b["Source_ID"] == island_id) + return beams[beam_idx] # We reduce the inner loop to a serial call @@ -219,8 +219,8 @@ def serial_loop( plotdir: Path, ionex_server: str, ionex_prefix: str, - ionex_proxy_server: Optional[str], - ionex_formatter: Optional[Union[str, Callable]], + ionex_proxy_server: str | None, + ionex_formatter: str | Callable | None, ionex_predownload: bool, ) -> FrionResults: prediction = predict_worker.fn( @@ -255,16 +255,16 @@ def main( outdir: Path, host: str, epoch: int, - sbid: Optional[int] = None, - username: Optional[str] = None, - password: Optional[str] = None, + sbid: int | None = None, + username: str | None = None, + password: str | None = None, database=False, ionex_server: str = "ftp://ftp.aiub.unibe.ch/CODE/", ionex_prefix: str = "codg", - ionex_proxy_server: Optional[str] = None, - ionex_formatter: Optional[Union[str, Callable]] = "ftp.aiub.unibe.ch", + ionex_proxy_server: str | None = None, + ionex_formatter: str | Callable | None = "ftp.aiub.unibe.ch", ionex_predownload: bool = False, - limit: Optional[int] = None, + limit: int | None = None, ): """FRion flow @@ -307,7 +307,8 @@ def main( field_col=field_col, ) if not sbid_check: - raise ValueError(f"SBID {sbid} does not match field {field}") + msg = f"SBID {sbid} does not match field {field}" + raise ValueError(msg) query_1 = {"$and": [{f"beams.{field}": {"$exists": True}}]} @@ -333,7 +334,8 @@ def main( # Raise error if too much or too little data if field_col.count_documents(query_3) > 1: logger.error(f"More than one SELECT=1 for {field} - try supplying SBID.") - raise ValueError(f"More than one SELECT=1 for {field} - try supplying SBID.") + msg = f"More than one SELECT=1 for {field} - try supplying SBID." + raise ValueError(msg) elif field_col.count_documents(query_3) == 0: logger.error(f"No data for {field} with {query_3}, trying without SELECT=1.") @@ -361,7 +363,7 @@ def main( beams_cor = [] for island in islands: island_id = island["Source_ID"] - beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] + beam_idx = next(i for i, b in enumerate(beams) if b["Source_ID"] == island_id) beam = beams[beam_idx] beams_cor.append(beam) diff --git a/arrakis/imager.py b/arrakis/imager.py old mode 100644 new mode 100755 index 1bb568d3..0a6e6008 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Arrkis imager""" +from __future__ import annotations + import argparse import hashlib import logging @@ -11,11 +13,11 @@ from glob import glob from pathlib import Path from subprocess import CalledProcessError -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any from typing import NamedTuple as Struct import astropy.units as u -import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from astropy.io import fits @@ -51,7 +53,7 @@ workdir_arg_parser, ) -matplotlib.use("Agg") +mpl.use("Agg") TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) @@ -63,9 +65,9 @@ class ImageSet(Struct): """Path to the measurement set that was imaged.""" prefix: str """Prefix used for the wsclean output files.""" - image_lists: Dict[str, List[str]] + image_lists: dict[str, list[str]] """Dictionary of lists of images. The keys are the polarisations and the values are the list of images for that polarisation.""" - aux_lists: Optional[Dict[Tuple[str, str], List[str]]] = None + aux_lists: dict[tuple[str, str], list[str]] | None = None """Dictionary of lists of auxillary images. The keys are a tuple of the polarisation and the image type, and the values are the list of images for that polarisation and image type.""" @@ -82,13 +84,13 @@ class MFSImage(Struct): @task(name="Get pol. axis") def get_pol_axis_task( - ms: Path, feed_idx: Optional[int] = None, col: str = "RECEPTOR_ANGLE" + ms: Path, feed_idx: int | None = None, col: str = "RECEPTOR_ANGLE" ) -> float: return get_pol_axis(ms=ms, feed_idx=feed_idx, col=col).to(u.deg).value @task(name="Merge ImageSets") -def merge_imagesets(image_sets: List[Optional[ImageSet]]) -> ImageSet: +def merge_imagesets(image_sets: list[ImageSet | None]) -> ImageSet: """Merge a collection of ImageSets into a single ImageSet. Args: @@ -134,7 +136,7 @@ def merge_imagesets(image_sets: List[Optional[ImageSet]]) -> ImageSet: def get_mfs_image( - prefix_str: str, pol: str, small_size: Tuple[int, int] = (512, 512) + prefix_str: str, pol: str, small_size: tuple[int, int] = (512, 512) ) -> MFSImage: """Get the MFS image from the image set. @@ -181,14 +183,14 @@ def make_validation_plots(prefix: Path, pols: str) -> None: mfs_image = get_mfs_image(prefix_str, stokes) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) for ax, sub_image, title in zip(axs, mfs_image, ("Image", "Model", "Residual")): - sub_image = np.abs(sub_image) + abs_sub_image = np.abs(sub_image) if title == "Model": norm = ImageNormalize( - sub_image, interval=MinMaxInterval(), stretch=SqrtStretch() + abs_sub_image, interval=MinMaxInterval(), stretch=SqrtStretch() ) else: norm = ImageNormalize(mfs_image.residual, vmin=0, stretch=SqrtStretch()) - _ = ax.imshow(sub_image, origin="lower", norm=norm, cmap="cubehelix") + _ = ax.imshow(abs_sub_image, origin="lower", norm=norm, cmap="cubehelix") ax.set_title(title) ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) @@ -207,7 +209,7 @@ def make_validation_plots(prefix: Path, pols: str) -> None: logger.info(f"Uploaded {fig_name} to {uuid}") -def get_wsclean(wsclean: Union[Path, str]) -> Path: +def get_wsclean(wsclean: Path | str) -> Path: """Pull wsclean image from dockerhub (or wherver). Args: @@ -249,7 +251,7 @@ def cleanup_imageset(purge: bool, image_set: ImageSet) -> None: # they were just copied across directly without modification if image_set.aux_lists: logger.critical("Removing auxillary images. ") - for (pol, aux), aux_list in image_set.aux_lists.items(): + for (pol, _aux), aux_list in image_set.aux_lists.items(): for aux_image in aux_list: try: logger.critical(f"Removing {aux_image}") @@ -298,25 +300,25 @@ def image_beam( mgain: float = 0.7, niter: int = 100_000, auto_mask: float = 3, - force_mask_rounds: Optional[int] = None, + force_mask_rounds: int | None = None, auto_threshold: float = 1, - gridder: Optional[str] = None, + gridder: str | None = None, robust: float = -0.5, mem: float = 90, - absmem: Optional[float] = None, - taper: Optional[float] = None, + absmem: float | None = None, + taper: float | None = None, minuv_l: float = 0.0, - parallel_deconvolution: Optional[int] = None, - nmiter: Optional[int] = None, + parallel_deconvolution: int | None = None, + nmiter: int | None = None, local_rms: bool = False, - local_rms_window: Optional[float] = None, + local_rms_window: float | None = None, multiscale: bool = False, - multiscale_scale_bias: Optional[float] = None, - multiscale_scales: Optional[str] = "0,2,4,8,16,32,64,128", + multiscale_scale_bias: float | None = None, + multiscale_scales: str | None = "0,2,4,8,16,32,64,128", data_column: str = "CORRECTED_DATA", no_mf_weighting: bool = False, no_update_model_required: bool = True, - beam_fitting_size: Optional[float] = 1.25, + beam_fitting_size: float | None = 1.25, disable_pol_local_rms: bool = False, disable_pol_force_mask_rounds: bool = False, ) -> ImageSet: @@ -436,9 +438,8 @@ def image_beam( logger.info(line.rstrip()) # Catch divergence - look for the string 'KJy' in the output if "KJy" in line: - raise ValueError( - f"Detected divergence in wsclean output: {line.rstrip()}" - ) + msg = f"Detected divergence in wsclean output: {line.rstrip()}" + raise ValueError(msg) except CalledProcessError as e: logger.error(f"Failed to run wsclean with command: {command}") logger.error(f"Stdout: {e.stdout}") @@ -449,7 +450,7 @@ def image_beam( # Purge ms_temp shutil.rmtree(ms_temp) - suffixes: List[str] = ["image", "model", "psf", "residual", "dirty"] + suffixes: list[str] = ["image", "model", "psf", "residual", "dirty"] if temp_dir_images != out_dir: # Copy the images to the output directory logger.info(f"Copying images to {out_dir}") @@ -547,8 +548,8 @@ def make_cube( image_set: ImageSet, common_beam_pkl: Path, pol_angle_deg: float, - aux_mode: Optional[str] = None, -) -> Tuple[Path, Path]: + aux_mode: str | None = None, +) -> tuple[Path, Path]: """Make a cube from the images""" logger = get_run_logger() @@ -609,10 +610,10 @@ def make_cube( new_w_name = new_name.replace( f"image.{image_type}", f"weights.{image_type}" ).replace(".fits", ".txt") - data = dict( - Channel=np.arange(len(rmss_arr)), - Weight=1 / rmss_arr**2, # Want inverse variance - ) + data = { + "Channel": np.arange(len(rmss_arr)), + "Weight": 1 / rmss_arr**2, # Want inverse variance + } tab = Table(data) tab.write(new_w_name, format="ascii.commented_header", overwrite=True) @@ -620,7 +621,7 @@ def make_cube( @task(name="Get Beam", persist_result=True) -def get_beam(image_set: ImageSet, cutoff: Optional[float]) -> Path: +def get_beam(image_set: ImageSet, cutoff: float | None) -> Path: """Derive a common resolution across all images within a set of ImageSet Args: @@ -652,7 +653,8 @@ def get_beam(image_set: ImageSet, cutoff: Optional[float]) -> Path: logger.info(f"The common beam is: {common_beam=}") if any([np.isnan(common_beam.major), np.isnan(common_beam.minor)]): - raise ValueError("Common beam is NaN, consider raising the cutoff.") + msg = "Common beam is NaN, consider raising the cutoff." + raise ValueError(msg) # serialise the beam common_beam_pkl = Path(f"beam_{image_hash}.pkl") @@ -668,8 +670,8 @@ def get_beam(image_set: ImageSet, cutoff: Optional[float]) -> Path: def smooth_imageset( image_set: ImageSet, common_beam_pkl: Path, - cutoff: Optional[float] = None, - aux_mode: Optional[str] = None, + cutoff: float | None = None, + aux_mode: str | None = None, ) -> ImageSet: """Smooth all images described within an ImageSet to a desired resolution @@ -694,7 +696,7 @@ def smooth_imageset( logger.info(f"Smooting {image_set.ms} images") - images_to_smooth: Dict[str, List[str]] + images_to_smooth: dict[str, list[str]] if aux_mode is None: images_to_smooth = image_set.image_lists else: @@ -738,7 +740,7 @@ def smooth_imageset( @task(name="Cleanup") def cleanup( - purge: bool, image_sets: List[ImageSet], ignore_files: Optional[List[Any]] = None + purge: bool, image_sets: list[ImageSet], ignore_files: list[Any] | None = None ) -> None: """Utility to remove all images described by an collection of ImageSets. Internally called `cleanup_imageset`. @@ -751,7 +753,7 @@ def cleanup( """ logger = get_run_logger() - logger.warn(f"Ignoring files in {ignore_files=}. ") + logger.warning(f"Ignoring files in {ignore_files=}. ") if not purge: logger.info("Not purging intermediate files") @@ -792,7 +794,7 @@ def fix_ms_askap_corrs(ms: Path, *args, **kwargs) -> Path: """ logger = get_run_logger() - logger.info(f"Correcting {str(ms)} correlations for wsclean. ") + logger.info(f"Correcting {ms!s} correlations for wsclean. ") fix_ms_corrs(ms=ms, *args, **kwargs) @@ -804,9 +806,9 @@ def main( msdir: Path, out_dir: Path, num_beams: int = 36, - temp_dir_images: Optional[Path] = None, - temp_dir_wsclean: Optional[Path] = None, - cutoff: Optional[float] = None, + temp_dir_images: Path | None = None, + temp_dir_wsclean: Path | None = None, + cutoff: float | None = None, robust: float = -0.5, pols: str = "IQU", nchan: int = 36, @@ -815,22 +817,22 @@ def main( mgain: float = 0.8, niter: int = 100_000, auto_mask: float = 3, - force_mask_rounds: Union[int, None] = None, + force_mask_rounds: int | None = None, auto_threshold: float = 1, - taper: Union[float, None] = None, + taper: float | None = None, purge: bool = False, minuv: float = 0.0, - parallel_deconvolution: Optional[int] = None, - gridder: Optional[str] = None, - nmiter: Optional[int] = None, + parallel_deconvolution: int | None = None, + gridder: str | None = None, + nmiter: int | None = None, local_rms: bool = False, - local_rms_window: Optional[float] = None, - wsclean_path: Union[Path, str] = "docker://alecthomson/wsclean:latest", - multiscale: Optional[bool] = None, - multiscale_scale_bias: Optional[float] = None, - multiscale_scales: Optional[str] = "0,2,4,8,16,32,64,128", - absmem: Optional[float] = None, - make_residual_cubes: Optional[bool] = False, + local_rms_window: float | None = None, + wsclean_path: Path | str = "docker://alecthomson/wsclean:latest", + multiscale: bool | None = None, + multiscale_scale_bias: float | None = None, + multiscale_scales: str | None = "0,2,4,8,16,32,64,128", + absmem: float | None = None, + make_residual_cubes: bool | None = False, ms_glob_pattern: str = "scienceData*_averaged_cal.leakage.ms", data_column: str = "CORRECTED_DATA", skip_fix_ms: bool = False, @@ -1072,8 +1074,6 @@ def main( logger.info("Imager finished!") - return - def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser: """Return the argument parser for the imager routine. diff --git a/arrakis/init_database.py b/arrakis/init_database.py old mode 100644 new mode 100755 index 85c4787b..54679d43 --- a/arrakis/init_database.py +++ b/arrakis/init_database.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 """Create the Arrakis database""" +from __future__ import annotations + import json import logging import time from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -41,11 +42,10 @@ def source2beams(ra: float, dec: float, database: Table, max_sep: float = 1) -> c1 = SkyCoord(database["RA_DEG"] * u.deg, database["DEC_DEG"] * u.deg, frame="icrs") c2 = SkyCoord(ra * u.deg, dec * u.deg, frame="icrs") sep = c1.separation(c2) - beams = database[sep < max_sep * u.deg] - return beams + return database[sep < max_sep * u.deg] -def ndix_unique(x: np.ndarray) -> Tuple[np.ndarray, List[np.ndarray]]: +def ndix_unique(x: np.ndarray) -> tuple[np.ndarray, list[np.ndarray]]: """Find the N-dimensional array of indices of the unique values in x From https://stackoverflow.com/questions/54734545/indices-of-unique-values-in-array @@ -67,7 +67,7 @@ def ndix_unique(x: np.ndarray) -> Tuple[np.ndarray, List[np.ndarray]]: def cat2beams( mastercat: Table, database: Table, max_sep: float = 1 -) -> Tuple[np.ndarray, np.ndarray, Angle]: +) -> tuple[np.ndarray, np.ndarray, Angle]: """Find the separations between sources in the master catalogue and the RACS beams Args: @@ -89,8 +89,7 @@ def cat2beams( m_dec = m_dec * u.deg c2 = SkyCoord(m_ra, m_dec, frame="icrs") - seps = search_around_sky(c1, c2, seplimit=max_sep * u.degree) - return seps + return search_around_sky(c1, c2, seplimit=max_sep * u.degree) def source_database( @@ -98,9 +97,9 @@ def source_database( compcat: Table, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, -) -> Tuple[InsertManyResult, InsertManyResult]: + username: str | None = None, + password: str | None = None, +) -> tuple[InsertManyResult, InsertManyResult]: """Insert sources into the database Following https://medium.com/analytics-vidhya/how-to-upload-a-pandas-dataframe-to-mongodb-ffa18c0953c1 @@ -177,8 +176,8 @@ def beam_database( islandcat: Table, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: str | None = None, + password: str | None = None, ) -> InsertManyResult: """Insert beams into the database @@ -262,7 +261,7 @@ def get_catalogue(survey_dir: Path, epoch: int = 0) -> Table: return racs_fields -def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> List[Dict]: +def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> list[dict]: """Get beams from the master catalogue Args: @@ -296,14 +295,14 @@ def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> List[Dict]: ) beam_list = [] - for i, (val, idx) in enumerate( + for _i, (val, idx) in enumerate( tqdm(zip(vals, ixs), total=len(vals), desc="Getting beams", file=TQDM_OUT) ): beam_dict = {} name = mastercat[val]["Source_Name"] isl_id = mastercat[val]["Source_ID"] beams = database[seps[0][idx.astype(int)]] - for j, field in enumerate(np.unique(beams["FIELD_NAME"])): + for _j, field in enumerate(np.unique(beams["FIELD_NAME"])): ndx = beams["FIELD_NAME"] == field field = field.replace("_test4_1.05_", "_") if epoch == 0 else field beam_dict.update( @@ -333,11 +332,11 @@ def beam_inf( survey_dir: Path, host: str, epoch: int, - username: Optional[str] = None, - password: Optional[str] = None, + username: str | None = None, + password: str | None = None, ) -> InsertManyResult: """Get the beam information""" - tabs: List[Table] = [] + tabs: list[Table] = [] for row in tqdm(database, desc="Reading beam info", file=TQDM_OUT): try: tab = read_racs_database( @@ -403,7 +402,8 @@ def read_racs_database( basedir = survey_dir / "db" / epoch_name data_file = basedir / f"{table}.csv" if not data_file.exists(): - raise FileNotFoundError(f"{data_file} not found!") + msg = f"{data_file} not found!" + raise FileNotFoundError(msg) return Table.read(data_file) @@ -412,9 +412,9 @@ def field_database( survey_dir: Path, host: str, epoch: int, - username: Optional[str] = None, - password: Optional[str] = None, -) -> Tuple[InsertManyResult, InsertManyResult]: + username: str | None = None, + password: str | None = None, +) -> tuple[InsertManyResult, InsertManyResult]: """Reset and load the field database Args: @@ -462,14 +462,14 @@ def field_database( def main( load: bool = False, - islandcat: Optional[str] = None, - compcat: Optional[str] = None, - database_path: Optional[Path] = None, + islandcat: str | None = None, + compcat: str | None = None, + database_path: Path | None = None, host: str = "localhost", - username: Optional[str] = None, - password: Optional[str] = None, + username: str | None = None, + password: str | None = None, field: bool = False, - epochs: List[int] = 0, + epochs: list[int] = 0, force: bool = False, ) -> None: """Main script diff --git a/arrakis/linmos.py b/arrakis/linmos.py old mode 100644 new mode 100755 index 09d998e2..f1473f7c --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Run LINMOS on cutouts in parallel""" +from __future__ import annotations + import argparse import logging import os @@ -9,7 +11,6 @@ from glob import glob from pathlib import Path from pprint import pformat -from typing import Dict, List, Optional, Tuple from typing import NamedTuple as Struct import astropy.units as u @@ -41,16 +42,16 @@ class ImagePaths(Struct): """Class to hold image paths""" - images: List[Path] + images: list[Path] """List of image paths""" - weights: List[Path] + weights: list[Path] """List of weight paths""" @task(name="Find images") def find_images( field: str, - beams_row: Tuple[int, pd.Series], + beams_row: tuple[int, pd.Series], stoke: str, datadir: Path, ) -> ImagePaths: @@ -74,7 +75,7 @@ def find_images( field_beams = beams.beams[field] # First check that the images exist - image_list: List[Path] = [] + image_list: list[Path] = [] for bm in list(set(field_beams["beam_list"])): # Ensure list of beams is unique! imfile = Path(field_beams[f"{stoke.lower()}_beam{bm}_image_file"]) assert ( @@ -85,9 +86,10 @@ def find_images( image_list = sorted(image_list) if len(image_list) == 0: - raise Exception("No files found. Have you run imaging? Check your prefix?") + msg = "No files found. Have you run imaging? Check your prefix?" + raise Exception(msg) - weight_list: List[Path] = [] + weight_list: list[Path] = [] for bm in list(set(field_beams["beam_list"])): # Ensure list of beams is unique! wgtsfile = Path(field_beams[f"{stoke.lower()}_beam{bm}_weight_file"]) assert ( @@ -109,8 +111,8 @@ def find_images( @task(name="Smooth images") def smooth_images( - image_dict: Dict[str, ImagePaths], -) -> Dict[str, ImagePaths]: + image_dict: dict[str, ImagePaths], +) -> dict[str, ImagePaths]: """Smooth cubelets to a common resolution Args: @@ -119,9 +121,9 @@ def smooth_images( Returns: ImagePaths: Smoothed cubelets. """ - smooth_dict: Dict[str, ImagePaths] = {} + smooth_dict: dict[str, ImagePaths] = {} for stoke, image_list in image_dict.items(): - infiles: List[str] = [] + infiles: list[str] = [] for im in image_list.images: if im.suffix == ".fits": infiles.append(im.resolve().as_posix()) @@ -132,8 +134,8 @@ def smooth_images( conv_mode="robust", suffix="cres", ) - smooth_files: List[Path] = [] - for key, val in datadict.items(): + smooth_files: list[Path] = [] + for _key, val in datadict.items(): smooth_files.append(Path(val["outfile"])) smooth_dict[stoke] = ImagePaths(smooth_files, image_list.weights) @@ -145,7 +147,7 @@ def genparset( image_paths: ImagePaths, stoke: str, datadir: Path, - holofile: Optional[Path] = None, + holofile: Path | None = None, ) -> str: """Generate parset for LINMOS @@ -163,7 +165,7 @@ def genparset( """ logger.setLevel(logging.INFO) - pol_angles_list: List[float] = [] + pol_angles_list: list[float] = [] for im in image_paths.images: _pol_angle: float = fits.getheader(im)["INSTRUMENT_RECEPTOR_ANGLE"] pol_angles_list.append(_pol_angle) @@ -221,8 +223,8 @@ def genparset( @task(name="Run linmos") def linmos( - parset: Optional[str], fieldname: str, image: str, holofile: Path -) -> Optional[pymongo.UpdateOne]: + parset: str | None, fieldname: str, image: str, holofile: Path +) -> pymongo.UpdateOne | None: """Run linmos Args: @@ -242,7 +244,7 @@ def linmos( logger.setLevel(logging.INFO) if parset is None: - return + return None workdir = os.path.dirname(parset) rootdir = os.path.split(workdir)[0] @@ -272,7 +274,8 @@ def linmos( new_files = glob(f"{workdir}/*.cutout.image.restored.{stoke.lower()}*.linmos.fits") if len(new_files) != 1: - raise Exception(f"LINMOS file not found! -- check {log_file}?") + msg = f"LINMOS file not found! -- check {log_file}?" + raise Exception(msg) new_file = os.path.abspath(new_files[0]) outer = os.path.basename(os.path.dirname(new_file)) @@ -297,8 +300,7 @@ def get_yanda(version="1.3.0") -> str: str: Path to yandasoft image. """ sclient.load(f"docker://csirocass/yandasoft:{version}-galaxy") - image = os.path.abspath(sclient.pull()) - return image + return os.path.abspath(sclient.pull()) # We reduce the inner loop to a serial call @@ -306,12 +308,12 @@ def get_yanda(version="1.3.0") -> str: @task(name="LINMOS loop") def serial_loop( field: str, - beams_row: Tuple[int, pd.Series], - stokeslist: List[str], + beams_row: tuple[int, pd.Series], + stokeslist: list[str], cutdir: Path, holofile: Path, image: Path, -) -> List[Optional[pymongo.UpdateOne]]: +) -> list[pymongo.UpdateOne | None]: results = [] for stoke in stokeslist: image_path = find_images.fn( @@ -343,14 +345,14 @@ def main( datadir: Path, host: str, epoch: int, - sbid: Optional[int] = None, - holofile: Optional[Path] = None, - username: Optional[str] = None, - password: Optional[str] = None, + sbid: int | None = None, + holofile: Path | None = None, + username: str | None = None, + password: str | None = None, yanda: str = "1.3.0", - yanda_img: Optional[Path] = None, - stokeslist: Optional[List[str]] = None, - limit: Optional[int] = None, + yanda_img: Path | None = None, + stokeslist: list[str] | None = None, + limit: int | None = None, ) -> None: """LINMOS flow @@ -387,7 +389,7 @@ def main( logger.info(f"The query is {query=}") - island_ids: List[str] = sorted(beams_col.distinct("Source_ID", query)) + island_ids: list[str] = sorted(beams_col.distinct("Source_ID", query)) big_beams = pd.DataFrame( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) @@ -415,7 +417,7 @@ def main( ) results.append(sub_results) - updates_lists: List[list] = [f.result() for f in results] + updates_lists: list[list] = [f.result() for f in results] # Flatten updates = [u for ul in updates_lists for u in ul] updates = [u for u in updates if u is not None] diff --git a/arrakis/logger.py b/arrakis/logger.py old mode 100644 new mode 100755 index 287936d1..a4846020 --- a/arrakis/logger.py +++ b/arrakis/logger.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Logging module for arrakis""" +from __future__ import annotations + import argparse import io import logging diff --git a/arrakis/makecat.py b/arrakis/makecat.py old mode 100644 new mode 100755 index 6ec80be6..25c15965 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Make an Arrakis catalogue""" +from __future__ import annotations + import argparse import logging import os @@ -8,7 +10,7 @@ import warnings from pathlib import Path from pprint import pformat -from typing import Callable, NamedTuple, Optional, Tuple, Union +from typing import Callable, NamedTuple import astropy.units as u import dask.dataframe as dd @@ -56,7 +58,7 @@ class SpectralIndices(NamedTuple): betas_err: np.ndarray -def combinate(data: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: +def combinate(data: ArrayLike) -> tuple[ArrayLike, ArrayLike]: """Return all combinations of data with itself Args: @@ -156,7 +158,7 @@ def is_blended_component(sub_df: pd.DataFrame) -> pd.DataFrame: name="blend_ratio", dtype=float, ) - df = pd.DataFrame( + return pd.DataFrame( { "is_blended_flag": is_blended, "N_blended": n_blended, @@ -165,10 +167,8 @@ def is_blended_component(sub_df: pd.DataFrame) -> pd.DataFrame: index=sub_df.index, ) - return df - df = cat.to_pandas() - df.set_index("cat_id", inplace=True) + df = df.set_index("cat_id") ddf = dd.from_pandas(df, chunksize=1000) grp = ddf.groupby("source_id") logger.info("Identifying blended components...") @@ -318,7 +318,7 @@ def get_fit_func( degree: int = 2, do_plot: bool = False, high_snr_cut: float = 30.0, -) -> Tuple[Callable, plt.Figure]: +) -> tuple[Callable, plt.Figure]: """Fit an envelope to define leakage sources Args: @@ -445,12 +445,12 @@ def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: logger.info(f"Number of available sources: {len(good_cat)}.") df = good_cat.to_pandas() - df.reset_index(inplace=True) - df.set_index("cat_id", inplace=True) + df = df.reset_index() + df = df.set_index("cat_id") df_out = big_cat.to_pandas() - df_out.reset_index(inplace=True) - df_out.set_index("cat_id", inplace=True) + df_out = df_out.reset_index() + df_out = df_out.set_index("cat_id") df_out["local_rm_flag"] = False try: @@ -532,7 +532,7 @@ def masker(x): ) # Put flag into the catalogue df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] - df.drop(columns=["bin_number"], inplace=True) + df = df.drop(columns=["bin_number"]) df_out.update(df["local_rm_flag"]) except Exception as e: @@ -658,9 +658,7 @@ def get_alpha(cat: TableLike) -> SpectralIndices: @task(name="Get integration times") -def get_integration_time( - cat: RMTable, field_col: Collection, sbid: Optional[int] = None -): +def get_integration_time(cat: RMTable, field_col: Collection, sbid: int | None = None): logger.warning("Will be stripping the trailing field character prefix. ") field_names = [ name[:-1] if name[-1] in ("A", "B") else name for name in list(cat["tile_id"]) @@ -692,13 +690,14 @@ def get_integration_time( doc_count = field_col.count_documents(query) if doc_count == 0: - raise ValueError(f"No data for query {query}") + msg = f"No data for query {query}" + raise ValueError(msg) else: logger.warning("Using SELECT=0 instead.") field_data = list(field_col.find(query, reutrn_vals)) tint_df = pd.DataFrame(field_data) - tint_df.set_index("FIELD_NAME", inplace=True, drop=False) + tint_df = tint_df.set_index("FIELD_NAME", drop=False) # Check for duplicates if len(tint_df.index) != len(set(tint_df.index)): @@ -783,7 +782,6 @@ def replace_nans(filename: str): Args: filename (str): File name """ - pass # with open(filename, "r") as f: # xml = f.read() # xml = xml.replace("NaN", "null") @@ -850,8 +848,8 @@ def update_tile_separations(rmtab: TableLike, field_col: Collection) -> TableLik {"FIELD_NAME": {"$in": list(field_names)}}, ) ) - field_data.drop_duplicates(subset=["FIELD_NAME"], inplace=True) - field_data.set_index("FIELD_NAME", inplace=True) + field_data = field_data.drop_duplicates(subset=["FIELD_NAME"]) + field_data = field_data.set_index("FIELD_NAME") field_coords = SkyCoord( ra=field_data["RA_DEG"], dec=field_data["DEC_DEG"], unit=(u.deg, u.deg) @@ -897,14 +895,14 @@ def main( field: str, host: str, epoch: int, - sbid: Optional[int] = None, + sbid: int | None = None, leakage_degree: int = 4, leakage_bins: int = 16, leakage_snr: float = 30.0, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: str | None = None, + password: str | None = None, verbose: bool = True, - outfile: Union[str, None] = None, + outfile: str | None = None, ) -> None: """Make a catalogue from the Arrakis database flow @@ -936,7 +934,8 @@ def main( field_col=field_col, ) if not sbid_check: - raise ValueError(f"SBID {sbid} does not match field {field}") + msg = f"SBID {sbid} does not match field {field}" + raise ValueError(msg) logger.info("Starting beams collection query") tick = time.time() @@ -1021,7 +1020,7 @@ def main( # subset=["rmclean_summary", "rmsynth_summary", "rmclean1d", "rmsynth1d"], # inplace=True, # ) - comps_df.set_index("Source_ID", inplace=True) + comps_df = comps_df.set_index("Source_ID") tock = time.time() logger.info(f"Finished component collection query - {tock-tick:.2f}s") logger.info(f"Found {len(comps_df)} components to catalogue. ") @@ -1029,17 +1028,18 @@ def main( logger.info("Starting island collection query") tick = time.time() islands_df = pd.DataFrame(island_col.find({"Source_ID": {"$in": all_island_ids}})) - islands_df.set_index("Source_ID", inplace=True) + islands_df = islands_df.set_index("Source_ID") tock = time.time() logger.info(f"Finished island collection query - {tock-tick:.2f}s") if len(comps_df) == 0: logger.error("No components found for this field.") - raise ValueError("No components found for this field.") + msg = "No components found for this field." + raise ValueError(msg) rmtab = RMTable() # Add items to main cat using RMtable standard - for j, [name, typ, src, col, unit] in enumerate( + for _j, [name, typ, src, col, unit] in enumerate( tqdm( zip( columns_possum.output_cols, @@ -1174,7 +1174,7 @@ def main( # Replace all infs with nans for col in rmtab.colnames: # Check if column is a float - if isinstance(rmtab[col][0], np.float_): + if isinstance(rmtab[col][0], np.float64): rmtab[col][np.isinf(rmtab[col])] = np.nan # Convert all mJy to Jy @@ -1209,7 +1209,7 @@ def main( logger.info(f"Writing {outfile} to disk") _, ext = os.path.splitext(outfile) - if ext == ".xml" or ext == ".vot": + if ext in (".xml", ".vot"): write_votable(rmtab, outfile) else: rmtab.write(outfile, overwrite=True) diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py old mode 100644 new mode 100755 index ce49d445..189c2141 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 """Merge multiple RACS fields""" +from __future__ import annotations + import argparse import os from pprint import pformat from shutil import copyfile -from typing import Dict, List, Optional import pymongo from prefect import flow, task @@ -17,16 +18,13 @@ def make_short_name(name: str) -> str: - short = os.path.join( - os.path.basename(os.path.dirname(name)), os.path.basename(name) - ) - return short + return os.path.join(os.path.basename(os.path.dirname(name)), os.path.basename(name)) @task(name="Copy singleton island") def copy_singleton( - beam: dict, field_dict: Dict[str, str], merge_name: str, data_dir: str -) -> List[pymongo.UpdateOne]: + beam: dict, field_dict: dict[str, str], merge_name: str, data_dir: str +) -> list[pymongo.UpdateOne]: """Copy an island within a single field to the merged field Args: @@ -43,7 +41,7 @@ def copy_singleton( """ updates = [] for field, vals in beam["beams"].items(): - if field not in field_dict.keys(): + if field not in field_dict: continue field_dir = field_dict[field] try: @@ -51,7 +49,8 @@ def copy_singleton( q_file_old = os.path.join(field_dir, vals["q_file_ion"]) u_file_old = os.path.join(field_dir, vals["u_file_ion"]) except KeyError: - raise KeyError("Ion files not found. Have you run FRion?") + msg = "Ion files not found. Have you run FRion?" + raise KeyError(msg) new_dir = os.path.join(data_dir, beam["Source_ID"]) try_mkdir(new_dir, verbose=False) @@ -93,11 +92,11 @@ def copy_singleton( def copy_singletons( - field_dict: Dict[str, str], + field_dict: dict[str, str], data_dir: str, beams_col: pymongo.collection.Collection, merge_name: str, -) -> List[pymongo.UpdateOne]: +) -> list[pymongo.UpdateOne]: """Copy islands that don't overlap other fields Args: @@ -119,7 +118,7 @@ def copy_singletons( {"n_fields_DR1": 1}, ] } - for field in field_dict.keys() + for field in field_dict ] } @@ -176,7 +175,7 @@ def genparset( def merge_multiple_field( beam: dict, field_dict: dict, merge_name: str, data_dir: str, image: str -) -> List[pymongo.UpdateOne]: +) -> list[pymongo.UpdateOne]: """Merge an island that overlaps multiple fields Args: @@ -197,7 +196,7 @@ def merge_multiple_field( u_files_old = [] updates = [] for field, vals in beam["beams"].items(): - if field not in field_dict.keys(): + if field not in field_dict: continue field_dir = field_dict[field] try: @@ -205,7 +204,8 @@ def merge_multiple_field( q_file_old = os.path.join(field_dir, vals["q_file_ion"]) u_file_old = os.path.join(field_dir, vals["u_file_ion"]) except KeyError: - raise KeyError("Ion files not found. Have you run FRion?") + msg = "Ion files not found. Have you run FRion?" + raise KeyError(msg) i_files_old.append(i_file_old) q_files_old.append(q_file_old) u_files_old.append(u_file_old) @@ -224,12 +224,12 @@ def merge_multiple_field( @task(name="Merge multiple islands") def merge_multiple_fields( - field_dict: Dict[str, str], + field_dict: dict[str, str], data_dir: str, beams_col: pymongo.collection.Collection, merge_name: str, image: str, -) -> List[pymongo.UpdateOne]: +) -> list[pymongo.UpdateOne]: """Merge multiple islands that overlap multiple fields Args: @@ -252,7 +252,7 @@ def merge_multiple_fields( {"n_fields_DR1": {"$gt": 1}}, ] } - for field in field_dict.keys() + for field in field_dict ] } @@ -276,14 +276,14 @@ def merge_multiple_fields( @flow(name="Merge fields") def main( - fields: List[str], - field_dirs: List[str], + fields: list[str], + field_dirs: list[str], merge_name: str, output_dir: str, host: str, epoch: int, - username: Optional[str] = None, - password: Optional[str] = None, + username: str | None = None, + password: str | None = None, yanda="1.3.0", ) -> str: logger.debug(f"{fields=}") diff --git a/arrakis/process_region.py b/arrakis/process_region.py old mode 100644 new mode 100755 index c262f618..4eac0f4a --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Arrakis multi-field pipeline""" +from __future__ import annotations + import argparse import logging import os diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py old mode 100644 new mode 100755 index fe385b27..1d19093a --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Arrakis single-field pipeline""" +from __future__ import annotations + import argparse import logging import os diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py old mode 100644 new mode 100755 index 1521ab38..3edf38c8 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Run RM-synthesis on cutouts in parallel""" +from __future__ import annotations + import argparse import logging import os @@ -8,9 +10,8 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import Optional -import matplotlib +import matplotlib as mpl import numpy as np import pymongo from matplotlib import pyplot as plt @@ -29,7 +30,7 @@ ) from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser -matplotlib.use("Agg") +mpl.use("Agg") logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) @@ -42,7 +43,7 @@ def rmclean1d( cutoff: float = -3, maxIter=10000, gain=0.1, - sbid: Optional[int] = None, + sbid: int | None = None, savePlots=False, rm_verbose=True, window=None, @@ -83,7 +84,8 @@ def rmclean1d( logger.debug(f"Checking {f.absolute()}") if not f.exists(): logger.fatal(f"File does not exist: '{f}'.") - raise FileNotFoundError(f"File does not exist: '{f}'") + msg = f"File does not exist: '{f}'" + raise FileNotFoundError(msg) nBits = 32 try: @@ -112,13 +114,9 @@ def rmclean1d( # Ensure JSON serializable for k, v in outdict.items(): - if isinstance(v, np.float_): - outdict[k] = float(v) - elif isinstance(v, np.float32): + if isinstance(v, (np.float64, np.float32)): outdict[k] = float(v) - elif isinstance(v, np.int_): - outdict[k] = int(v) - elif isinstance(v, np.int32): + elif isinstance(v, (np.int_, np.int32)): outdict[k] = int(v) elif isinstance(v, np.ndarray): outdict[k] = v.tolist() @@ -151,7 +149,7 @@ def rmclean3d( field: str, island: dict, outdir: Path, - sbid: Optional[int] = None, + sbid: int | None = None, cutoff: float = -3, maxIter=10000, gain=0.1, @@ -218,13 +216,13 @@ def main( outdir: Path, host: str, epoch: int, - sbid: Optional[int] = None, - username: Optional[str] = None, - password: Optional[str] = None, + sbid: int | None = None, + username: str | None = None, + password: str | None = None, dimension="1d", database=False, savePlots=True, - limit: Optional[int] = None, + limit: int | None = None, cutoff: float = -3, maxIter=10000, gain=0.1, @@ -271,7 +269,8 @@ def main( field_col=field_col, ) if not sbid_check: - raise ValueError(f"SBID {sbid} does not match field {field}") + msg = f"SBID {sbid} does not match field {field}" + raise ValueError(msg) query = {"$and": [{f"beams.{field}": {"$exists": True}}]} if sbid is not None: @@ -401,7 +400,8 @@ def main( outputs.append(output) else: - raise ValueError(f"Dimension {dimension} not supported.") + msg = f"Dimension {dimension} not supported." + raise ValueError(msg) if database: logger.info("Updating database...") diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py old mode 100644 new mode 100755 index 4e1a0668..42041fd7 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Run RM-CLEAN on cutouts in parallel""" +from __future__ import annotations + import argparse import logging import os @@ -9,11 +11,10 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import List, Optional, Tuple, Union from typing import NamedTuple as Struct import astropy.units as u -import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -43,7 +44,7 @@ from arrakis.utils.fitting import fit_pl, fitted_mean, fitted_std from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser -matplotlib.use("Agg") +mpl.use("Agg") logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) @@ -77,30 +78,30 @@ class StokesSpectra(Struct): class StokesIFitResult(Struct): """Stokes I fit results""" - alpha: Optional[float] + alpha: float | None """The alpha parameter of the fit""" - amplitude: Optional[float] + amplitude: float | None """The amplitude parameter of the fit""" - x_0: Optional[float] + x_0: float | None """The x_0 parameter of the fit""" - model_repr: Optional[str] + model_repr: str | None """The model representation of the fit""" - modStokesI: Optional[np.ndarray] + modStokesI: np.ndarray | None """The model Stokes I spectrum""" - fit_dict: Optional[dict] + fit_dict: dict | None """The dictionary of the fit results""" @task(name="3D RM-synthesis") def rmsynthoncut3d( island_id: str, - beam_tuple: Tuple[str, pd.Series], + beam_tuple: tuple[str, pd.Series], outdir: Path, freq: np.ndarray, field: str, - sbid: Optional[int] = None, - phiMax_radm2: Optional[float] = None, - dPhi_radm2: Optional[float] = None, + sbid: int | None = None, + phiMax_radm2: float | None = None, + dPhi_radm2: float | None = None, nSamples: int = 5, weightType: str = "variance", fitRMSF: bool = True, @@ -198,7 +199,7 @@ def rmsynthoncut3d( # Prep header head_dict = dict(header) head_dict.pop("", None) - if "COMMENT" in head_dict.keys(): + if "COMMENT" in head_dict: head_dict["COMMENT"] = str(head_dict["COMMENT"]) outer_dir = os.path.basename(os.path.dirname(ifile)) @@ -223,7 +224,7 @@ def rmsynthoncut3d( ) -def cubelet_bane(cubelet: np.ndarray, header: fits.Header) -> Tuple[np.ndarray]: +def cubelet_bane(cubelet: np.ndarray, header: fits.Header) -> tuple[np.ndarray]: """Background and noise estimation on a cubelet Args: @@ -275,10 +276,7 @@ def extract_single_spectrum( outdir: Path, ) -> Spectrum: """Extract a single spectrum from a cubelet""" - if ion and (stokes == "q" or stokes == "u"): - key = f"{stokes}_file_ion" - else: - key = f"{stokes}_file" + key = f"{stokes}_file_ion" if ion and stokes in ("q", "u") else f"{stokes}_file" filename = outdir / field_dict[key] with fits.open(filename, mode="denywrite", memmap=True) as hdulist: hdu = hdulist[0] @@ -343,7 +341,7 @@ def sigma_clip_spectra( Returns: StokesSpectra: The filtered Stokes spectra """ - filter_list: List[np.ndarry] = [] + filter_list: list[np.ndarry] = [] for spectrum in stokes_spectra: rms_filter = sigma_clip( spectrum.rms, @@ -354,7 +352,7 @@ def sigma_clip_spectra( filter_list.append(rms_filter.mask) filter_idx = np.any(filter_list, axis=0) - filtered_data_list: List[Spectrum] = [] + filtered_data_list: list[Spectrum] = [] for spectrum in stokes_spectra: filtered_data = spectrum.data.copy() filtered_data[filter_idx] = np.nan @@ -372,12 +370,12 @@ def sigma_clip_spectra( def fit_stokes_I( freq: np.ndarray, coord: SkyCoord, - tt0: Optional[str] = None, - tt1: Optional[str] = None, + tt0: str | None = None, + tt1: str | None = None, do_own_fit: bool = False, - iarr: Optional[np.ndarray] = None, - rmsi: Optional[np.ndarray] = None, - polyOrd: Optional[int] = None, + iarr: np.ndarray | None = None, + rmsi: np.ndarray | None = None, + polyOrd: int | None = None, ) -> StokesIFitResult: if tt0 and tt1: mfs_i_0 = fits.getdata(tt0, memmap=True) @@ -478,15 +476,15 @@ def update_rmtools_dict( @task(name="1D RM-synthesis") def rmsynthoncut1d( - comp_tuple: Tuple[str, pd.Series], - beam_tuple: Tuple[str, pd.Series], + comp_tuple: tuple[str, pd.Series], + beam_tuple: tuple[str, pd.Series], outdir: Path, freq: np.ndarray, field: str, - sbid: Optional[int] = None, + sbid: int | None = None, polyOrd: int = 3, - phiMax_radm2: Optional[float] = None, - dPhi_radm2: Optional[float] = None, + phiMax_radm2: float | None = None, + dPhi_radm2: float | None = None, nSamples: int = 5, weightType: str = "variance", fitRMSF: bool = True, @@ -496,8 +494,8 @@ def rmsynthoncut1d( debug: bool = False, rm_verbose: bool = False, fit_function: str = "log", - tt0: Optional[str] = None, - tt1: Optional[str] = None, + tt0: str | None = None, + tt1: str | None = None, ion: bool = False, do_own_fit: bool = False, ) -> pymongo.UpdateOne: @@ -660,13 +658,9 @@ def rmsynthoncut1d( # Ensure JSON serializable for k, v in mDict.items(): - if isinstance(v, np.float_): + if isinstance(v, (np.float64, np.float32)): mDict[k] = float(v) - elif isinstance(v, np.float32): - mDict[k] = float(v) - elif isinstance(v, np.int_): - mDict[k] = int(v) - elif isinstance(v, np.int32): + elif isinstance(v, (np.int_, np.int32)): mDict[k] = int(v) elif isinstance(v, np.ndarray): mDict[k] = v.tolist() @@ -680,7 +674,7 @@ def rmsynthoncut1d( # Prep header head_dict = dict(filtered_stokes_spectra.i.header) head_dict.pop("", None) - if "COMMENT" in head_dict.keys(): + if "COMMENT" in head_dict: head_dict["COMMENT"] = str(head_dict["COMMENT"]) logger.debug(f"Heading for {cname} is {pformat(head_dict)}") @@ -746,18 +740,18 @@ def main( outdir: Path, host: str, epoch: int, - sbid: Optional[int] = None, - username: Optional[str] = None, - password: Optional[str] = None, + sbid: int | None = None, + username: str | None = None, + password: str | None = None, dimension: str = "1d", verbose: bool = True, database: bool = False, - limit: Union[int, None] = None, + limit: int | None = None, savePlots: bool = False, weightType: str = "variance", fitRMSF: bool = True, - phiMax_radm2: Union[float, None] = None, - dPhi_radm2: Union[float, None] = None, + phiMax_radm2: float | None = None, + dPhi_radm2: float | None = None, nSamples: int = 5, polyOrd: int = 3, noStokesI: bool = False, @@ -766,8 +760,8 @@ def main( rm_verbose: bool = False, debug: bool = False, fit_function: str = "log", - tt0: Optional[str] = None, - tt1: Optional[str] = None, + tt0: str | None = None, + tt1: str | None = None, ion: bool = False, do_own_fit: bool = False, ) -> None: @@ -829,7 +823,8 @@ def main( field_col=field_col, ) if not sbid_check: - raise ValueError(f"SBID {sbid} does not match field {field}") + msg = f"SBID {sbid} does not match field {field}" + raise ValueError(msg) beam_query = {"$and": [{f"beams.{field}": {"$exists": True}}]} @@ -839,7 +834,7 @@ def main( logger.info(f"Querying beams with {beam_query}") beams = pd.DataFrame(list(beams_col.find(beam_query).sort("Source_ID"))) - beams.set_index("Source_ID", drop=False, inplace=True) + beams = beams.set_index("Source_ID", drop=False) island_ids = sorted(beams_col.distinct("Source_ID", beam_query)) isl_query = {"Source_ID": {"$in": island_ids}} @@ -857,7 +852,7 @@ def main( ).sort("Source_ID") ) ) - components.set_index("Source_ID", drop=False, inplace=True) + components = components.set_index("Source_ID", drop=False) component_ids = list(components["Gaussian_ID"]) n_comp = comp_col.count_documents(isl_query) @@ -1028,7 +1023,8 @@ def main( ) outputs.append(output) else: - raise ValueError("An incorrect RMSynth mode has been configured. ") + msg = "An incorrect RMSynth mode has been configured. " + raise ValueError(msg) if database: logger.info("Updating database...") diff --git a/arrakis/utils/coordinates.py b/arrakis/utils/coordinates.py index 04535a83..3caa185c 100644 --- a/arrakis/utils/coordinates.py +++ b/arrakis/utils/coordinates.py @@ -1,8 +1,8 @@ -#!/usr/bin/env python """Coordinate utilities""" +from __future__ import annotations + import warnings -from typing import Tuple from astropy.coordinates import SkyCoord from astropy.coordinates.angles import dms_tuple, hms_tuple @@ -47,7 +47,7 @@ def deg_to_dms(deg: float) -> dms_tuple: return dms_tuple(degree, minute, seconds) -def coord_to_string(coord: SkyCoord) -> Tuple[str, str]: +def coord_to_string(coord: SkyCoord) -> tuple[str, str]: """Convert coordinate to string without astropy Args: diff --git a/arrakis/utils/database.py b/arrakis/utils/database.py index 2c211dbf..8b5b4985 100644 --- a/arrakis/utils/database.py +++ b/arrakis/utils/database.py @@ -1,8 +1,8 @@ -#!/usr/bin/env python """Database utilities""" +from __future__ import annotations + import warnings -from typing import Optional, Tuple, Union import pymongo from astropy.utils.exceptions import AstropyWarning @@ -27,15 +27,16 @@ def validate_sbid_field_pair(field_name: str, sbid: int, field_col: Collection) bool: If field name and sbid pair is valid. """ logger.info(f"Validating field name and SBID pair: {field_name}, {sbid}") - field_data: Optional[dict] = field_col.find_one({"SBID": sbid}) + field_data: dict | None = field_col.find_one({"SBID": sbid}) if field_data is None: - raise ValueError(f"SBID {sbid} not found in database") + msg = f"SBID {sbid} not found in database" + raise ValueError(msg) return field_data["FIELD_NAME"] == field_name def test_db( - host: str, username: Union[str, None] = None, password: Union[str, None] = None + host: str, username: str | None = None, password: str | None = None ) -> bool: """Test connection to MongoDB @@ -64,7 +65,8 @@ def test_db( try: dbclient.list_database_names() except pymongo.errors.ServerSelectionTimeoutError: - raise Exception("Please ensure 'mongod' is running") + msg = "Please ensure 'mongod' is running" + raise Exception(msg) logger.info("MongoDB connection succesful!") @@ -74,9 +76,9 @@ def test_db( def get_db( host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, -) -> Tuple[Collection, Collection, Collection]: + username: str | None = None, + password: str | None = None, +) -> tuple[Collection, Collection, Collection]: """Get MongoDBs Args: @@ -120,8 +122,7 @@ def get_field_db(host: str, epoch: int, username=None, password=None) -> Collect authMechanism="SCRAM-SHA-256", ) # type: pymongo.MongoClient mydb = dbclient[f"arrakis_epoch_{epoch}"] # Create/open database - field_col = mydb["fields"] # Create/open collection - return field_col + return mydb["fields"] # Create/open collection def get_beam_inf_db(host: str, epoch: int, username=None, password=None) -> Collection: @@ -143,5 +144,4 @@ def get_beam_inf_db(host: str, epoch: int, username=None, password=None) -> Coll authMechanism="SCRAM-SHA-256", ) # type: pymongo.MongoClient mydb = dbclient[f"arrakis_epoch_{epoch}"] # Create/open database - beam_inf_col = mydb["beam_inf"] # Create/open collection - return beam_inf_col + return mydb["beam_inf"] # Create/open collection diff --git a/arrakis/utils/exceptions.py b/arrakis/utils/exceptions.py index 49b22de7..4f3dd336 100644 --- a/arrakis/utils/exceptions.py +++ b/arrakis/utils/exceptions.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Errors and exceptions""" +from __future__ import annotations + import warnings from astropy.utils.exceptions import AstropyWarning diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index 0a55d845..7d219977 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -1,10 +1,12 @@ #!/usr/bin/env python """FITS utilities""" +from __future__ import annotations + import warnings from glob import glob from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import astropy.units as u import numpy as np @@ -22,7 +24,7 @@ warnings.simplefilter("ignore", category=AstropyWarning) -def head2dict(h: fits.Header) -> Dict[str, Any]: +def head2dict(h: fits.Header) -> dict[str, Any]: """Convert FITS header to a dict. Writes a cutout, as stored in source_dict, to disk. The file location @@ -67,10 +69,10 @@ def fix_header(cutout_header: fits.Header, original_header: fits.Header) -> fits def getfreq( - cube: Union[str, Path], - outdir: Optional[Path] = None, - filename: Union[str, Path, None] = None, -) -> Union[u.Quantity, Tuple[u.Quantity, Path]]: + cube: str | Path, + outdir: Path | None = None, + filename: str | Path | None = None, +) -> u.Quantity | tuple[u.Quantity, Path]: """Get list of frequencies from FITS data. Gets the frequency list from a given cube. Can optionally save @@ -176,7 +178,7 @@ def getdata(cubedir="./", tabledir="./", mapdata=None, verbose=True): mask = ~(u_cube == 0 * u.jansky / u.beam) u_cube = u_cube.with_mask(mask) - datadict = { + return { "i_tab": i_tab, "i_tab_comp": components, "i_taylor": i_taylor, @@ -191,5 +193,3 @@ def getdata(cubedir="./", tabledir="./", mapdata=None, verbose=True): "u_file": ucubes[0], "v_file": vcubes[0], } - - return datadict diff --git a/arrakis/utils/fitting.py b/arrakis/utils/fitting.py index e3b558a8..e83c3cb1 100644 --- a/arrakis/utils/fitting.py +++ b/arrakis/utils/fitting.py @@ -1,9 +1,10 @@ #!/usr/bin/env python """Fitting utilities""" +from __future__ import annotations + import warnings from functools import partial -from typing import Optional, Tuple import numpy as np from astropy.stats import akaike_info_criterion_lsq @@ -18,7 +19,7 @@ warnings.simplefilter("ignore", category=AstropyWarning) -def fitted_mean(data: np.ndarray, axis: Optional[int] = None) -> float: +def fitted_mean(data: np.ndarray, axis: int | None = None) -> float: """Calculate the mean of a distribution. Args: @@ -28,12 +29,13 @@ def fitted_mean(data: np.ndarray, axis: Optional[int] = None) -> float: float: Mean. """ if axis is not None: - raise NotImplementedError("Axis not implemented") + msg = "Axis not implemented" + raise NotImplementedError(msg) mean, _ = norm.fit(data) return mean -def fitted_std(data: np.ndarray, axis: Optional[int] = None) -> float: +def fitted_std(data: np.ndarray, axis: int | None = None) -> float: """Calculate the standard deviation of a distribution. Args: @@ -43,7 +45,8 @@ def fitted_std(data: np.ndarray, axis: Optional[int] = None) -> float: float: Standard deviation. """ if axis is not None: - raise NotImplementedError("Axis not implemented") + msg = "Axis not implemented" + raise NotImplementedError(msg) _, std = norm.fit(data) return std @@ -62,7 +65,7 @@ def chi_squared(model: np.ndarray, data: np.ndarray, error: np.ndarray) -> float return np.sum(((model - data) / error) ** 2) -def best_aic_func(aics: np.ndarray, n_param: np.ndarray) -> Tuple[float, int, int]: +def best_aic_func(aics: np.ndarray, n_param: np.ndarray) -> tuple[float, int, int]: """Find the best AIC for a set of AICs using Occam's razor.""" # Find the best AIC best_aic_idx = int(np.nanargmin(aics)) @@ -254,7 +257,7 @@ def fit_pl( # Now find the best model best_aic, best_n, best_aic_idx = best_aic_func( np.array([save_dict[n]["aics"] for n in range(nterms + 1)]), - np.array([n for n in range(nterms + 1)]), + np.array(list(range(nterms + 1))), ) logger.debug(f"Best fit: {best_n}, {best_aic}") best_p = save_dict[best_n]["params"] @@ -270,36 +273,36 @@ def fit_pl( error=fluxerr[goodchan], ) chi_sq_red = chi_sq / (goodchan.sum() - len(best_p)) - return dict( - best_n=best_n, - best_p=best_p, - best_e=best_e, - best_m=best_m, - best_h=best_h, - best_l=best_l, - best_f=best_f, - fit_flag=best_flag, - ref_nu=ref_nu, - chi_sq=chi_sq, - chi_sq_red=chi_sq_red, - ) + return { + "best_n": best_n, + "best_p": best_p, + "best_e": best_e, + "best_m": best_m, + "best_h": best_h, + "best_l": best_l, + "best_f": best_f, + "fit_flag": best_flag, + "ref_nu": ref_nu, + "chi_sq": chi_sq, + "chi_sq_red": chi_sq_red, + } except Exception as e: logger.critical(f"Failed to fit power law: {e}") - return dict( - best_n=np.nan, - best_p=[np.nan], - best_e=[np.nan], - best_m=np.ones_like(freq), - best_h=np.ones_like(freq), - best_l=np.ones_like(freq), - best_f=None, - fit_flag={ + return { + "best_n": np.nan, + "best_p": [np.nan], + "best_e": [np.nan], + "best_m": np.ones_like(freq), + "best_h": np.ones_like(freq), + "best_l": np.ones_like(freq), + "best_f": None, + "fit_flag": { "is_negative": True, "is_not_finite": True, "is_not_normal": True, "is_close_to_zero": True, }, - ref_nu=np.nan, - chi_sq=np.nan, - chi_sq_red=np.nan, - ) + "ref_nu": np.nan, + "chi_sq": np.nan, + "chi_sq_red": np.nan, + } diff --git a/arrakis/utils/io.py b/arrakis/utils/io.py index 96cbc698..5a17f8cd 100644 --- a/arrakis/utils/io.py +++ b/arrakis/utils/io.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """I/O utilities""" +from __future__ import annotations + import logging import os import shlex @@ -9,7 +11,6 @@ import warnings from glob import glob from pathlib import Path -from typing import Tuple from astropy.table import Table from astropy.utils.exceptions import AstropyWarning @@ -92,7 +93,7 @@ def try_mkdir(dir_path: str, verbose=True): logger.info(f"Directory '{dir_path}' exists.") -def gettable(tabledir: str, keyword: str, verbose=True) -> Tuple[Table, str]: +def gettable(tabledir: str, keyword: str, verbose=True) -> tuple[Table, str]: """Get a table from a directory given a keyword to glob. Args: @@ -127,6 +128,7 @@ def _samefile(src, dst): return os.path.samefile(src, dst) except OSError: return False + return None def copyfile(src, dst, *, follow_symlinks=True, verbose=True): @@ -137,7 +139,8 @@ def copyfile(src, dst, *, follow_symlinks=True, verbose=True): """ if _samefile(src, dst): - raise SameFileError(f"{src!r} and {dst!r} are the same file") + msg = f"{src!r} and {dst!r} are the same file" + raise SameFileError(msg) for fn in [src, dst]: try: @@ -148,14 +151,14 @@ def copyfile(src, dst, *, follow_symlinks=True, verbose=True): else: # XXX What about other special files? (sockets, devices...) if stat.S_ISFIFO(st.st_mode): - raise SpecialFileError(f"`{fn}` is a named pipe") + msg = f"`{fn}` is a named pipe" + raise SpecialFileError(msg) if not follow_symlinks and os.path.islink(src): os.symlink(os.readlink(src), dst) else: - with open(src, "rb") as fsrc: - with open(dst, "wb") as fdst: - copyfileobj(fsrc, fdst, verbose=verbose) + with open(src, "rb") as fsrc, open(dst, "wb") as fdst: + copyfileobj(fsrc, fdst, verbose=verbose) return dst diff --git a/arrakis/utils/json.py b/arrakis/utils/json.py index 3cae6d79..c8ca006c 100644 --- a/arrakis/utils/json.py +++ b/arrakis/utils/json.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """JSON utilities""" +from __future__ import annotations + import dataclasses import json @@ -23,7 +25,7 @@ def default(self, obj): # pylint: disable=E0202 return int(obj) elif isinstance(obj, np.floating): return float(obj) - elif isinstance(obj, np.complex): + elif isinstance(obj, complex): return (obj.real, obj.imag) elif isinstance(obj, np.ndarray): return obj.tolist() diff --git a/arrakis/utils/meta.py b/arrakis/utils/meta.py index 1c230117..d2eabcf2 100644 --- a/arrakis/utils/meta.py +++ b/arrakis/utils/meta.py @@ -1,6 +1,7 @@ -#!/usr/bin/env python """Generic program utilities""" +from __future__ import annotations + import importlib import warnings from itertools import zip_longest @@ -36,8 +37,7 @@ def class_for_name(module_name: str, class_name: str) -> object: # load the module, will raise ImportError if module cannot be loaded m = importlib.import_module(module_name) # get the class, will raise AttributeError if class cannot be found - c = getattr(m, class_name) - return c + return getattr(m, class_name) # stolen from https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python @@ -45,7 +45,8 @@ def zip_equal(*iterables): sentinel = object() for combo in zip_longest(*iterables, fillvalue=sentinel): if sentinel in combo: - raise ValueError("Iterables have different lengths") + msg = "Iterables have different lengths" + raise ValueError(msg) yield combo @@ -65,4 +66,6 @@ def yes_or_no(question: str) -> bool: elif reply[:1] == "n": return False else: - raise ValueError("Please answer 'y' or 'n'") + msg = "Please answer 'y' or 'n'" + raise ValueError(msg) + return None diff --git a/arrakis/utils/msutils.py b/arrakis/utils/msutils.py index 688b3bb5..d14332e5 100644 --- a/arrakis/utils/msutils.py +++ b/arrakis/utils/msutils.py @@ -1,10 +1,11 @@ #!/usr/bin/env python """MeasurementSet utilities""" +from __future__ import annotations + import copy import warnings from pathlib import Path -from typing import Optional import astropy.units as u from astropy.utils.exceptions import AstropyWarning @@ -18,7 +19,7 @@ def get_pol_axis( - ms: Path, feed_idx: Optional[int] = None, col: str = "RECEPTOR_ANGLE" + ms: Path, feed_idx: int | None = None, col: str = "RECEPTOR_ANGLE" ) -> u.Quantity: """Get the polarization axis from the ASKAP MS. Checks are performed to ensure this polarisation axis angle is constant throughout the observation. @@ -37,7 +38,8 @@ def get_pol_axis( """ _known_cols = ("RECEPTOR_ANGLE", "INSTRUMENT_RECEPTOR_ANGLE") if col not in _known_cols: - raise ValueError(f"Unknown column {col=}, please use one of {_known_cols}") + msg = f"Unknown column {col=}, please use one of {_known_cols}" + raise ValueError(msg) with table((ms / "FEED").as_posix(), readonly=True, ack=False) as tf: ms_feed = tf.getcol(col) * u.rad # PAF is at 45deg to feeds @@ -61,8 +63,7 @@ def beam_from_ms(ms: str) -> int: """Work out which beam is in this MS""" with table(ms, readonly=True, ack=False) as t: vis_feed = t.getcol("FEED1", 0, 1) - beam = vis_feed[0] - return beam + return vis_feed[0] def field_idx_from_ms(ms: str) -> int: @@ -70,10 +71,9 @@ def field_idx_from_ms(ms: str) -> int: with table(f"{ms}/FIELD", readonly=True, ack=False) as field: idxs = list(field.SOURCE_ID) assert len(idxs) == 1 or all( - [idx == idxs[0] for idx in idxs] + idx == idxs[0] for idx in idxs ), "More than one field in MS" - idx = idxs[0] - return idx + return idxs[0] def field_name_from_ms(ms: str) -> str: @@ -81,163 +81,162 @@ def field_name_from_ms(ms: str) -> str: with table(f"{ms}/FIELD", readonly=True, ack=False) as field: names = list(field.NAME) assert len(names) == 1, "More than one field in MS" - name = names[0] - return name + return names[0] def wsclean( mslist: list, use_mpi: bool, version: bool = False, - j: Optional[int] = None, - parallel_gridding: Optional[int] = None, - parallel_reordering: Optional[int] = None, + j: int | None = None, + parallel_gridding: int | None = None, + parallel_reordering: int | None = None, no_work_on_master: bool = False, - mem: Optional[float] = None, - abs_mem: Optional[float] = None, + mem: float | None = None, + abs_mem: float | None = None, verbose: bool = False, log_time: bool = False, quiet: bool = False, reorder: bool = False, no_reorder: bool = False, - temp_dir: Optional[str] = None, + temp_dir: str | None = None, update_model_required: bool = False, no_update_model_required: bool = False, no_dirty: bool = False, save_first_residual: bool = False, save_weights: bool = False, save_uv: bool = False, - reuse_psf: Optional[str] = None, - reuse_dirty: Optional[str] = None, + reuse_psf: str | None = None, + reuse_dirty: str | None = None, apply_primary_beam: bool = False, reuse_primary_beam: bool = False, use_differential_lofar_beam: bool = False, - primary_beam_limit: Optional[float] = None, - mwa_path: Optional[str] = None, + primary_beam_limit: float | None = None, + mwa_path: str | None = None, save_psf_pb: bool = False, - pb_grid_size: Optional[int] = None, - beam_model: Optional[str] = None, - beam_mode: Optional[str] = None, - beam_normalisation_mode: Optional[str] = None, + pb_grid_size: int | None = None, + beam_model: str | None = None, + beam_mode: str | None = None, + beam_normalisation_mode: str | None = None, dry_run: bool = False, - weight: Optional[str] = None, - super_weight: Optional[float] = None, + weight: str | None = None, + super_weight: float | None = None, mf_weighting: bool = False, no_mf_weighting: bool = False, - weighting_rank_filter: Optional[float] = None, - weighting_rank_filter_size: Optional[float] = None, - taper_gaussian: Optional[str] = None, - taper_tukey: Optional[float] = None, - taper_inner_tukey: Optional[float] = None, - taper_edge: Optional[float] = None, - taper_edge_tukey: Optional[float] = None, + weighting_rank_filter: float | None = None, + weighting_rank_filter_size: float | None = None, + taper_gaussian: str | None = None, + taper_tukey: float | None = None, + taper_inner_tukey: float | None = None, + taper_edge: float | None = None, + taper_edge_tukey: float | None = None, use_weights_as_taper: bool = False, store_imaging_weights: bool = False, - name: Optional[str] = None, - size: Optional[str] = None, - padding: Optional[float] = None, - scale: Optional[str] = None, + name: str | None = None, + size: str | None = None, + padding: float | None = None, + scale: str | None = None, predict: bool = False, ws_continue: bool = False, subtract_model: bool = False, - gridder: Optional[str] = None, - channels_out: Optional[int] = None, - shift: Optional[str] = None, + gridder: str | None = None, + channels_out: int | None = None, + shift: str | None = None, gap_channel_division: bool = False, - channel_division_frequencies: Optional[str] = None, - nwlayers: Optional[int] = None, - nwlayers_factor: Optional[float] = None, - nwlayers_for_size: Optional[str] = None, + channel_division_frequencies: str | None = None, + nwlayers: int | None = None, + nwlayers_factor: float | None = None, + nwlayers_for_size: str | None = None, no_small_inversion: bool = False, small_inversion: bool = False, - grid_mode: Optional[str] = None, - kernel_size: Optional[int] = None, - oversampling: Optional[int] = None, + grid_mode: str | None = None, + kernel_size: int | None = None, + oversampling: int | None = None, make_psf: bool = False, make_psf_only: bool = False, - visibility_weighting_mode: Optional[str] = None, + visibility_weighting_mode: str | None = None, no_normalize_for_weighting: bool = False, - baseline_averaging: Optional[float] = None, - simulate_noise: Optional[float] = None, - simulate_baseline_noise: Optional[str] = None, - idg_mode: Optional[str] = None, - wgridder_accuracy: Optional[float] = None, - aterm_config: Optional[str] = None, + baseline_averaging: float | None = None, + simulate_noise: float | None = None, + simulate_baseline_noise: str | None = None, + idg_mode: str | None = None, + wgridder_accuracy: float | None = None, + aterm_config: str | None = None, grid_with_beam: bool = False, - beam_aterm_update: Optional[int] = False, - aterm_kernel_size: Optional[float] = None, - apply_facet_solutions: Optional[str] = None, + beam_aterm_update: int | None = False, + aterm_kernel_size: float | None = None, + apply_facet_solutions: str | None = None, apply_facet_beam: bool = False, - facet_beam_update: Optional[int] = False, + facet_beam_update: int | None = False, save_aterms: bool = False, - pol: Optional[str] = None, - interval: Optional[str] = None, - intervals_out: Optional[int] = None, + pol: str | None = None, + interval: str | None = None, + intervals_out: int | None = None, even_timesteps: bool = False, odd_timesteps: bool = False, - channel_range: Optional[str] = None, - field: Optional[int] = None, - spws: Optional[str] = None, - data_column: Optional[str] = None, - maxuvw_m: Optional[float] = None, - minuvw_m: Optional[float] = None, - maxuv_l: Optional[float] = None, - minuv_l: Optional[float] = None, - maxw: Optional[float] = None, - niter: Optional[int] = None, - nmiter: Optional[int] = None, - threshold: Optional[float] = None, - auto_threshold: Optional[float] = None, - auto_mask: Optional[float] = None, - force_mask_rounds: Optional[int] = None, + channel_range: str | None = None, + field: int | None = None, + spws: str | None = None, + data_column: str | None = None, + maxuvw_m: float | None = None, + minuvw_m: float | None = None, + maxuv_l: float | None = None, + minuv_l: float | None = None, + maxw: float | None = None, + niter: int | None = None, + nmiter: int | None = None, + threshold: float | None = None, + auto_threshold: float | None = None, + auto_mask: float | None = None, + force_mask_rounds: int | None = None, local_rms: bool = False, - local_rms_window: Optional[float] = False, - local_rms_method: Optional[str] = None, - gain: Optional[float] = None, - mgain: Optional[float] = None, + local_rms_window: float | None = False, + local_rms_method: str | None = None, + gain: float | None = None, + mgain: float | None = None, join_polarizations: bool = False, - link_polarizations: Optional[str] = None, - facet_regions: Optional[str] = None, + link_polarizations: str | None = None, + facet_regions: str | None = None, join_channels: bool = False, - spectral_correction: Optional[str] = None, + spectral_correction: str | None = None, no_fast_subminor: bool = False, multiscale: bool = False, - multiscale_scale_bias: Optional[float] = None, - multiscale_max_scales: Optional[int] = None, - multiscale_scales: Optional[str] = None, - multiscale_shape: Optional[str] = None, - multiscale_gain: Optional[float] = None, - multiscale_convolution_padding: Optional[float] = None, + multiscale_scale_bias: float | None = None, + multiscale_max_scales: int | None = None, + multiscale_scales: str | None = None, + multiscale_shape: str | None = None, + multiscale_gain: float | None = None, + multiscale_convolution_padding: float | None = None, no_multiscale_fast_subminor: bool = False, - python_deconvolution: Optional[str] = None, + python_deconvolution: str | None = None, iuwt: bool = False, iuwt_snr_test: bool = False, no_iuwt_snr_test: bool = False, - moresane_ext: Optional[str] = None, - moresane_arg: Optional[str] = None, - moresane_sl: Optional[str] = None, + moresane_ext: str | None = None, + moresane_arg: str | None = None, + moresane_sl: str | None = None, save_source_list: bool = False, - clean_border: Optional[float] = None, - fits_mask: Optional[str] = None, - casa_mask: Optional[str] = None, - horizon_mask: Optional[str] = None, + clean_border: float | None = None, + fits_mask: str | None = None, + casa_mask: str | None = None, + horizon_mask: str | None = None, no_negative: bool = False, negative: bool = False, stop_negative: bool = False, - fit_spectral_pol: Optional[int] = None, - fit_spectral_log_pol: Optional[int] = None, - force_spectrum: Optional[str] = None, - deconvolution_channels: Optional[int] = None, + fit_spectral_pol: int | None = None, + fit_spectral_log_pol: int | None = None, + force_spectrum: str | None = None, + deconvolution_channels: int | None = None, squared_channel_joining: bool = False, - parallel_deconvolution: Optional[int] = None, - deconvolution_threads: Optional[int] = None, - restore: Optional[str] = None, - restore_list: Optional[str] = None, - beam_size: Optional[float] = None, - beam_shape: Optional[str] = None, + parallel_deconvolution: int | None = None, + deconvolution_threads: int | None = None, + restore: str | None = None, + restore_list: str | None = None, + beam_size: float | None = None, + beam_shape: str | None = None, fit_beam: bool = False, no_fit_beam: bool = False, - beam_fitting_size: Optional[float] = None, + beam_fitting_size: float | None = None, theoretic_beam: bool = False, circular_beam: bool = False, elliptical_beam: bool = False, @@ -750,10 +749,7 @@ def wsclean( mslist = arguments.pop("mslist") use_mpi = arguments.pop("use_mpi") # Check for MPI - if use_mpi: - command = "mpirun wsclean-mp" - else: - command = "wsclean " + command = "mpirun wsclean-mp" if use_mpi else "wsclean " # Check for square channels and multiscale if arguments["squared_channel_joining"] and arguments["multiscale"]: diff --git a/arrakis/utils/pipeline.py b/arrakis/utils/pipeline.py index 334a5cd5..7e8bd861 100644 --- a/arrakis/utils/pipeline.py +++ b/arrakis/utils/pipeline.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Pipeline and flow utility functions""" +from __future__ import annotations + import argparse import base64 import logging @@ -9,14 +11,13 @@ import time import warnings from pathlib import Path -from typing import List, Optional, Tuple, Union from uuid import UUID import astropy.units as u import dask.array as da -import dask.distributed as distributed import numpy as np from astropy.utils.exceptions import AstropyWarning +from dask import distributed from dask.delayed import Delayed from dask.distributed import get_client from distributed.client import futures_of @@ -57,7 +58,7 @@ # Stolen from Flint @task(name="Upload image as artifact") def upload_image_as_artifact_task( - image_path: Path, description: Optional[str] = None + image_path: Path, description: str | None = None ) -> UUID: """Create and submit a markdown artifact tracked by prefect for an input image. Currently supporting png formatted images. @@ -260,8 +261,8 @@ def __exit__(self, exc_type, exc_value, traceback): def inspect_client( - client: Union[distributed.Client, None] = None, -) -> Tuple[str, int, int, u.Quantity, int, u.Quantity]: + client: distributed.Client | None = None, +) -> tuple[str, int, int, u.Quantity, int, u.Quantity]: """_summary_ Args: @@ -311,9 +312,7 @@ def chunk_dask( return chunk_outputs -def delayed_to_da( - list_of_delayed: List[Delayed], chunk: Union[int, None] = None -) -> da.Array: +def delayed_to_da(list_of_delayed: list[Delayed], chunk: int | None = None) -> da.Array: """Convert list of delayed arrays to a dask array Args: @@ -324,18 +323,13 @@ def delayed_to_da( da.Array: Dask array """ sample = list_of_delayed[0].compute() - dim = (len(list_of_delayed),) + sample.shape - if chunk is None: - c_dim = dim - else: - c_dim = (chunk,) + sample.shape + dim = (len(list_of_delayed), *sample.shape) + c_dim = dim if chunk is None else (chunk, *sample.shape) darray_list = [ da.from_delayed(lazy, dtype=sample.dtype, shape=sample.shape) for lazy in list_of_delayed ] - darray = da.stack(darray_list, axis=0).reshape(dim).rechunk(c_dim) - - return darray + return da.stack(darray_list, axis=0).reshape(dim).rechunk(c_dim) # stolen from https://github.com/tqdm/tqdm/issues/278 diff --git a/arrakis/utils/plotting.py b/arrakis/utils/plotting.py index 31aa31b6..8c8fd2f2 100644 --- a/arrakis/utils/plotting.py +++ b/arrakis/utils/plotting.py @@ -1,6 +1,7 @@ -#!/usr/bin/env python """Plotting utilities""" +from __future__ import annotations + import warnings from astropy.utils.exceptions import AstropyWarning @@ -24,7 +25,7 @@ def latexify(fig_width=None, fig_height=None, columns=1): """ from math import sqrt - import matplotlib + import matplotlib as mpl # code adapted from http://www.scipy.org/Cookbook/Matplotlib/LaTeX_Examples # Width and max height in inches for IEEE journals taken from @@ -63,4 +64,4 @@ def latexify(fig_width=None, fig_height=None, columns=1): "font.family": "serif", } - matplotlib.rcParams.update(params) + mpl.rcParams.update(params) diff --git a/arrakis/utils/typing.py b/arrakis/utils/typing.py index cb5f8625..63acce00 100644 --- a/arrakis/utils/typing.py +++ b/arrakis/utils/typing.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Typing utilities""" +from __future__ import annotations + from pathlib import Path from typing import TypeVar diff --git a/arrakis/validate.py b/arrakis/validate.py old mode 100644 new mode 100755 index 4a57bcc9..82510c2c --- a/arrakis/validate.py +++ b/arrakis/validate.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Make validation plots from a catalogue""" +from __future__ import annotations + import argparse import logging from importlib import resources diff --git a/arrakis/wsclean_rmsynth.py b/arrakis/wsclean_rmsynth.py old mode 100644 new mode 100755 index 5de3c789..09e094d6 --- a/arrakis/wsclean_rmsynth.py +++ b/arrakis/wsclean_rmsynth.py @@ -11,6 +11,7 @@ # -niter 1000 -save-first-residual -auto-threshold 3 -mgain 0.8 \ # -interval 10 11 -size 1024 1024 -scale 1amin \ # 1052736496-averaged.ms/ +from __future__ import annotations from functools import partial from time import time @@ -68,8 +69,6 @@ def get_rmsynth_params( rmsf = kay * np.sum(weights * np.exp(np.outer(bees, ays)), 1) assert (np.abs(rmsf) <= 1.0).all(), "RMSF is not normalized" - print(f"{np.abs(rmsf).max()=}") - return RMSynthParams( phis=phis, phis_double=phis_double, @@ -232,7 +231,6 @@ def _clean_loop( # A clean component is "loop-gain * peak_fdf cc_scalar = gain * peak_fdf - print(f"{np.abs(cc_scalar)=}") cc_vec[idx_peak_fdf] += cc_scalar # At which channel is the cc_scalar located at in the RMSF? @@ -295,9 +293,7 @@ def proper_rm_clean( # Calculate the spectrum quarr = np.sum(cc_vec[:, np.newaxis] * np.exp(2.0j * np.outer(phis, ays)), axis=0) - spec = np.array([quarr.real, quarr.imag]).T - - return spec + return np.array([quarr.real, quarr.imag]).T def simple_clean( @@ -340,9 +336,7 @@ def simple_clean( # Update the CC at the peak cc_vec[np.argmin(np.abs(phis - peak_phi))] = cc_scalar quarr = np.sum(cc_vec[:, np.newaxis] * np.exp(2.0j * np.outer(phis, ays)), axis=0) - spec = np.array([quarr.real, quarr.imag]).T - - return spec + return np.array([quarr.real, quarr.imag]).T def deconvolve( @@ -352,12 +346,9 @@ def deconvolve( meta: dict, ): if meta.channels == []: - raise ValueError("No channels in meta") + msg = "No channels in meta" + raise ValueError(msg) nchan, npol, height, width = residual.shape - print( - "Python deconvolve() function was called for " - + f"{width} x {height} x {npol} (npol) x {nchan} (chan) dataset" - ) # residual and model are numpy arrays with dimensions nchan x npol x height x width # psf is a numpy array with dimensions nchan x height x width. @@ -365,7 +356,8 @@ def deconvolve( # This file demonstrates a very simple deconvolution strategy, which doesn't # support multiple polarizations: if npol != 2: - raise NotImplementedError("npol must be 2") + msg = "npol must be 2" + raise NotImplementedError(msg) # If there are channels missing (flagged), they will be set to NaN # They're here set to zero to make it easy to calculate the integrated image mask_3D = np.isnan(residual) @@ -377,10 +369,9 @@ def deconvolve( freqs = np.array([x.frequency for x in meta.channels]) weights_3d = np.ones_like(freqs) params_3d = get_rmsynth_params(freqs, weights_3d, nsamp=3) - tick = time() + time() # Mask residual with 2D mask res_masked = residual[:, ~mask_2D].reshape(nchan, npol, -1) - print(f"{res_masked.shape=}") # Now reshape to nchan x npol x height*width # Only need to CLEAN the unmasked pixels fdf_2d = rmsynth2d( @@ -390,13 +381,9 @@ def deconvolve( ays=params_3d.ays, phis=params_3d.phis, ) - tock = time() - print(f"RMSynth3D took {tock - tick:0.2f} seconds") + time() integrated_residual_1d = np.sum(np.abs(fdf_2d), axis=0) integrated_residual = np.zeros((height, width)) - print(f"{integrated_residual.shape=}") - print(f"{integrated_residual_1d.shape=}") - print(f"{integrated_residual[~mask_2D_p].shape=}") integrated_residual[~mask_2D_p] = integrated_residual_1d np.savetxt("integrated_residual.txt", integrated_residual) peak_index = np.unravel_index( @@ -410,9 +397,6 @@ def deconvolve( [meta.major_iter_threshold, meta.final_threshold, mgain_threshold] ) - print( - f"Starting iteration {meta.iteration_number}, peak={peak_value}, first threshold={first_threshold}" - ) while peak_value > first_threshold and meta.iteration_number < meta.max_iterations: y = peak_index[0] x = peak_index[1] @@ -422,7 +406,7 @@ def deconvolve( weights = (spectrum_complex.sum(axis=1) > 0).astype(int) params = get_rmsynth_params(freqs, weights) - tick = time() + time() fdf = rmsynth_1d( stokes_q=spectrum_complex[:, 0], stokes_u=spectrum_complex[:, 1], @@ -430,15 +414,14 @@ def deconvolve( ays=params.ays, phis=params.phis, ) - tock = time() - print(f"1D RMSynth took {tock - tick:0.2f} seconds") + time() # model_spectrum = simple_clean( # phis=params.phis, # fdf=fdf, # ays=params.ays, # fwhm=params.fwhm, # )\ - tick = time() + time() model_spectrum = proper_rm_clean( phis=params.phis, phis_double=params.phis_double, @@ -449,8 +432,7 @@ def deconvolve( cutoff=100e-6, max_iter=1000, ) - tock = time() - print(f"1D RM clean took {tock - tick:0.2f} seconds") + time() if meta.iteration_number < 10: np.savetxt( f"model_spectrum_iter_{meta.iteration_number}.txt", model_spectrum @@ -473,7 +455,7 @@ def deconvolve( ###################### # Update the residual - tick = time() + time() res_masked = residual[:, ~mask_2D].reshape((nchan, npol, -1)) fdf_2d = rmsynth2d( stokes_q=res_masked[:, 0], @@ -482,8 +464,7 @@ def deconvolve( ays=params_3d.ays, phis=params_3d.phis, ) - tock = time() - print(f"RMSynth3D took {tock - tick:0.2f} seconds") + time() integrated_residual_1d = np.sum(np.abs(fdf_2d), axis=0) integrated_residual = np.zeros((height, width)) integrated_residual[~mask_2D_p] = integrated_residual_1d @@ -492,14 +473,10 @@ def deconvolve( ) peak_value = integrated_residual[peak_index] - print(f"{peak_value=}") - meta.iteration_number = meta.iteration_number + 1 - print(f"Stopped after iteration {meta.iteration_number}, peak={peak_value}") - # Fill a dictionary with values that wsclean expects: - result = dict() + result = {} result["residual"] = residual result["model"] = model result["level"] = peak_value @@ -508,5 +485,4 @@ def deconvolve( and meta.iteration_number < meta.max_iterations ) - print("Finished deconvolve()") return result diff --git a/docs/source/conf.py b/docs/source/conf.py index 54acd114..5b193385 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys from importlib.metadata import distribution diff --git a/pyproject.toml b/pyproject.toml index de72d0fb..0306996f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,41 @@ tar_cubelets = { reference="scripts/tar_cubelets.py", type="file"} create_mongodb = { reference="scripts/create_mongodb.py", type="file"} [tool.ruff] -src = ["arrakis", "tests", "scripts"] +src = ["arrakis", "scripts", "tests"] [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "UP", "I"] +extend-select = [ + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + # "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable + "NPY", # NumPy specific rules + "PD", # pandas-vet +] +ignore = [ + "PLR09", # Too many <...> + "PLR2004", # Magic value used in comparison + "ISC001", # Conflicts with formatter +] +isort.required-imports = ["from __future__ import annotations"] +# Uncomment if using a _compat.typing backport +# typing-modules = ["cutout_fits._compat.typing"] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["T20"] +"noxfile.py" = ["T20"] diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index a0454ab2..f18380e2 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Prepare files for CASDA upload""" +from __future__ import annotations + import argparse import hashlib import logging @@ -226,7 +228,8 @@ def convert_spectra( elif keys[0] == "FDF": data = rmtables[keys[2]][keys[1]].values * u.Jy / u.beam / rmsf_unit else: - raise ValueError(f"Unknown column {col}") + msg = f"Unknown column {col}" + raise ValueError(msg) new_col = Column( name=col, @@ -304,7 +307,8 @@ def update_cube(cube: str, cube_dir: str) -> None: ".linmos.ion.edge.linmos.fits", ".linmos.edge.linmos.fits" ) if not os.path.exists(fname): - raise FileNotFoundError(f"Could not find {fname}") + msg = f"Could not find {fname}" + raise FileNotFoundError(msg) data[:, i, :, :] = fits.getdata( fname, memmap=True, @@ -398,10 +402,7 @@ def write_polspec(table: Table, filename: str, overwrite: bool = False): np.array(tabcol[0]) ) # get the type of each element in 2D array col_format = "Q" + fits.column._convert_record2fits(subtype) + "()" - if tabcol.unit is not None: - unit = tabcol.unit.to_string() - else: - unit = "" + unit = tabcol.unit.to_string() if tabcol.unit is not None else "" pfcol = fits.Column( name=tabcol.name, unit=unit, array=tabcol.data, format=col_format ) @@ -508,7 +509,8 @@ def main( polcat = polcat[:100] else: - raise ValueError(f"Unknown prep_type: {prep_type}") + msg = f"Unknown prep_type: {prep_type}" + raise ValueError(msg) casda_dir = os.path.join(data_dir, f"casda_{prep_type}") try_mkdir(casda_dir) @@ -570,7 +572,7 @@ def main( ), f"Number of cubes does not match number of sources -- {len(cubes)=} and {len(set(polcat['source_id']))=}" unique_ids, unique_idx = np.unique(polcat["source_id"], return_index=True) - lookup = {sid: i for sid, i in zip(unique_ids, unique_idx)} + lookup = dict(zip(unique_ids, unique_idx)) with tqdm(total=len(cubes), desc="Sorting cubes", file=TQDM_OUT) as pbar: def my_sorter(x, lookup=lookup, pbar=pbar): @@ -585,6 +587,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): idx += i pbar.update(1) return idx + return None sorted_cubes = sorted(cubes, key=my_sorter) @@ -613,7 +616,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): ), f"{len(spectra)=} and {len(polcat)=}" # Sanity check unique_ids, unique_idx = np.unique(polcat["cat_id"], return_index=True) - lookup = {sid: i for sid, i in zip(unique_ids, unique_idx)} + lookup = dict(zip(unique_ids, unique_idx)) with tqdm(total=len(spectra), desc="Sorting spectra", file=TQDM_OUT) as pbar: def my_sorter(x, lookup=lookup, pbar=pbar): @@ -663,7 +666,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): try_mkdir(spec_dir) plots = find_plots(data_dir=data_dir) unique_ids, unique_idx = np.unique(polcat["cat_id"], return_index=True) - lookup = {sid: i for sid, i in zip(unique_ids, unique_idx)} + lookup = dict(zip(unique_ids, unique_idx)) if prep_type != "full": # Drop plots that are not in the cut catalogue diff --git a/scripts/check_cutout.py b/scripts/check_cutout.py index aff5e432..1540fbb7 100755 --- a/scripts/check_cutout.py +++ b/scripts/check_cutout.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import os from glob import glob diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index 16b5128b..019916bd 100644 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -13,6 +13,8 @@ y = sin(offset)*sin(angle)/incy + refy """ +from __future__ import annotations + import os import warnings from glob import glob @@ -159,7 +161,7 @@ def main( island_ids = sorted(beams_col.distinct("Source_ID", beam_query)) isl_query = {"Source_ID": {"$in": island_ids}} beams = pd.DataFrame(list(beams_col.find(isl_query).sort("Source_ID"))) - beams.set_index("Source_ID", drop=False, inplace=True) + beams = beams.set_index("Source_ID", drop=False) components = pd.DataFrame( list( comp_col.find( @@ -176,12 +178,12 @@ def main( ).sort("Source_ID") ) ) - components.set_index("Source_ID", drop=False, inplace=True) + components = components.set_index("Source_ID", drop=False) _ = list(components["Gaussian_ID"]) assert len(set(beams.index)) == len(set(components.index)) outputs = [] - for i, c in components.iterrows(): + for _i, c in components.iterrows(): if snr_cut is not None: noise = c.Noise signal = c.Total_flux_Gaussian diff --git a/scripts/compute_leakage.py b/scripts/compute_leakage.py index 0bc91094..b15c7448 100644 --- a/scripts/compute_leakage.py +++ b/scripts/compute_leakage.py @@ -1,12 +1,14 @@ #!/usr/bin/env python3 +from __future__ import annotations + import logging import astropy -import astropy.units as units import matplotlib.pyplot as plt import numpy as np from arrakis.logger import TqdmToLogger, logger from arrakis.utils.database import get_db +from astropy import units from astropy.coordinates import SkyCoord from astropy.wcs import WCS from tqdm.auto import tqdm, trange @@ -122,7 +124,7 @@ def trim_mean(x): num_points_in_aperture_list = [] # Init collectors logger.info("\nDeriving robust leakage estimates for interpolation grid...") - for row_idx, row in enumerate(tqdm(pair_dist, file=TQDM_OUT)): + for _row_idx, row in enumerate(tqdm(pair_dist, file=TQDM_OUT)): # Guide to where we're at # if row_idx%100==0: # logger.info('Processing row %d of %d'%(row_idx,len(pair_dist))) diff --git a/scripts/copy_cutouts.py b/scripts/copy_cutouts.py index 66a0d70b..6f8693c5 100755 --- a/scripts/copy_cutouts.py +++ b/scripts/copy_cutouts.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import os diff --git a/scripts/copy_cutouts_askap.py b/scripts/copy_cutouts_askap.py index 0c1591b2..d2f4d2f4 100644 --- a/scripts/copy_cutouts_askap.py +++ b/scripts/copy_cutouts_askap.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import os diff --git a/scripts/copy_data.py b/scripts/copy_data.py index 61101d70..04d96621 100755 --- a/scripts/copy_data.py +++ b/scripts/copy_data.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse +import contextlib import os from glob import glob from pathlib import Path @@ -67,17 +70,14 @@ def main( idx = abspath.find("_F") f_no = abspath[idx + 1 : idx + 4] newpath = abspath.replace(f_no, "F00") - try: + with contextlib.suppress(SameFileError): copyfile(abspath, newpath) - except SameFileError: - pass logger.debug(os.path.basename(newpath)) - if clean: - if yes: - files = glob(f"{check}/CONTCUBE*") - for f in files: - os.remove(f) + if clean and yes: + files = glob(f"{check}/CONTCUBE*") + for f in files: + os.remove(f) def cli(): diff --git a/scripts/create_mongodb.py b/scripts/create_mongodb.py index f9d94ab9..ee0e56ea 100644 --- a/scripts/create_mongodb.py +++ b/scripts/create_mongodb.py @@ -1,10 +1,11 @@ #!/usr/bin/env python +from __future__ import annotations + import argparse import logging import subprocess as sp from pathlib import Path from pprint import pformat -from typing import Optional import pymongo from arrakis.logger import logger @@ -21,7 +22,7 @@ def start_mongod( dbpath: Path, logpath: Path, host: str = "localhost", - port: Optional[int] = None, + port: int | None = None, auth: bool = False, ): cmd = f"mongod --fork --dbpath {dbpath} --logpath {logpath} --bind_ip {host}" @@ -38,7 +39,8 @@ def start_mongod( f"mongod already running - try shutting down first with `mongod --dbpath {dbpath} --shutdown`" ) logger.error(f"{e}") - raise MongodError(f"Failed to start mongod. Command was: {cmd}") + msg = f"Failed to start mongod. Command was: {cmd}" + raise MongodError(msg) logger.info(proc.decode()) logger.info("Started mongod") @@ -53,7 +55,8 @@ def stop_mongod( proc = sp.check_output(cmd.split()) except sp.CalledProcessError as e: logger.error(f"{e}") - raise MongodError(f"Failed to stop mongod. Command was: {cmd}") + msg = f"Failed to stop mongod. Command was: {cmd}" + raise MongodError(msg) logger.info(proc.decode()) logger.info("Stopped mongod") @@ -86,7 +89,7 @@ def create_or_update_user( def create_admin_user( host: str, password: str, - port: Optional[int] = None, + port: int | None = None, username: str = "admin", ): logger.info(f"Creating admin user {username} on {host}:{port}") @@ -104,7 +107,7 @@ def create_admin_user( def create_read_only_user( host: str, password: str, - port: Optional[int] = None, + port: int | None = None, username: str = "reader", ): logger.info(f"Creating read-only user {username} on {host}:{port}") @@ -124,9 +127,9 @@ def main( admin_password: str, reader_password: str, host: str = "localhost", - port: Optional[int] = None, - admin_username: Optional[str] = "admin", - reader_username: Optional[str] = "reader", + port: int | None = None, + admin_username: str | None = "admin", + reader_username: str | None = "reader", ): logpath = dbpath.parent / "mongod.log" start_mongod( diff --git a/scripts/find_row.py b/scripts/find_row.py index f1d9761f..35b3d9bf 100755 --- a/scripts/find_row.py +++ b/scripts/find_row.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse from pathlib import Path diff --git a/scripts/find_sbid.py b/scripts/find_sbid.py index c1ad53a0..75700e6d 100755 --- a/scripts/find_sbid.py +++ b/scripts/find_sbid.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse from pathlib import Path diff --git a/scripts/fix_dr1_cat.py b/scripts/fix_dr1_cat.py index ed64e6fd..47fe1f69 100755 --- a/scripts/fix_dr1_cat.py +++ b/scripts/fix_dr1_cat.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Post process DR1 catalog""" +from __future__ import annotations + import logging import os import pickle @@ -139,7 +141,7 @@ def fix_fields( "day": u.d, } for col in new_tab.colnames: - if str(new_tab[col].unit) in dumb_units.keys(): + if str(new_tab[col].unit) in dumb_units: new_unit = dumb_units[str(new_tab[col].unit)] logger.debug(f"Fixing {col} unit from {new_tab[col].unit} to {new_unit}") new_tab[col].unit = new_unit @@ -203,7 +205,7 @@ def main(cat: str, survey_dir: Path, epoch: int = 0): logger.info(f"Wrote leakage fit to {outfit}") logger.info(f"Writing corrected catalogue to {outfile}") - if ext == ".xml" or ext == ".vot": + if ext in (".xml", ".vot"): write_votable(fix_flag_tab, outfile) else: tab.write(outfile, overwrite=True) diff --git a/scripts/fix_src_cat.py b/scripts/fix_src_cat.py index b600263d..55dab855 100644 --- a/scripts/fix_src_cat.py +++ b/scripts/fix_src_cat.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Post process DR1 source catalog""" +from __future__ import annotations + import logging import os from pathlib import Path @@ -27,7 +29,7 @@ def add_metadata(vo_table: vot.tree.Table, table: Table, filename: str): vot: VO Table object with metadata """ # Add metadata - for col_idx, col_name in enumerate(table.colnames): + for _col_idx, col_name in enumerate(table.colnames): col = table[col_name] vocol = vo_table.get_first_table().get_field_by_id(col_name) if hasattr(col, "description"): @@ -110,8 +112,8 @@ def main( field.add_index("FIELD_NAME") spice_df = spice_cat.to_pandas() - spice_df.set_index("source_id", inplace=True) - spice_df.sort_index(inplace=True) + spice_df = spice_df.set_index("source_id") + spice_df = spice_df.sort_index() source_cat.sort("Source_ID") spice_grp = spice_df.groupby("source_id") diff --git a/scripts/hello_mpi_world.py b/scripts/hello_mpi_world.py index f16539f3..f47226bb 100755 --- a/scripts/hello_mpi_world.py +++ b/scripts/hello_mpi_world.py @@ -3,6 +3,8 @@ Parallel Hello World """ +from __future__ import annotations + import sys from mpi4py import MPI diff --git a/scripts/make_links.py b/scripts/make_links.py index 647c4385..10b72866 100755 --- a/scripts/make_links.py +++ b/scripts/make_links.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +from __future__ import annotations + import argparse import os import subprocess @@ -18,13 +20,13 @@ def main(indir, outdir): link = name.replace(".fits", ".conv.fits") cmd = f"ln -s {f} {os.path.abspath(outdir)}/{link}" logger.info(cmd) - subprocess.run(split(cmd)) + subprocess.run(split(cmd), check=False) for f in weights: name = os.path.basename(f) cmd = f"ln -s {f} {os.path.abspath(outdir)}/{name}" logger.info(cmd) - subprocess.run(split(cmd)) + subprocess.run(split(cmd), check=False) def cli(): diff --git a/scripts/spica.py b/scripts/spica.py index 82c858ba..ff5b69d3 100755 --- a/scripts/spica.py +++ b/scripts/spica.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shlex import subprocess as sb @@ -57,9 +59,8 @@ def mslist(cal_sb, name): ms = glob(f"{racs_area}/{cal_sb}/RACS_test4_1.05_{name}/*beam00_*.ms")[0] except Exception as e: logger.error(e) - raise Exception( - f"Can't find '{racs_area}/{cal_sb}/RACS_test4_1.05_{name}/*beam00_*.ms'" - ) + msg = f"Can't find '{racs_area}/{cal_sb}/RACS_test4_1.05_{name}/*beam00_*.ms'" + raise Exception(msg) mslist_out = sb.run( shlex.split(f"mslist --full {ms}"), capture_output=True, check=False @@ -72,8 +73,7 @@ def mslist(cal_sb, name): shlex.split("date +%Y-%m-%d-%H%M%S"), capture_output=True, check=True ) - out = mslist_out.stderr.decode() + f"METADATA_IS_GOOD {date_out.stdout.decode()}" - return out + return mslist_out.stderr.decode() + f"METADATA_IS_GOOD {date_out.stdout.decode()}" def main( @@ -107,7 +107,7 @@ def main( cal_files = glob( f"{racs_area}/{cal_sbid}/RACS_test4_1.05_{name}/*averaged_cal.leakage.ms" ) - leak = not len(cal_files) == 0 + leak = len(cal_files) != 0 cubes = [] for stoke in ["i", "q", "u"]: cubes.extend( diff --git a/scripts/tar_cubelets.py b/scripts/tar_cubelets.py index ac401344..d43ff4d6 100755 --- a/scripts/tar_cubelets.py +++ b/scripts/tar_cubelets.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import logging import os @@ -41,18 +42,18 @@ def main(casda_dir: str): """ casda_dir = os.path.abspath(casda_dir) if not os.path.exists(casda_dir): - raise FileNotFoundError(f"Directory {casda_dir} does not exist") + msg = f"Directory {casda_dir} does not exist" + raise FileNotFoundError(msg) if not os.path.exists(os.path.join(casda_dir, "cubelets")): - raise FileNotFoundError(f"Directory {casda_dir} does not contain cubelets/") + msg = f"Directory {casda_dir} does not contain cubelets/" + raise FileNotFoundError(msg) cube_list = glob(os.path.join(casda_dir, "cubelets", "*.fits")) logger.info(f"{len(cube_list)} cublets to tar...") - sources = set( - [ - os.path.basename(cube)[:13] - for cube in tqdm(cube_list, desc="Sources", file=TQDM_OUT) - ] - ) + sources = { + os.path.basename(cube)[:13] + for cube in tqdm(cube_list, desc="Sources", file=TQDM_OUT) + } logger.info(f"...into {len(sources)} sources") out_dir = os.path.join(casda_dir, "cubelets_tar") os.makedirs(out_dir, exist_ok=True) diff --git a/submit/test_image.py b/submit/test_image.py index 16ff1610..051dad61 100755 --- a/submit/test_image.py +++ b/submit/test_image.py @@ -6,6 +6,7 @@ # SBATCH --cpus-per-task=1 # SBATCH --account=OD-217087 # SBATCH --qos=express +from __future__ import annotations import logging from pathlib import Path diff --git a/tests/cli_test.py b/tests/cli_test.py index 9f45b7d8..c9a13799 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -1,5 +1,7 @@ """Tests for CLI.""" +from __future__ import annotations + import subprocess import unittest @@ -8,12 +10,12 @@ class test_cli(unittest.TestCase): def test_cli_init(self): """Tests that the CLI `spice_init` runs.""" res = subprocess.run(["spice_init", "--help"], check=True) - self.assertEqual(res.returncode, 0) + assert res.returncode == 0 def test_cli_process(self): """Tests that the CLI `spice_process` runs.""" res = subprocess.run(["spice_process", "--help"], check=True) - self.assertEqual(res.returncode, 0) + assert res.returncode == 0 if __name__ == "__main__": diff --git a/tests/unit_test.py b/tests/unit_test.py index b00649de..e22decf2 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -1,5 +1,7 @@ """Tests for functions.""" +from __future__ import annotations + import unittest # Test functions within arrakis.rmsyth_oncuts @@ -10,23 +12,18 @@ class TestRmsynthOncuts(unittest.TestCase): def test_rmsynthoncut3d(self): """Test rmsynthoncut3d.""" - pass def test_rms_1d(self): """Test rms_1d.""" - pass def test_estimate_noise_annulus(self): """Test estimate_noise_annulus.""" - pass def test_rmsynthoncut1d(self): """Test rmsynthoncut1d.""" - pass def test_rmsynthoncut_i(self): """Test rmsynthoncut_i.""" - pass if __name__ == "__main__": From 0f3c4cf2d9547e891bf164b70ef65a54c2187c2c Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Mon, 22 Jul 2024 21:38:25 +0800 Subject: [PATCH 04/17] Major ruffage --- arrakis/merge_fields.py | 64 ++++++++++----------- arrakis/process_region.py | 35 ++++++------ arrakis/process_spice.py | 19 +++--- arrakis/rmclean_oncuts.py | 3 +- arrakis/rmsynth_oncuts.py | 88 ++++++++++++++-------------- arrakis/utils/database.py | 4 +- arrakis/utils/exceptions.py | 1 - arrakis/utils/fitsutils.py | 78 ------------------------- arrakis/utils/fitting.py | 1 - arrakis/utils/io.py | 111 +----------------------------------- arrakis/utils/json.py | 15 +++-- arrakis/utils/meta.py | 8 +-- arrakis/utils/msutils.py | 1 - arrakis/utils/pipeline.py | 5 +- arrakis/utils/typing.py | 1 - arrakis/wsclean_rmsynth.py | 2 +- 16 files changed, 115 insertions(+), 321 deletions(-) diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index 189c2141..aa8ecf18 100755 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -5,8 +5,10 @@ import argparse import os +from pathlib import Path from pprint import pformat from shutil import copyfile +from typing import Any import pymongo from prefect import flow, task @@ -23,15 +25,15 @@ def make_short_name(name: str) -> str: @task(name="Copy singleton island") def copy_singleton( - beam: dict, field_dict: dict[str, str], merge_name: str, data_dir: str + beam: dict[str, Any], field_dict: dict[str, Path], merge_name: str, data_dir: Path ) -> list[pymongo.UpdateOne]: """Copy an island within a single field to the merged field Args: beam (dict): Beam document - field_dict (Dict[str, str]): Field dictionary + field_dict (dict[str, Path]): Field dictionary merge_name (str): Merged field name - data_dir (str): Output directory + data_dir (Path): Output directory Raises: KeyError: If ion files not found @@ -45,25 +47,18 @@ def copy_singleton( continue field_dir = field_dict[field] try: - i_file_old = os.path.join(field_dir, vals["i_file"]) - q_file_old = os.path.join(field_dir, vals["q_file_ion"]) - u_file_old = os.path.join(field_dir, vals["u_file_ion"]) - except KeyError: + i_file_old = field_dir / str(vals["i_file"]) + q_file_old = field_dir / str(vals["q_file_ion"]) + u_file_old = field_dir / str(vals["u_file_ion"]) + except KeyError as e: msg = "Ion files not found. Have you run FRion?" - raise KeyError(msg) - new_dir = os.path.join(data_dir, beam["Source_ID"]) + raise KeyError(msg) from e + new_dir = data_dir / str(beam["Source_ID"]) + new_dir.mkdir(exist_ok=True) - try_mkdir(new_dir, verbose=False) - - i_file_new = os.path.join(new_dir, os.path.basename(i_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) - q_file_new = os.path.join(new_dir, os.path.basename(q_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) - u_file_new = os.path.join(new_dir, os.path.basename(u_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) + i_file_new = (new_dir / i_file_old.name).replace(".fits", ".edge.linmos.fits") + q_file_new = (new_dir / q_file_old.name).replace(".fits", ".edge.linmos.fits") + u_file_new = (new_dir / u_file_old.name).replace(".fits", ".edge.linmos.fits") for src, dst in zip( [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] @@ -174,7 +169,7 @@ def genparset( def merge_multiple_field( - beam: dict, field_dict: dict, merge_name: str, data_dir: str, image: str + beam: dict, field_dict: dict, merge_name: str, data_dir: Path, image: str ) -> list[pymongo.UpdateOne]: """Merge an island that overlaps multiple fields @@ -182,7 +177,7 @@ def merge_multiple_field( beam (dict): Beam document field_dict (dict): Field dictionary merge_name (str): Merged field name - data_dir (str): Data directory + data_dir (Path): Data directory image (str): Yandasoft image Raises: @@ -210,7 +205,7 @@ def merge_multiple_field( q_files_old.append(q_file_old) u_files_old.append(u_file_old) - new_dir = os.path.join(data_dir, beam["Source_ID"]) + new_dir = data_dir / beam["Source_ID"] try_mkdir(new_dir, verbose=False) @@ -277,9 +272,9 @@ def merge_multiple_fields( @flow(name="Merge fields") def main( fields: list[str], - field_dirs: list[str], + field_dirs: list[Path], merge_name: str, - output_dir: str, + output_dir: Path, host: str, epoch: int, username: str | None = None, @@ -292,9 +287,8 @@ def main( len(fields) == len(field_dirs) ), f"List of fields must be the same length as length of field dirs. {len(fields)=},{len(field_dirs)=}" - field_dict = { - field: os.path.join(field_dir, "cutouts") - for field, field_dir in zip(fields, field_dirs) + field_dict: dict[str, Path] = { + field: field_dir / "cutouts" for field, field_dir in zip(fields, field_dirs) } image = get_yanda(version=yanda) @@ -303,11 +297,11 @@ def main( host=host, epoch=epoch, username=username, password=password ) - output_dir = os.path.abspath(output_dir) - inter_dir = os.path.join(output_dir, merge_name) - try_mkdir(inter_dir) - data_dir = os.path.join(inter_dir, "cutouts") - try_mkdir(data_dir) + output_dir = output_dir.absolute() + inter_dir = output_dir / merge_name + inter_dir.mkdir(exist_ok=True) + data_dir = inter_dir / "cutouts" + data_dir.mkdir(exist_ok=True) singleton_updates = copy_singletons( field_dict=field_dict, @@ -368,14 +362,14 @@ def merge_parser(parent_parser: bool = False) -> argparse.ArgumentParser: parser.add_argument( "--datadirs", - type=str, + type=Path, nargs="+", help="Directories containing cutouts (in subdir outdir/cutouts)..", ) parser.add_argument( "--output_dir", - type=str, + type=Path, help="Path to save merged data (in output_dir/merge_name/cutouts)", ) parser.add_argument( diff --git a/arrakis/process_region.py b/arrakis/process_region.py index 4eac0f4a..4cacc065 100755 --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -5,11 +5,10 @@ import argparse import logging -import os +from importlib import resources from pathlib import Path import configargparse -import pkg_resources import yaml from astropy.time import Time from prefect import flow @@ -31,13 +30,15 @@ @flow -def process_merge(args, host: str, inter_dir: str, task_runner) -> None: +def process_merge( + args: argparse.Namespace, host: str, inter_dir: Path, task_runner +) -> None: """Workflow to merge spectra from overlapping fields together Args: args (Namespace): Parameters to use for this process host (str): Address of the mongoDB servicing the processing - inter_dir (str): Location to store data from merged fields + inter_dir (Path): Location to store data from merged fields """ previous_future = None previous_future = ( @@ -146,8 +147,8 @@ def main(args: configargparse.Namespace) -> None: args (configargparse.Namespace): Command line arguments. """ if args.dask_config is None: - config_dir = pkg_resources.resource_filename("arrakis", "configs") - args.dask_config = os.path.join(config_dir, "default.yaml") + with resources.path("arrakis", "configs") as config_dir: + args.dask_config = config_dir / "default.yaml" if args.outfile is None: args.outfile = f"{args.merge_name}.pipe.test.fits" @@ -159,16 +160,16 @@ def main(args: configargparse.Namespace) -> None: ) args_yaml = yaml.dump(vars(args)) - args_yaml_f = os.path.abspath(f"{args.merge_name}-config-{Time.now().fits}.yaml") + args_yaml_f = Path(f"{args.merge_name}-config-{Time.now().fits}.yaml").absolute() logger.info(f"Saving config to '{args_yaml_f}'") - with open(args_yaml_f, "w") as f: + with args_yaml_f.open("w") as f: f.write(args_yaml) dask_runner = process_spice.create_dask_runner( - dask_config=args.dask_config, + dask_config=Path(args.dask_config), ) - inter_dir = os.path.join(os.path.abspath(args.output_dir), args.merge_name) + inter_dir = Path(args.output_dir).absolute() / args.merge_name process_merge.with_options( name=f"Arrakis Merge: {args.merge_name}", task_runner=dask_runner @@ -198,23 +199,19 @@ def pipeline_parser(parent_parser: bool = False) -> argparse.ArgumentParser: help="Config file for Dask SlurmCLUSTER.", ) + parser.add_argument("--skip_frion", action="store_true", help="Skip cleanup stage.") parser.add_argument( - "--skip_frion", action="store_true", help="Skip cleanup stage [False]." + "--skip_rmsynth", action="store_true", help="Skip RM Synthesis stage." ) parser.add_argument( - "--skip_rmsynth", action="store_true", help="Skip RM Synthesis stage [False]." - ) - parser.add_argument( - "--skip_rmclean", action="store_true", help="Skip RM-CLEAN stage [False]." - ) - parser.add_argument( - "--skip_cat", action="store_true", help="Skip catalogue stage [False]." + "--skip_rmclean", action="store_true", help="Skip RM-CLEAN stage." ) + parser.add_argument("--skip_cat", action="store_true", help="Skip catalogue stage.") parser.add_argument( "--skip_validate", action="store_true", help="Skip validation stage." ) parser.add_argument( - "--skip_cleanup", action="store_true", help="Skip cleanup stage [False]." + "--skip_cleanup", action="store_true", help="Skip cleanup stage." ) return pipeline_parser diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 1d19093a..ef8109ca 100755 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -5,7 +5,6 @@ import argparse import logging -import os from importlib import resources from pathlib import Path @@ -210,22 +209,22 @@ def save_args(args: configargparse.Namespace) -> Path: Path: Output path of the saved file """ args_yaml = yaml.dump(vars(args)) - args_yaml_f = os.path.abspath(f"{args.field}-config-{Time.now().fits}.yaml") + args_yaml_f = Path(f"{args.field}-config-{Time.now().fits}.yaml").absolute() logger.info(f"Saving config to '{args_yaml_f}'") - with open(args_yaml_f, "w") as f: + with args_yaml_f.open("w") as f: f.write(args_yaml) return Path(args_yaml_f) def create_dask_runner( - dask_config: str, + dask_config: Path | None, overload: bool = False, ) -> DaskTaskRunner: """Create a DaskTaskRunner Args: - dask_config (str): Configuraiton file for the DaskTaskRunner + dask_config (Path | None): Configuraiton file for the DaskTaskRunner overload (bool, optional): Overload the options for threadded work. Defaults to False. Returns: @@ -237,7 +236,7 @@ def create_dask_runner( config_dir = resources.files("arrakis.configs") dask_config = config_dir / "default.yaml" - with open(dask_config) as f: + with dask_config.open() as f: logger.info(f"Loading {dask_config}") yaml_config: dict = yaml.safe_load(f) @@ -284,7 +283,7 @@ def main(args: configargparse.Namespace) -> None: # This is the client for the imager component of the arrakis # pipeline. dask_runner = create_dask_runner( - dask_config=args.imager_dask_config, + dask_config=Path(args.imager_dask_config), overload=True, ) @@ -342,7 +341,7 @@ def main(args: configargparse.Namespace) -> None: # This is the client and pipeline for the RM extraction dask_runner_2 = create_dask_runner( - dask_config=args.dask_config, + dask_config=Path(args.dask_config), ) # Define flow @@ -370,13 +369,13 @@ def pipeline_parser(parent_parser: bool = False) -> argparse.ArgumentParser: parser = pipeline_parser.add_argument_group("pipeline arguments") parser.add_argument( "--dask_config", - type=str, + type=Path, default=None, help="Config file for Dask SlurmCLUSTER.", ) parser.add_argument( "--imager_dask_config", - type=str, + type=Path, default=None, help="Config file for Dask SlurmCLUSTER.", ) diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index 3edf38c8..15953d1b 100755 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -5,7 +5,6 @@ import argparse import logging -import os import warnings from pathlib import Path from pprint import pformat @@ -77,7 +76,7 @@ def rmclean1d( weightFile = outdir / f"{rm1dfiles['weights']}" rmSynthFile = outdir / f"{rm1dfiles['summary_json']}" - prefix = os.path.join(os.path.abspath(os.path.dirname(fdfFile)), cname) + prefix = (fdfFile.parent.absolute() / cname).as_posix() # Sanity checks for f in [weightFile, fdfFile, rmsfFile, rmSynthFile]: diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 42041fd7..28e953fc 100755 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -5,7 +5,6 @@ import argparse import logging -import os import traceback import warnings from pathlib import Path @@ -127,22 +126,21 @@ def rmsynthoncut3d( """ beam = dict(beam_tuple[1]) iname = island_id - ifile = os.path.join(outdir, beam["beams"][field]["i_file"]) + ifile = outdir / str(beam["beams"][field]["i_file"]) if ion: - qfile = os.path.join(outdir, beam["beams"][field]["q_file_ion"]) - ufile = os.path.join(outdir, beam["beams"][field]["u_file_ion"]) + qfile = outdir / str(beam["beams"][field]["q_file_ion"]) + ufile = outdir / str(beam["beams"][field]["u_file_ion"]) else: - qfile = os.path.join(outdir, beam["beams"][field]["q_file"]) - ufile = os.path.join(outdir, beam["beams"][field]["u_file"]) - # vfile = beam['beams'][field]['v_file'] + qfile = outdir / str(beam["beams"][field]["q_file"]) + ufile = outdir / str(beam["beams"][field]["u_file"]) header: fits.Header dataQ: np.ndarray dataI: np.ndarray - header, dataQ = do_RMsynth_3D.readFitsCube(qfile, rm_verbose) + header, dataQ = do_RMsynth_3D.readFitsCube(qfile.as_posix(), rm_verbose) header, dataU = do_RMsynth_3D.readFitsCube(ufile, rm_verbose) - header, dataI = do_RMsynth_3D.readFitsCube(ifile, rm_verbose) + header, dataI = do_RMsynth_3D.readFitsCube(ifile.as_posix(), rm_verbose) dataQ = np.squeeze(dataQ) dataU = np.squeeze(dataU) @@ -188,7 +186,7 @@ def rmsynthoncut3d( headtemplate=header, fitRMSF=fitRMSF, prefixOut=prefix, - outDir=os.path.dirname(ifile), + outDir=ifile.parent.as_posix(), write_seperate_FDF=True, not_rmsf=not_RMSF, nBits=32, @@ -202,17 +200,17 @@ def rmsynthoncut3d( if "COMMENT" in head_dict: head_dict["COMMENT"] = str(head_dict["COMMENT"]) - outer_dir = os.path.basename(os.path.dirname(ifile)) + outer_dir = Path(ifile.parent.name) newvalues = { "field": save_name, "rm3dfiles": { - "FDF_real_dirty": os.path.join(outer_dir, f"{prefix}FDF_real_dirty.fits"), - "FDF_im_dirty": os.path.join(outer_dir, f"{prefix}FDF_im_dirty.fits"), - "FDF_tot_dirty": os.path.join(outer_dir, f"{prefix}FDF_tot_dirty.fits"), - "RMSF_real": os.path.join(outer_dir, f"{prefix}RMSF_real.fits"), - "RMSF_tot": os.path.join(outer_dir, f"{prefix}RMSF_tot.fits"), - "RMSF_FWHM": os.path.join(outer_dir, f"{prefix}RMSF_FWHM.fits"), + "FDF_real_dirty": (outer_dir / f"{prefix}FDF_real_dirty.fits").as_posix(), + "FDF_im_dirty": (outer_dir / f"{prefix}FDF_im_dirty.fits").as_posix(), + "FDF_tot_dirty": (outer_dir / f"{prefix}FDF_tot_dirty.fits").as_posix(), + "RMSF_real": (outer_dir / f"{prefix}RMSF_real.fits").as_posix(), + "RMSF_tot": (outer_dir / f"{prefix}RMSF_tot.fits").as_posix(), + "RMSF_FWHM": (outer_dir / f"{prefix}RMSF_FWHM.fits").as_posix(), }, "rmsynth3d": True, "header": dict(header), @@ -256,11 +254,11 @@ def cubelet_bane(cubelet: np.ndarray, header: fits.Header) -> tuple[np.ndarray]: background = np.zeros(cubelet.shape[0]) * np.nan noise = np.zeros(cubelet.shape[0]) * np.nan for chan, plane in enumerate(data_masked): - plane = plane[np.isfinite(plane)] - if len(plane) == 0: + good_plane = plane[np.isfinite(plane)] + if len(good_plane) == 0: continue clipped_plane = sigma_clip( - plane, sigma=3, cenfunc=fitted_mean, stdfunc=fitted_std, maxiters=None + good_plane, sigma=3, cenfunc=fitted_mean, stdfunc=fitted_std, maxiters=None ) background[chan], noise[chan] = norm.fit(clipped_plane.compressed()) @@ -402,7 +400,7 @@ def fit_stokes_I( fit_dict=None, ) - elif do_own_fit: + if do_own_fit: logger.info("Doing own fit") fit_dict = fit_pl(freq=freq, flux=iarr, fluxerr=rmsi, nterms=abs(polyOrd)) @@ -415,15 +413,14 @@ def fit_stokes_I( fit_dict=fit_dict, ) - else: - return StokesIFitResult( - alpha=None, - amplitude=None, - x_0=None, - model_repr=None, - modStokesI=None, - fit_dict=None, - ) + return StokesIFitResult( + alpha=None, + amplitude=None, + x_0=None, + model_repr=None, + modStokesI=None, + fit_dict=None, + ) def update_rmtools_dict( @@ -526,10 +523,10 @@ def rmsynthoncut1d( comp = comp_tuple[1] beam = dict(beam_tuple[1]) - iname = comp["Source_ID"] - cname = comp["Gaussian_ID"] - ra = comp["RA"] - dec = comp["Dec"] + iname = str(comp["Source_ID"]) + cname = str(comp["Gaussian_ID"]) + ra = float(comp["RA"]) + dec = float(comp["Dec"]) coord = SkyCoord(ra * u.deg, dec * u.deg) field_dict = beam["beams"][field] @@ -548,7 +545,7 @@ def rmsynthoncut1d( operation = {"$set": {"rm_outputs_1d.$.rmsynth1d": False}} return pymongo.UpdateOne(myquery, operation, upsert=True) - prefix = f"{os.path.dirname(stokes_spectra.i.filename)}/{cname}" + prefix = stokes_spectra.i.filename.parent / cname # Filter by RMS for outlier rejection filtered_stokes_spectra = sigma_clip_spectra(stokes_spectra) @@ -589,8 +586,8 @@ def rmsynthoncut1d( data.append(filtered_stokes_spectra.__getattribute__(stokes).rms) # Run 1D RM-synthesis on the spectra - np.savetxt(f"{prefix}.dat", np.vstack(data).T, delimiter=" ") - np.savetxt(f"{prefix}_bkg.dat", np.vstack(bkg_data).T, delimiter=" ") + np.savetxt(f"{prefix.as_posix()}.dat", np.vstack(data).T, delimiter=" ") + np.savetxt(f"{prefix.as_posix()}_bkg.dat", np.vstack(bkg_data).T, delimiter=" ") try: logger.info(f"Using {fit_function} to fit Stokes I") @@ -610,7 +607,7 @@ def rmsynthoncut1d( verbose=rm_verbose, debug=debug, fit_function=fit_function, - prefixOut=prefix, + prefixOut=prefix.as_posix(), ) except Exception as err: traceback.print_tb(err.__traceback__) @@ -667,7 +664,7 @@ def rmsynthoncut1d( elif isinstance(v, np.bool_): mDict[k] = bool(v) - do_RMsynth_1D.saveOutput(mDict, aDict, prefix, rm_verbose) + do_RMsynth_1D.saveOutput(mDict, aDict, prefix.as_posix(), rm_verbose) myquery = {"Gaussian_ID": cname} @@ -678,15 +675,15 @@ def rmsynthoncut1d( head_dict["COMMENT"] = str(head_dict["COMMENT"]) logger.debug(f"Heading for {cname} is {pformat(head_dict)}") - outer_dir = os.path.basename(os.path.dirname(filtered_stokes_spectra.i.filename)) + outer_dir = Path(filtered_stokes_spectra.i.filename.parent.name) newvalues = { "field": save_name, "rm1dfiles": { - "FDF_dirty": os.path.join(outer_dir, f"{cname}_FDFdirty.dat"), - "RMSF": os.path.join(outer_dir, f"{cname}_RMSF.dat"), - "weights": os.path.join(outer_dir, f"{cname}_weight.dat"), - "summary_dat": os.path.join(outer_dir, f"{cname}_RMsynth.dat"), - "summary_json": os.path.join(outer_dir, f"{cname}_RMsynth.json"), + "FDF_dirty": (outer_dir / f"{cname}_FDFdirty.dat").as_posix(), + "RMSF": (outer_dir / f"{cname}_RMSF.dat").as_posix(), + "weights": (outer_dir / f"{cname}_weight.dat").as_posix(), + "summary_dat": (outer_dir / f"{cname}_RMsynth.dat").as_posix(), + "summary_json": (outer_dir / f"{cname}_RMsynth.json").as_posix(), }, "rmsynth1d": True, "header": head_dict, @@ -744,7 +741,6 @@ def main( username: str | None = None, password: str | None = None, dimension: str = "1d", - verbose: bool = True, database: bool = False, limit: int | None = None, savePlots: bool = False, diff --git a/arrakis/utils/database.py b/arrakis/utils/database.py index 8b5b4985..eaa28165 100644 --- a/arrakis/utils/database.py +++ b/arrakis/utils/database.py @@ -64,9 +64,9 @@ def test_db( ) as dbclient: # type: pymongo.MongoClient try: dbclient.list_database_names() - except pymongo.errors.ServerSelectionTimeoutError: + except pymongo.errors.ServerSelectionTimeoutError as err: msg = "Please ensure 'mongod' is running" - raise Exception(msg) + raise Exception(msg) from err logger.info("MongoDB connection succesful!") diff --git a/arrakis/utils/exceptions.py b/arrakis/utils/exceptions.py index 4f3dd336..4e01db2b 100644 --- a/arrakis/utils/exceptions.py +++ b/arrakis/utils/exceptions.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """Errors and exceptions""" from __future__ import annotations diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index 7d219977..1274f38c 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -1,10 +1,8 @@ -#!/usr/bin/env python """FITS utilities""" from __future__ import annotations import warnings -from glob import glob from pathlib import Path from typing import Any @@ -13,11 +11,9 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS -from spectral_cube import SpectralCube from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger -from arrakis.utils.io import gettable from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) @@ -119,77 +115,3 @@ def getfreq( logger.info(f"Saving to {outfile}") np.savetxt(outfile, np.array(freq)) return freq, outfile - - -def getdata(cubedir="./", tabledir="./", mapdata=None, verbose=True): - """Get the spectral and source-finding data. - - Args: - cubedir: Directory containing data cubes in FITS format. - tabledir: Directory containing Selavy results. - mapdata: 2D FITS image which corresponds to Selavy table. - - Kwargs: - verbose (bool): Whether to print messages. - - Returns: - datadict (dict): Dictionary of necessary astropy tables and - Spectral cubes. - - """ - if cubedir[-1] == "/": - cubedir = cubedir[:-1] - - if tabledir[-1] == "/": - tabledir = tabledir[:-1] - # Glob out the necessary files - # Data cubes - icubes = glob(f"{cubedir}/image.restored.i.*contcube*linmos.fits") - qcubes = glob(f"{cubedir}/image.restored.q.*contcube*linmos.fits") - ucubes = glob(f"{cubedir}/image.restored.u.*contcube*linmos.fits") - vcubes = glob(f"{cubedir}/image.restored.v.*contcube*linmos.fits") - - cubes = [icubes, qcubes, ucubes, vcubes] - # Selavy images - selavyfits = mapdata - # Get selvay data from VOTab - i_tab, voisle = gettable(tabledir, "islands", verbose=verbose) # Selvay VOTab - components, tablename = gettable(tabledir, "components", verbose=verbose) - - logger.info(f"Getting spectral data from: {cubes}\n") - logger.info(f"Getting source location data from: {selavyfits}\n") - - # Read data using Spectral cube - i_taylor = SpectralCube.read(selavyfits, mode="denywrite") - wcs_taylor = WCS(i_taylor.header) - i_cube = SpectralCube.read(icubes[0], mode="denywrite") - wcs_cube = WCS(i_cube.header) - q_cube = SpectralCube.read(qcubes[0], mode="denywrite") - u_cube = SpectralCube.read(ucubes[0], mode="denywrite") - if len(vcubes) != 0: - v_cube = SpectralCube.read(vcubes[0], mode="denywrite") - else: - v_cube = None - # Mask out using Stokes I == 0 -- seems to be the current fill value - mask = ~(i_cube == 0 * u.jansky / u.beam) - i_cube = i_cube.with_mask(mask) - mask = ~(q_cube == 0 * u.jansky / u.beam) - q_cube = q_cube.with_mask(mask) - mask = ~(u_cube == 0 * u.jansky / u.beam) - u_cube = u_cube.with_mask(mask) - - return { - "i_tab": i_tab, - "i_tab_comp": components, - "i_taylor": i_taylor, - "wcs_taylor": wcs_taylor, - "wcs_cube": wcs_cube, - "i_cube": i_cube, - "q_cube": q_cube, - "u_cube": u_cube, - "v_cube": v_cube, - "i_file": icubes[0], - "q_file": qcubes[0], - "u_file": ucubes[0], - "v_file": vcubes[0], - } diff --git a/arrakis/utils/fitting.py b/arrakis/utils/fitting.py index e83c3cb1..b8bd3bb3 100644 --- a/arrakis/utils/fitting.py +++ b/arrakis/utils/fitting.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """Fitting utilities""" from __future__ import annotations diff --git a/arrakis/utils/io.py b/arrakis/utils/io.py index 5a17f8cd..deebff25 100644 --- a/arrakis/utils/io.py +++ b/arrakis/utils/io.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """I/O utilities""" from __future__ import annotations @@ -6,19 +5,14 @@ import logging import os import shlex -import stat import subprocess as sp import warnings -from glob import glob from pathlib import Path -from astropy.table import Table from astropy.utils.exceptions import AstropyWarning from spectral_cube.utils import SpectralCubeWarning -from tqdm.auto import tqdm from arrakis.logger import TqdmToLogger, logger -from arrakis.utils.exceptions import SameFileError, SpecialFileError from arrakis.utils.typing import PathLike warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) @@ -62,7 +56,7 @@ def prsync(wild_src: str, tgt: str, ncores: int): os.system(f"ls -d {wild_src} | xargs -n 1 -P {ncores} -I% rsync -rvh % {tgt}") -def try_symlink(src: str, dst: str, verbose=True): +def try_symlink(src: str, dst: str): """Create symlink if it doesn't exist Args: @@ -76,106 +70,3 @@ def try_symlink(src: str, dst: str, verbose=True): logger.info(f"Made symlink '{dst}'.") except FileExistsError: logger.info(f"Symlink '{dst}' exists.") - - -def try_mkdir(dir_path: str, verbose=True): - """Create directory if it doesn't exist - - Args: - dir_path (str): Path to directory - verbose (bool, optional): Verbose output. Defaults to True. - """ - # Create output dir if it doesn't exist - try: - os.mkdir(dir_path) - logger.info(f"Made directory '{dir_path}'.") - except FileExistsError: - logger.info(f"Directory '{dir_path}' exists.") - - -def gettable(tabledir: str, keyword: str, verbose=True) -> tuple[Table, str]: - """Get a table from a directory given a keyword to glob. - - Args: - tabledir (str): Directory. - keyword (str): Keyword to glob for. - verbose (bool, optional): Verbose output. Defaults to True. - - Returns: - Tuple[Table, str]: Table and it's file location. - """ - if tabledir[-1] == "/": - tabledir = tabledir[:-1] - # Glob out the necessary files - files = glob(f"{tabledir}/*.{keyword}*.xml") # Selvay VOTab - filename = files[0] - logger.info(f"Getting table data from {filename}...") - - # Get selvay data from VOTab - table = Table.read(filename, format="votable") - table = table.to_pandas() - str_df = table.select_dtypes([object]) - str_df = str_df.stack().str.decode("utf-8").unstack() - for col in str_df: - table[col] = str_df[col] - return table, filename - - -def _samefile(src, dst): - # Macintosh, Unix. - if hasattr(os.path, "samefile"): - try: - return os.path.samefile(src, dst) - except OSError: - return False - return None - - -def copyfile(src, dst, *, follow_symlinks=True, verbose=True): - """Copy data from src to dst. - - If follow_symlinks is not set and src is a symbolic link, a new - symlink will be created instead of copying the file it points to. - - """ - if _samefile(src, dst): - msg = f"{src!r} and {dst!r} are the same file" - raise SameFileError(msg) - - for fn in [src, dst]: - try: - st = os.stat(fn) - except OSError: - # File most likely does not exist - pass - else: - # XXX What about other special files? (sockets, devices...) - if stat.S_ISFIFO(st.st_mode): - msg = f"`{fn}` is a named pipe" - raise SpecialFileError(msg) - - if not follow_symlinks and os.path.islink(src): - os.symlink(os.readlink(src), dst) - else: - with open(src, "rb") as fsrc, open(dst, "wb") as fdst: - copyfileobj(fsrc, fdst, verbose=verbose) - return dst - - -def copyfileobj(fsrc, fdst, length=16 * 1024, verbose=True): - # copied = 0 - total = os.fstat(fsrc.fileno()).st_size - with tqdm( - total=total, - disable=(not verbose), - unit_scale=True, - desc="Copying file", - file=TQDM_OUT, - ) as pbar: - while True: - buf = fsrc.read(length) - if not buf: - break - fdst.write(buf) - copied = len(buf) - pbar.update(copied) diff --git a/arrakis/utils/json.py b/arrakis/utils/json.py index c8ca006c..e5b6d651 100644 --- a/arrakis/utils/json.py +++ b/arrakis/utils/json.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """JSON utilities""" from __future__ import annotations @@ -23,15 +22,15 @@ class MyEncoder(json.JSONEncoder): def default(self, obj): # pylint: disable=E0202 if isinstance(obj, np.integer): return int(obj) - elif isinstance(obj, np.floating): + if isinstance(obj, np.floating): return float(obj) - elif isinstance(obj, complex): + if isinstance(obj, complex): return (obj.real, obj.imag) - elif isinstance(obj, np.ndarray): + if isinstance(obj, np.ndarray): return obj.tolist() - elif isinstance(obj, fits.Header): + if isinstance(obj, fits.Header): return head2dict(obj) - elif dataclasses.is_dataclass(obj): + if dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) - else: - return super().default(obj) + + return super().default(obj) diff --git a/arrakis/utils/meta.py b/arrakis/utils/meta.py index d2eabcf2..ccbb2e22 100644 --- a/arrakis/utils/meta.py +++ b/arrakis/utils/meta.py @@ -63,9 +63,9 @@ def yes_or_no(question: str) -> bool: reply = str(input(question + " (y/n): ")).lower().strip() if reply[:1] == "y": return True - elif reply[:1] == "n": + if reply[:1] == "n": return False - else: - msg = "Please answer 'y' or 'n'" - raise ValueError(msg) + + msg = "Please answer 'y' or 'n'" + raise ValueError(msg) return None diff --git a/arrakis/utils/msutils.py b/arrakis/utils/msutils.py index d14332e5..581b8358 100644 --- a/arrakis/utils/msutils.py +++ b/arrakis/utils/msutils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """MeasurementSet utilities""" from __future__ import annotations diff --git a/arrakis/utils/pipeline.py b/arrakis/utils/pipeline.py index 7e8bd861..f5b5a536 100644 --- a/arrakis/utils/pipeline.py +++ b/arrakis/utils/pipeline.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """Pipeline and flow utility functions""" from __future__ import annotations @@ -80,7 +79,7 @@ def upload_image_as_artifact_task( image_type in SUPPORTED_IMAGE_TYPES ), f"{image_path} has type {image_type}, and is not supported. Supported types are {SUPPORTED_IMAGE_TYPES}" - with open(image_path, "rb") as open_image: + with image_path.open("rb") as open_image: logger.info(f"Encoding {image_path} in base64") image_base64 = base64.b64encode(open_image.read()).decode() @@ -355,10 +354,12 @@ def __init__( loop_runner.run_sync(self.listen) def _draw_bar(self, remaining, all, **kwargs): + _ = kwargs update_ct = (all - remaining) - self.tqdm.n self.tqdm.update(update_ct) def _draw_stop(self, **kwargs): + _ = kwargs self.tqdm.close() diff --git a/arrakis/utils/typing.py b/arrakis/utils/typing.py index 63acce00..19b6ee2f 100644 --- a/arrakis/utils/typing.py +++ b/arrakis/utils/typing.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 """Typing utilities""" from __future__ import annotations diff --git a/arrakis/wsclean_rmsynth.py b/arrakis/wsclean_rmsynth.py index 09e094d6..77c354c7 100755 --- a/arrakis/wsclean_rmsynth.py +++ b/arrakis/wsclean_rmsynth.py @@ -289,7 +289,7 @@ def proper_rm_clean( np.save("fdf_clean.npy", fdf_clean) np.save("cc_vec.npy", cc_vec) - raise + # raise # Calculate the spectrum quarr = np.sum(cc_vec[:, np.newaxis] * np.exp(2.0j * np.outer(phis, ays)), axis=0) From a1d5a44409567ee973682812e7ab490343f9c244 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:42:45 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- arrakis/frion.py | 2 +- arrakis/utils/fitsutils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index e4c4bcff..50477c2f 100755 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -16,6 +16,7 @@ import numpy as np import pymongo from astropy.time import Time, TimeDelta +from FRion import correct, predict from prefect import flow, task from tqdm.auto import tqdm @@ -28,7 +29,6 @@ ) from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser -from FRion import correct, predict logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index 1274f38c..1a128f2d 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -11,10 +11,10 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS +from FRion.correct import find_freq_axis from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger -from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) From 94ec64ff28acf6d301cd7736c5fb9e12dcdcb107 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 11:53:34 +0800 Subject: [PATCH 06/17] Ruffage --- arrakis/__init__.py | 0 arrakis/cleanup.py | 21 ++-- arrakis/columns_possum.py | 3 +- arrakis/cutout.py | 103 +++++++++++++++--- arrakis/data/__init__.py | 1 + arrakis/frion.py | 108 ++++++++++++++----- arrakis/imager.py | 139 +++++++++++++----------- arrakis/init_database.py | 48 +++++---- arrakis/linmos.py | 106 +++++++++++------- arrakis/logger.py | 67 +++++++++--- arrakis/makecat.py | 204 +++++++++++++++++++++++------------ arrakis/merge_fields.py | 120 ++++++++++++++------- arrakis/process_region.py | 17 ++- arrakis/process_spice.py | 27 +++-- arrakis/rmclean_oncuts.py | 28 +++-- arrakis/rmsynth_oncuts.py | 134 ++++++++++++++++------- arrakis/utils/__init__.py | 1 + arrakis/utils/coordinates.py | 4 +- arrakis/utils/database.py | 15 +-- arrakis/utils/exceptions.py | 14 ++- arrakis/utils/fitsutils.py | 4 +- arrakis/utils/fitting.py | 4 +- arrakis/utils/io.py | 41 +++---- arrakis/utils/json.py | 11 +- arrakis/utils/meta.py | 39 ++++++- arrakis/utils/msutils.py | 16 +-- arrakis/utils/pipeline.py | 134 +++++++++-------------- arrakis/utils/plotting.py | 13 +-- arrakis/utils/typing.py | 12 +-- arrakis/validate.py | 117 ++++++++++++++++++-- arrakis/wsclean_rmsynth.py | 3 +- pyproject.toml | 6 ++ 32 files changed, 1051 insertions(+), 509 deletions(-) mode change 100755 => 100644 arrakis/__init__.py diff --git a/arrakis/__init__.py b/arrakis/__init__.py old mode 100755 new mode 100644 diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py index 911da3f4..b97d1e32 100755 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -25,11 +25,10 @@ # @task(name="Purge cublets") def purge_cubelet_beams(filepath: Path) -> Path: - """Clean up beam images + """Clean up beam images. Args: - workdir (str): Directory containing images - stoke (str): Stokes parameter + filepath (Path): Path to the beam image """ # Clean up beam images logger.critical(f"Removing {filepath}") @@ -40,10 +39,11 @@ def purge_cubelet_beams(filepath: Path) -> Path: @task(name="Make cutout tarball") def make_cutout_tarball(cutdir: Path, overwrite: bool = False) -> Path: - """Make a tarball of the cutouts directory + """Make a tarball of the cutouts directory. Args: cutdir (Path): Directory containing cutouts + overwrite (bool): Overwrite existing tarball Returns: Path: Path to the tarball @@ -76,13 +76,12 @@ def main( datadir: Path, overwrite: bool = False, ) -> None: - """Clean up beam images flow + """Clean up beam images flow. Args: datadir (Path): Directory with sub dir 'cutouts' overwrite (bool): Overwrite existing tarball """ - cutdir = datadir / "cutouts" # First, make a tarball of the cutouts @@ -117,6 +116,14 @@ def main( def cleanup_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create a parser for the cleanup stage. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser + """ # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -144,7 +151,7 @@ def cleanup_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" + """Command-line interface.""" gen_parser = generic_parser(parent_parser=True) clean_parser = cleanup_parser(parent_parser=True) parser = argparse.ArgumentParser( diff --git a/arrakis/columns_possum.py b/arrakis/columns_possum.py index 0b2491a5..e5c8dc97 100755 --- a/arrakis/columns_possum.py +++ b/arrakis/columns_possum.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Column names from RM-tools to catalogue +"""Column names from RM-tools to catalogue. @author: cvaneck """ diff --git a/arrakis/cutout.py b/arrakis/cutout.py index a3f9dabe..46a5b477 100755 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Produce cutouts from RACS cubes""" +"""Produce cutouts from RACS cubes.""" from __future__ import annotations @@ -54,7 +54,7 @@ class CutoutArgs(Struct): - """Arguments for cutout function""" + """Arguments for cutout function.""" """Name of the source""" ra_left: float @@ -78,6 +78,20 @@ def cutout_weight( beam_num: int, dryrun=False, ) -> pymongo.UpdateOne: + """Cutout the weight image. + + Args: + image_name (Path): Image name + source_id (str): Source ID + cutout_args (CutoutArgs | None): Cutout arguments + field (str): Field name + stoke (str): Stokes parameter + beam_num (int): Beam number + dryrun (bool, optional): If doing a dryrun. Defaults to False. + + Returns: + pymongo.UpdateOne: Update query + """ # Update database myquery = {"Source_ID": source_id} @@ -220,12 +234,11 @@ def get_args( source: pd.Series, outdir: Path, ) -> CutoutArgs | None: - """Get arguments for cutout function + """Get arguments for cutout function. Args: comps (pd.DataFrame): List of mongo entries for RACS components in island - beam (Dict): Mongo entry for the RACS beam - island_id (str): RACS island ID + source (pd.Series): Mongo entry for RACS island outdir (Path): Input directory Raises: @@ -235,7 +248,6 @@ def get_args( Returns: List[CutoutArgs]: List of cutout arguments for cutout function """ - logger.setLevel(logging.INFO) island_id = source.Source_ID @@ -248,9 +260,9 @@ def get_args( outdir.mkdir(parents=True, exist_ok=True) # Find image size - ras: u.Quantity = comps.RA.values * u.deg - decs: u.Quantity = comps.Dec.values * u.deg - majs: list[float] = comps.Maj.values * u.arcsec + ras: u.Quantity = comps.RA.to_numpy() * u.deg + decs: u.Quantity = comps.Dec.to_numpy() * u.deg + majs: u.Quantity = comps.Maj.to_numpy() * u.arcsec coords = SkyCoord(ras, decs, frame="icrs") padder = np.max(majs) @@ -343,7 +355,29 @@ def worker( pad: float = 3, username: str | None = None, password: str | None = None, -): +) -> list[pymongo.UpdateOne]: + """Cutout worker. + + Args: + host (str): MongoDB host + epoch (int): Epoch + source (pd.Series): Source dataframe + comps (pd.DataFrame): Components dataframe + outdir (Path): Output directory + image_name (Path): Image name + data_in_mem (np.ndarray): Pre-read image data + old_header (fits.Header): Old header + cube (SpectralCube): Cube + field (str): Field name + beam_num (int): Beam number + stoke (str): Stokes parameter + pad (float, optional): PSF padding. Defaults to 3. + username (str | None, optional): MongoDB username. Defaults to None. + password (str | None, optional): MongoDB password. Defaults to None. + + Returns: + list[pymongo.UpdateOne]: List of update queries + """ _, _, comp_col = get_db( host=host, epoch=epoch, username=username, password=password ) @@ -393,14 +427,38 @@ def big_cutout( password: str | None = None, limit: int | None = None, ) -> list[pymongo.UpdateOne]: + """Make cutouts in parallel. + + Args: + sources (pd.DataFrame): Source dataframe + comps (pd.DataFrame): Components dataframe + beam_num (int): Beam number + stoke (str): Stokes parameter + datadir (Path): Directory with images + outdir (Path): Output directory + host (str): MongoDB host + epoch (int): Epoch + field (str): Field name + pad (float, optional): PSF padding. Defaults to 3. + username (str | None, optional): MongoDB username. Defaults to None. + password (str | None, optional): MondoDB password. Defaults to None. + limit (int | None, optional): Limit number of sources. Defaults to None. + + Raises: + FileNotFoundError: If no images found + FileExistsError: If more than one image found + + Returns: + list[pymongo.UpdateOne]: List of update queries + """ wild = f"image.restored.{stoke.lower()}*contcube*beam{beam_num:02}.conv.fits" images = list(datadir.glob(wild)) if len(images) == 0: msg = f"No images found matching '{wild}'" - raise Exception(msg) - elif len(images) > 1: + raise FileNotFoundError(msg) + if len(images) > 1: msg = f"More than one image found matching '{wild}'. Files {images=}" - raise Exception(msg) + raise FileExistsError(msg) image_name = images[0] @@ -467,12 +525,15 @@ def cutout_islands( field (str): RACS field name. directory (Path): Directory to store cutouts. host (str): MongoDB host. + epoch (int): RACS epoch. + sbid (int, optional): SBID. Defaults to None. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. verbose (bool, optional): Verbose output. Defaults to True. pad (int, optional): Number of beamwidths to pad cutouts. Defaults to 3. stokeslist (List[str], optional): Stokes parameters to cutout. Defaults to None. dryrun (bool, optional): Do everything except write FITS files. Defaults to True. + limit (int, optional): Limit to this many islands. Defaults to None. """ if stokeslist is None: stokeslist = ["I", "Q", "U", "V"] @@ -580,7 +641,7 @@ def cutout_islands( def main(args: argparse.Namespace) -> None: - """Main script + """Main script. Args: args (argparse.Namespace): Command-line args @@ -604,6 +665,14 @@ def main(args: argparse.Namespace) -> None: def cutout_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Cutout parser. + + Args: + parent_parser (bool, optional): Parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: Argument parser + """ descStr = f""" {logo_str} @@ -612,8 +681,8 @@ def cutout_parser(parent_parser: bool = False) -> argparse.ArgumentParser: If Stokes V is present, it will be squished into RMS spectra. To use with MPI: - mpirun -n $NPROCS python -u cutout.py $cubedir $tabledir - $outdir --mpi + mpirun -n $NPROCS python -u cutout.py $cubedir $tabledir + $outdir --mpi """ # Parse the command line options @@ -640,7 +709,7 @@ def cutout_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli() -> None: - """Command-line interface""" + """Command-line interface.""" gen_parser = generic_parser(parent_parser=True) work_parser = workdir_arg_parser(parent_parser=True) cut_parser = cutout_parser(parent_parser=True) diff --git a/arrakis/data/__init__.py b/arrakis/data/__init__.py index e69de29b..d9f8a6e5 100644 --- a/arrakis/data/__init__.py +++ b/arrakis/data/__init__.py @@ -0,0 +1 @@ +"""Arrakis data.""" diff --git a/arrakis/frion.py b/arrakis/frion.py index e4c4bcff..acb476e2 100755 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 -"""Correct for the ionosphere in parallel""" +"""Correct for the ionosphere in parallel.""" from __future__ import annotations import argparse import logging -import os from pathlib import Path from pprint import pformat from typing import Callable @@ -35,43 +34,61 @@ class Prediction(Struct): - """FRion prediction""" + """FRion prediction. + + Attributes: + predict_file (Path): FRion prediction file + update (pymongo.UpdateOne): Pymongo update query + + """ - predict_file: str + predict_file: Path + """FRion prediction file""" update: pymongo.UpdateOne + """Pymongo update query""" class FrionResults(Struct): + """FRion results. + + Attributes: + prediction (Prediction): FRion prediction + correction (pymongo.UpdateOne): Pymongo update query + + """ + prediction: Prediction + """FRion prediction""" correction: pymongo.UpdateOne + """Pymongo update query""" @task(name="FRion correction") def correct_worker( - beam: dict, outdir: str, field: str, prediction: Prediction, island: dict + beam: dict, outdir: Path, field: str, prediction: Prediction, island: dict ) -> pymongo.UpdateOne: - """Apply FRion corrections to a single island + """Apply FRion corrections to a single island. Args: beam (Dict): MongoDB beam document outdir (str): Output directory field (str): RACS field name - predict_file (str): FRion prediction file - island_id (str): RACS island ID + prediction (Prediction): FRion prediction + island (str): Island document Returns: pymongo.UpdateOne: Pymongo update query """ predict_file = prediction.predict_file island_id = island["Source_ID"] - qfile = os.path.join(outdir, beam["beams"][field]["q_file"]) - ufile = os.path.join(outdir, beam["beams"][field]["u_file"]) + qfile = outdir / str(beam["beams"][field]["q_file"]) + ufile = outdir / str(beam["beams"][field]["u_file"]) qout = beam["beams"][field]["q_file"].replace(".fits", ".ion.fits") uout = beam["beams"][field]["u_file"].replace(".fits", ".ion.fits") - qout_f = os.path.join(outdir, qout) - uout_f = os.path.join(outdir, uout) + qout_f = outdir / str(qout) + uout_f = outdir / str(uout) correct.apply_correction_to_files( qfile, ufile, predict_file, qout_f, uout_f, overwrite=True @@ -97,14 +114,13 @@ def predict_worker( end_time: Time, freq: np.ndarray, cutdir: Path, - plotdir: Path, server: str = "ftp://ftp.aiub.unibe.ch/CODE/", prefix: str = "", formatter: str | Callable | None = None, proxy_server: str | None = None, pre_download: bool = False, ) -> Prediction: - """Make FRion prediction for a single island + """Make FRion prediction for a single island. Args: island (Dict): Pymongo island document @@ -114,14 +130,18 @@ def predict_worker( end_time (Time): End time of the observation freq (np.ndarray): Array of frequencies with units cutdir (str): Cutout directory - plotdir (str): Plot directory + server (str, optional): IONEX server. Defaults to "ftp://ftp.aiub.unibe.ch/CODE/". + prefix (str, optional): IONEX prefix. Defaults to "". + formatter (str | Callable | None, optional): IONEX formatter. Defaults to None. + proxy_server (str | None, optional): Proxy server. Defaults to None. + pre_download (bool, optional): Pre-download IONEX files. Defaults to False. Returns: - Tuple[str, pymongo.UpdateOne]: FRion prediction file and pymongo update query + Prediction: predict_file and update """ logger.setLevel(logging.INFO) - ifile: Path = cutdir / beam["beams"][field]["i_file"] + ifile: Path = cutdir / str(beam["beams"][field]["i_file"]) i_dir = ifile.parent iname = island["Source_ID"] ra = island["RA"] @@ -174,8 +194,10 @@ def predict_worker( msg = f"Could not find IONEX file with prefixes {_prefixes_to_try}" raise FileNotFoundError(msg) - predict_file = os.path.join(i_dir, f"{iname}_ion.txt") - predict.write_modulation(freq_array=freq, theta=theta, filename=predict_file) + predict_file = i_dir / f"{iname}_ion.txt" + predict.write_modulation( + freq_array=freq, theta=theta, filename=predict_file.as_posix() + ) logger.info(f"Prediction file: {predict_file}") myquery = {"Source_ID": iname} @@ -200,6 +222,15 @@ def predict_worker( @task(name="Index beams") def index_beams(island: dict, beams: list[dict]) -> dict: + """Index beams by island ID. + + Args: + island (dict): Island document + beams (list[dict]): List of beam documents + + Returns: + dict: Beam document + """ island_id = island["Source_ID"] beam_idx = next(i for i, b in enumerate(beams) if b["Source_ID"] == island_id) return beams[beam_idx] @@ -223,6 +254,26 @@ def serial_loop( ionex_formatter: str | Callable | None, ionex_predownload: bool, ) -> FrionResults: + """Serial loop for FRion. + + Args: + island (dict): Island document + field (str): Field name + beam (dict): Beam document + start_time (Time): Start time + end_time (Time): End time + freq_hz_array (np.ndarray): Frequencies in Hz + cutdir (Path): Cutout directory + plotdir (Path): Plot directory + ionex_server (str): IONEX server + ionex_prefix (str): IONEX prefix + ionex_proxy_server (str | None): IONEX proxy server + ionex_formatter (str | Callable | None): IONEX formatter + ionex_predownload (bool): Pre-download IONEX files + + Returns: + FrionResults: _description_ + """ prediction = predict_worker.fn( island=island, field=field, @@ -266,7 +317,7 @@ def main( ionex_predownload: bool = False, limit: int | None = None, ): - """FRion flow + """FRion flow. Args: field (str): RACS field name @@ -278,6 +329,7 @@ def main( password (str, optional): Mongo passwrod. Defaults to None. database (bool, optional): Update database. Defaults to False. ionex_server (str, optional): IONEX server. Defaults to "ftp://ftp.aiub.unibe.ch/CODE/". + ionex_prefix (str, optional): IONEX prefix. Defaults to "codg". ionex_proxy_server (str, optional): Proxy server. Defaults to None. ionex_formatter (Union[str, Callable], optional): IONEX formatter. Defaults to "ftp.aiub.unibe.ch". ionex_predownload (bool, optional): Pre-download IONEX files. Defaults to False. @@ -337,9 +389,9 @@ def main( msg = f"More than one SELECT=1 for {field} - try supplying SBID." raise ValueError(msg) - elif field_col.count_documents(query_3) == 0: + if field_col.count_documents(query_3) == 0: logger.error(f"No data for {field} with {query_3}, trying without SELECT=1.") - query_3 = query_3 = {"$and": [{"FIELD_NAME": f"{field}"}]} + query_3 = {"$and": [{"FIELD_NAME": f"{field}"}]} if sbid is not None: query_3["$and"].append({"SBID": sbid}) field_data = field_col.find_one({"FIELD_NAME": f"{field}"}) @@ -353,7 +405,7 @@ def main( end_time = start_time + TimeDelta(field_data["SCAN_TINT"] * u.second) freq = getfreq( - os.path.join(cutdir, f"{beams[0]['beams'][f'{field}']['q_file']}"), + cutdir / f"{beams[0]['beams'][f'{field}']['q_file']}", ) if limit is not None: @@ -428,6 +480,14 @@ def main( def frion_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create a parser for FRion. + + Args: + parent_parser (bool, optional): If parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser + """ # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -481,7 +541,7 @@ def frion_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" + """Command-line interface.""" import warnings from astropy.utils.exceptions import AstropyWarning diff --git a/arrakis/imager.py b/arrakis/imager.py index 0a6e6008..1d560a6f 100755 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -1,16 +1,14 @@ #!/usr/bin/env python3 -"""Arrkis imager""" +"""Arrkis imager.""" from __future__ import annotations import argparse import hashlib import logging -import os import pickle import shutil from concurrent.futures import ThreadPoolExecutor -from glob import glob from pathlib import Path from subprocess import CalledProcessError from typing import Any @@ -59,20 +57,35 @@ class ImageSet(Struct): - """Container to organise files related to t he imaging of a measurement set.""" + """Container to organise files related to t he imaging of a measurement set. + + Attributes: + ms (Path): Path to the measurement set that was imaged. + prefix (str): Prefix used for the wsclean output files. + image_lists (dict[str, list[Path]]): Dictionary of lists of images. The keys are the polarisations and the values are the list of images for that polarisation. + aux_lists (dict[tuple[str, str], list[Path]]): Dictionary of lists of auxillary images. The keys are a tuple of the polarisation and the image type, and the values are the list of images for that polarisation and image type. + + """ ms: Path """Path to the measurement set that was imaged.""" prefix: str """Prefix used for the wsclean output files.""" - image_lists: dict[str, list[str]] + image_lists: dict[str, list[Path]] """Dictionary of lists of images. The keys are the polarisations and the values are the list of images for that polarisation.""" - aux_lists: dict[tuple[str, str], list[str]] | None = None + aux_lists: dict[tuple[str, str], list[Path]] | None = None """Dictionary of lists of auxillary images. The keys are a tuple of the polarisation and the image type, and the values are the list of images for that polarisation and image type.""" class MFSImage(Struct): - """Representation of a multi-frequency synthesis image.""" + """Representation of a multi-frequency synthesis image. + + Attributes: + image (np.ndarray): The image data. + model (np.ndarray): The model data. + residual (np.ndarray): The residual data. + + """ image: np.ndarray """The image data.""" @@ -86,6 +99,16 @@ class MFSImage(Struct): def get_pol_axis_task( ms: Path, feed_idx: int | None = None, col: str = "RECEPTOR_ANGLE" ) -> float: + """Get the polarisation axis angle from the measurement set. + + Args: + ms (Path): Path to the measurement set. + feed_idx (int | None, optional): Feed index. Defaults to None. + col (str, optional): Receptor column. Defaults to "RECEPTOR_ANGLE". + + Returns: + float: Polarisation axis angle in degrees. + """ return get_pol_axis(ms=ms, feed_idx=feed_idx, col=col).to(u.deg).value @@ -117,8 +140,8 @@ def merge_imagesets(image_sets: list[ImageSet | None]) -> ImageSet: image_set.prefix == prefix ), f"{image_set.prefix=} does not match {prefix=}" - image_lists = {} - aux_lists = {} + image_lists: dict[str, Path] = {} + aux_lists: dict[tuple[str, str], list[Path]] = {} for image_set in image_sets: for pol, images in image_set.image_lists.items(): @@ -213,7 +236,7 @@ def get_wsclean(wsclean: Path | str) -> Path: """Pull wsclean image from dockerhub (or wherver). Args: - version (str, optional): wsclean image tag. Defaults to "3.1". + wsclean (Path | str): Path to wsclean image or dockerhub image. Returns: Path: Path to wsclean image. @@ -225,7 +248,7 @@ def get_wsclean(wsclean: Path | str) -> Path: def cleanup_imageset(purge: bool, image_set: ImageSet) -> None: - """Delete images associated with an input ImageSet + """Delete images associated with an input ImageSet. Args: purge (bool): Whether files will be deleted or skipped. @@ -242,22 +265,16 @@ def cleanup_imageset(purge: bool, image_set: ImageSet) -> None: logger.critical(f"Removing {pol=} images for {image_set.ms}") for image in image_list: logger.critical(f"Removing {image}") - try: - os.remove(image) - except FileNotFoundError: - logger.critical(f"{image} not available for deletion. ") + image.unlink(missing_ok=True) # The aux images are the same between the native images and the smoothed images, # they were just copied across directly without modification if image_set.aux_lists: logger.critical("Removing auxillary images. ") - for (pol, _aux), aux_list in image_set.aux_lists.items(): + for (_pol, _aux), aux_list in image_set.aux_lists.items(): for aux_image in aux_list: - try: - logger.critical(f"Removing {aux_image}") - os.remove(aux_image) - except FileNotFoundError: - logger.critical(f"{aux_image} not available for deletion. ") + logger.critical(f"Removing {aux_image}") + aux_image.unlink(missing_ok=True) return @@ -322,7 +339,7 @@ def image_beam( disable_pol_local_rms: bool = False, disable_pol_force_mask_rounds: bool = False, ) -> ImageSet: - """Image a single beam""" + """Image a single beam.""" logger = get_run_logger() # Evaluate the temp directory if a ENV variable is used temp_dir_images = parse_env_path(temp_dir_images) @@ -491,14 +508,15 @@ def image_beam( # Update the prefix prefix = out_dir / prefix.name - prefix_str = prefix.resolve().as_posix() + prefix_base = prefix.name + prefix_dir = prefix.parent # Check rms of image to check for divergence for pol in pols: - mfs_image = ( - f"{prefix_str}-MFS-image.fits" + mfs_image = prefix_dir / ( + f"{prefix_base}-MFS-image.fits" if len(pols) == 1 - else f"{prefix_str}-MFS-{pol}-image.fits" + else f"{prefix_base}-MFS-{pol}-image.fits" ) rms = mad_std(fits.getdata(mfs_image), ignore_nan=True) if rms > 1: @@ -508,25 +526,25 @@ def image_beam( ) # Get images - image_lists = {} - aux_lists = {} + image_lists: dict[str, Path] = {} + aux_lists: dict[tuple[str, str], list[Path]] = {} aux_suffixes = suffixes[1:] for pol in pols: imglob = ( - f"{prefix_str}-*[0-9]-image.fits" + f"{prefix_base}-*[0-9]-image.fits" if len(pols) == 1 - else f"{prefix_str}-*[0-9]-{pol}-image.fits" + else f"{prefix_base}-*[0-9]-{pol}-image.fits" ) - image_list = sorted(glob(imglob)) + image_list = sorted(prefix_dir.glob(imglob)) image_lists[pol] = image_list logger.info(f"Found {len(image_list)} images for {pol=} {ms}.") for aux in aux_suffixes: aux_list = ( - sorted(glob(f"{prefix_str}-*[0-9]-{aux}.fits")) + sorted(prefix_dir.glob(f"{prefix_base}-*[0-9]-{aux}.fits")) if len(pols) == 1 or aux == "psf" - else sorted(glob(f"{prefix_str}-*[0-9]-{pol}-{aux}.fits")) + else sorted(prefix_dir.glob(f"{prefix_base}-*[0-9]-{pol}-{aux}.fits")) ) aux_lists[(pol, aux)] = aux_list @@ -534,7 +552,7 @@ def image_beam( logger.info("Constructing ImageSet") image_set = ImageSet( - ms=ms, prefix=prefix_str, image_lists=image_lists, aux_lists=aux_lists + ms=ms, prefix=prefix_base, image_lists=image_lists, aux_lists=aux_lists ) logger.debug(f"{image_set=}") @@ -550,7 +568,7 @@ def make_cube( pol_angle_deg: float, aux_mode: str | None = None, ) -> tuple[Path, Path]: - """Make a cube from the images""" + """Make a cube from the images.""" logger = get_run_logger() logger.info(f"Creating cube for {pol=} {image_set.ms=}") @@ -586,16 +604,16 @@ def make_cube( # Create a cube name old_name = image_list[0] - out_dir = os.path.dirname(old_name) - old_base = os.path.basename(old_name) + out_dir = old_name.parent + old_base = old_name.name new_base = old_base b_idx = new_base.find("beam") + len("beam") + 2 sub = new_base[b_idx:] new_base = new_base.replace(sub, ".conv.fits") new_base = new_base.replace("image", f"image.{image_type}.{pol.lower()}") - new_name = os.path.join(out_dir, new_base) + new_name = out_dir / new_base # Deserialise beam - with open(common_beam_pkl, "rb") as f: + with common_beam_pkl.open("rb") as f: common_beam = pickle.load(f) new_header = common_beam.attach_to_header(new_header) fits.writeto(new_name, data_cube, new_header, overwrite=True) @@ -607,9 +625,9 @@ def make_cube( # 0 1234.5 # 1 6789.0 # etc. - new_w_name = new_name.replace( - f"image.{image_type}", f"weights.{image_type}" - ).replace(".fits", ".txt") + new_w_name = Path( + new_name.as_posix().replace(f"image.{image_type}", f"weights.{image_type}") + ).with_suffix(".txt") data = { "Channel": np.arange(len(rmss_arr)), "Weight": 1 / rmss_arr**2, # Want inverse variance @@ -622,11 +640,11 @@ def make_cube( @task(name="Get Beam", persist_result=True) def get_beam(image_set: ImageSet, cutoff: float | None) -> Path: - """Derive a common resolution across all images within a set of ImageSet + """Derive a common resolution across all images within a set of ImageSet. Args: image_set (ImageSet): ImageSet that a common resolution will be derived for - cuttoff (float, optional): The maximum major axis of the restoring beam that is allowed when + cutoff (float, optional): The maximum major axis of the restoring beam that is allowed when searching for the lowest common beam. Images whose restoring beam's major acis is larger than this are ignored. Defaults to None. @@ -659,7 +677,7 @@ def get_beam(image_set: ImageSet, cutoff: float | None) -> Path: # serialise the beam common_beam_pkl = Path(f"beam_{image_hash}.pkl") - with open(common_beam_pkl, "wb") as f: + with common_beam_pkl.open("wb") as f: logger.info(f"Creating {common_beam_pkl}") pickle.dump(common_beam, f) @@ -673,14 +691,15 @@ def smooth_imageset( cutoff: float | None = None, aux_mode: str | None = None, ) -> ImageSet: - """Smooth all images described within an ImageSet to a desired resolution + """Smooth all images described within an ImageSet to a desired resolution. Args: image_set (ImageSet): Container whose image_list will be convolved to common resolution common_beam_pkl (Path): Location of pickle file with beam description cutoff (Optional[float], optional): PSF cutoff passed to the beamcon_2D worker. Defaults to None. - aux_model (Optional[str], optional): The image type in the `aux_lists` property of `image_set` that contains the images to smooth. If + aux_mode (Optional[str], optional): The image type in the `aux_lists` property of `image_set` that contains the images to smooth. If not set then the `image_lists` property of `image_set` is used. Defaults to None. + Returns: ImageSet: A copy of `image_set` pointing to the smoothed images. Note the `aux_images` property is not carried forward. """ @@ -688,7 +707,7 @@ def smooth_imageset( logger = get_run_logger() # Deserialise the beam - with open(common_beam_pkl, "rb") as f: + with common_beam_pkl.open("rb") as f: logger.info(f"Loading common beam from {common_beam_pkl}") common_beam = pickle.load(f) @@ -742,8 +761,9 @@ def smooth_imageset( def cleanup( purge: bool, image_sets: list[ImageSet], ignore_files: list[Any] | None = None ) -> None: - """Utility to remove all images described by an collection of ImageSets. Internally - called `cleanup_imageset`. + """Utility to remove all images described by an collection of ImageSets. + + Internally called `cleanup_imageset`. Args: purge (bool): Whether files are actually removed or skipped. @@ -767,8 +787,7 @@ def cleanup( @task(name="Fix MeasurementSet Directions") def fix_ms(ms: Path) -> Path: - """Apply the corrections to the FEED table of a measurement set that - is required for the ASKAP measurement sets. + """Apply the corrections to the FEED table of a measurement set that is required for the ASKAP measurement sets. Args: ms (Path): Path to the measurement set to fix. @@ -782,12 +801,14 @@ def fix_ms(ms: Path) -> Path: @task(name="Fix MeasurementSet Correlations") def fix_ms_askap_corrs(ms: Path, *args, **kwargs) -> Path: - """Applies a correction to raw telescope polarisation products to rotate them - to the wsclean espected form. This is essentially related to the third-axis of - ASKAP and reorientating its 'X' and 'Y's. + """Applies a correction to raw telescope polarisation products to rotate them to the wsclean espected form. + + This is essentially related to the third-axis of ASKAP and reorientating its 'X' and 'Y's. Args: ms (Path): Path to the measurement set to be corrected. + *args: Additional arguments to pass to the correction function. + **kwargs: Additional keyword arguments to pass to the correction function. Returns: Path: Path of the measurementt set containing the corrections. @@ -796,7 +817,7 @@ def fix_ms_askap_corrs(ms: Path, *args, **kwargs) -> Path: logger.info(f"Correcting {ms!s} correlations for wsclean. ") - fix_ms_corrs(ms=ms, *args, **kwargs) + fix_ms_corrs(ms, *args, **kwargs) return ms @@ -840,7 +861,7 @@ def main( disable_pol_local_rms: bool = False, disable_pol_force_mask_rounds: bool = False, ): - """Arrakis imager flow + """Arrakis imager flow. Args: msdir (Path): Path to the directory containing the MS files. @@ -880,7 +901,6 @@ def main( disable_pol_local_rms (bool, optional): Disable local RMS for polarisation images. Defaults to False. disable_pol_force_mask_rounds (bool, optional): Disable force mask rounds for polarisation images. Defaults to False. """ - simage = get_wsclean(wsclean=wsclean_path) logger.info(f"Searching {msdir} for MS matching {ms_glob_pattern}.") @@ -1084,7 +1104,6 @@ def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser: Returns: argparse.ArgumentParser: Arguments required for the imager routine """ - # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -1303,7 +1322,7 @@ def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" + """Command-line interface.""" im_parser = imager_parser(parent_parser=True) work_parser = workdir_arg_parser(parent_parser=True) diff --git a/arrakis/init_database.py b/arrakis/init_database.py index 54679d43..d3f75233 100755 --- a/arrakis/init_database.py +++ b/arrakis/init_database.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Create the Arrakis database""" +"""Create the Arrakis database.""" from __future__ import annotations @@ -28,7 +28,7 @@ def source2beams(ra: float, dec: float, database: Table, max_sep: float = 1) -> Table: - """Find RACS beams that contain a given source position + """Find RACS beams that contain a given source position. Args: ra (float): RA of source in degrees. @@ -46,8 +46,9 @@ def source2beams(ra: float, dec: float, database: Table, max_sep: float = 1) -> def ndix_unique(x: np.ndarray) -> tuple[np.ndarray, list[np.ndarray]]: - """Find the N-dimensional array of indices of the unique values in x - From https://stackoverflow.com/questions/54734545/indices-of-unique-values-in-array + """Find the N-dimensional array of indices of the unique values in x. + + From https://stackoverflow.com/questions/54734545/indices-of-unique-values-in-array. Args: x (np.ndarray): Array of values. @@ -68,7 +69,7 @@ def ndix_unique(x: np.ndarray) -> tuple[np.ndarray, list[np.ndarray]]: def cat2beams( mastercat: Table, database: Table, max_sep: float = 1 ) -> tuple[np.ndarray, np.ndarray, Angle]: - """Find the separations between sources in the master catalogue and the RACS beams + """Find the separations between sources in the master catalogue and the RACS beams. Args: mastercat (Table): Master catalogue table. @@ -100,7 +101,7 @@ def source_database( username: str | None = None, password: str | None = None, ) -> tuple[InsertManyResult, InsertManyResult]: - """Insert sources into the database + """Insert sources into the database. Following https://medium.com/analytics-vidhya/how-to-upload-a-pandas-dataframe-to-mongodb-ffa18c0953c1 @@ -108,6 +109,7 @@ def source_database( islandcat (Table): Island catalogue table. compcat (Table): Component catalogue table. host (str): MongoDB host IP. + epoch (int): RACS epoch number. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo host. Defaults to None. @@ -121,7 +123,7 @@ def source_database( if isinstance(df_i["Source_ID"][0], bytes): logger.info("Decoding strings!") str_df = df_i.select_dtypes([object]) - str_df = str_df.stack().str.decode("utf-8").unstack() + str_df = str_df.melt().str.decode("utf-8").pivot_table() for col in str_df: df_i[col] = str_df[col] @@ -148,7 +150,7 @@ def source_database( if isinstance(df_c["Source_ID"][0], bytes): logger.info("Decoding strings!") str_df = df_c.select_dtypes([object]) - str_df = str_df.stack().str.decode("utf-8").unstack() + str_df = str_df.melt().str.decode("utf-8").pivot_table() for col in str_df: df_c[col] = str_df[col] @@ -179,9 +181,10 @@ def beam_database( username: str | None = None, password: str | None = None, ) -> InsertManyResult: - """Insert beams into the database + """Insert beams into the database. Args: + database_path (Path): Path to RACS database. islandcat (Table): Island catalogue table. host (str): MongoDB host IP. username (str, optional): Mongo username. Defaults to None. @@ -218,9 +221,10 @@ def beam_database( def get_catalogue(survey_dir: Path, epoch: int = 0) -> Table: - """Get the RACS catalogue for a given epoch + """Get the RACS catalogue for a given epoch. Args: + survey_dir (Path): Path to RACS database. epoch (int, optional): Epoch number. Defaults to 0. Returns: @@ -262,11 +266,12 @@ def get_catalogue(survey_dir: Path, epoch: int = 0) -> Table: def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> list[dict]: - """Get beams from the master catalogue + """Get beams from the master catalogue. Args: mastercat (Table): Master catalogue table. database (Table): RACS database table. + epoch (int, optional): RACS epoch number. Defaults to 0. Returns: List[Dict]: List of beam dictionaries. @@ -294,7 +299,7 @@ def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> list[dict]: fields, ) - beam_list = [] + beam_list: list[dict] = [] for _i, (val, idx) in enumerate( tqdm(zip(vals, ixs), total=len(vals), desc="Getting beams", file=TQDM_OUT) ): @@ -304,10 +309,10 @@ def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> list[dict]: beams = database[seps[0][idx.astype(int)]] for _j, field in enumerate(np.unique(beams["FIELD_NAME"])): ndx = beams["FIELD_NAME"] == field - field = field.replace("_test4_1.05_", "_") if epoch == 0 else field + correct_field = field.replace("_test4_1.05_", "_") if epoch == 0 else field beam_dict.update( { - field: { + correct_field: { "beam_list": list(beams["BEAM_NUM"][ndx]), "SBIDs": list(np.unique(beams["SBID"][ndx])), "DR1": bool(np.unique(in_dr1[seps[0][idx.astype(int)]][ndx])), @@ -335,7 +340,7 @@ def beam_inf( username: str | None = None, password: str | None = None, ) -> InsertManyResult: - """Get the beam information""" + """Get the beam information.""" tabs: list[Table] = [] for row in tqdm(database, desc="Reading beam info", file=TQDM_OUT): try: @@ -378,7 +383,7 @@ def read_racs_database( epoch: int, table: str, ) -> Table: - """Read the RACS database from CSVs or postgresql + """Read the RACS database from CSVs or postgresql. Args: survey_dir (Path): Path to RACS database (i.e. 'askap_surveys/racs' repo). @@ -415,7 +420,7 @@ def field_database( username: str | None = None, password: str | None = None, ) -> tuple[InsertManyResult, InsertManyResult]: - """Reset and load the field database + """Reset and load the field database. Args: survey_dir (Path): Path to RACS database (i.e. 'askap_surveys/racs' repo). @@ -432,8 +437,8 @@ def field_database( database["COMMENT"] = database["COMMENT"].astype(str) # Remove rows with SBID < 0 database = database[database["SBID"] >= 0] - df = database.to_pandas() - field_list_dict = df.to_dict("records") + database_df = database.to_pandas() + field_list_dict = database_df.to_dict("records") logger.info("Loading fields into mongo...") field_col = get_field_db( host=host, epoch=epoch, username=username, password=password @@ -472,12 +477,13 @@ def main( epochs: list[int] = 0, force: bool = False, ) -> None: - """Main script + """Main script. Args: load (bool, optional): Load the database. Defaults to False. islandcat (Union[str, None], optional): Island catalogue. Defaults to None. compcat (Union[str, None], optional): Component catalogue. Defaults to None. + database_path (Union[Path, None], optional): Path to RACS database. Defaults to None. host (str, optional): Mongo host. Defaults to "localhost". username (Union[str, None], optional): Mongo username. Defaults to None. password (Union[str, None], optional): Mongo password. Defaults to None. @@ -570,7 +576,7 @@ def main( def cli(): - """Command-line interface""" + """Command-line interface.""" import argparse # Help string to be shown using the -h option diff --git a/arrakis/linmos.py b/arrakis/linmos.py index f1473f7c..fcaa984f 100755 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Run LINMOS on cutouts in parallel""" +"""Run LINMOS on cutouts in parallel.""" from __future__ import annotations @@ -8,7 +8,6 @@ import os import shlex import warnings -from glob import glob from pathlib import Path from pprint import pformat from typing import NamedTuple as Struct @@ -40,7 +39,7 @@ class ImagePaths(Struct): - """Class to hold image paths""" + """Class to hold image paths.""" images: list[Path] """List of image paths""" @@ -55,11 +54,11 @@ def find_images( stoke: str, datadir: Path, ) -> ImagePaths: - """Find the images and weights for a given field and stokes parameter + """Find the images and weights for a given field and stokes parameter. Args: field (str): Field name. - beams (dict): Beam information. + beams_row (tuple[int, pd.Series]): Row from beams collection. stoke (str): Stokes parameter. datadir (Path): Data directory. @@ -113,10 +112,10 @@ def find_images( def smooth_images( image_dict: dict[str, ImagePaths], ) -> dict[str, ImagePaths]: - """Smooth cubelets to a common resolution + """Smooth cubelets to a common resolution. Args: - image_list (ImagePaths): List of cubelets to smooth. + image_dict (dict[str, ImagePaths]): Cubelets to smooth. Returns: ImagePaths: Smoothed cubelets. @@ -148,8 +147,8 @@ def genparset( stoke: str, datadir: Path, holofile: Path | None = None, -) -> str: - """Generate parset for LINMOS +) -> Path: + """Generate parset for LINMOS. Args: image_paths (ImagePaths): List of images and weights. @@ -161,7 +160,7 @@ def genparset( Exception: If no files are found. Returns: - str: Path to parset file. + Path: Path to parset file. """ logger.setLevel(logging.INFO) @@ -191,7 +190,7 @@ def genparset( linmos_image_str = f"{first_image[:first_image.find('beam')]}linmos" linmos_weight_str = f"{first_weight[:first_weight.find('beam')]}linmos" - parset_file = os.path.join(parset_dir, f"linmos_{stoke}.in") + parset_file = parset_dir / f"linmos_{stoke}.in" parset = f"""linmos.names = {image_string} linmos.weights = {weight_string} linmos.imagetype = fits @@ -215,7 +214,7 @@ def genparset( else: logger.warning("No holography file provided - not correcting leakage!") - with open(parset_file, "w") as f: + with parset_file.open("w") as f: f.write(parset) return parset_file @@ -223,9 +222,9 @@ def genparset( @task(name="Run linmos") def linmos( - parset: str | None, fieldname: str, image: str, holofile: Path + parset: Path | None, fieldname: str, image: Path, holofile: Path ) -> pymongo.UpdateOne | None: - """Run linmos + """Run linmos. Args: parset (str): Path to parset file. @@ -246,51 +245,58 @@ def linmos( if parset is None: return None - workdir = os.path.dirname(parset) - rootdir = os.path.split(workdir)[0] - parset_name = os.path.basename(parset) - source = os.path.basename(workdir) + # Yes this a bit overly complicated parsing of the file paths + # But I want to be sure that the paths are correct + # And I am but a grug + workdir = parset.parent + rootdir = workdir.parent + parset_name = parset.name + source = workdir.name stoke = parset_name[parset_name.find(".in") - 1] - log_file = parset.replace(".in", ".log") - linmos_command = shlex.split(f"linmos -c {parset}") + log_file = parset.with_suffix(".log") + linmos_command = shlex.split(f"linmos -c {parset.as_posix()}") holo_folder = holofile.parent output = sclient.execute( - image=image, + image=image.as_posix(), command=linmos_command, - bind=f"{rootdir}:{rootdir},{holo_folder}:{holo_folder}", + bind=f"{rootdir.as_posix()}:{rootdir.as_posix()},{holo_folder.as_posix()}:{holo_folder.as_posix()}", return_result=True, quiet=False, stream=True, ) - with open(log_file, "w") as f: + with log_file.open("w") as f: for line in output: # We could log this, but it's a lot of output # We seem to be DDoS'ing the Prefect server # logger.info(line) f.write(line) - new_files = glob(f"{workdir}/*.cutout.image.restored.{stoke.lower()}*.linmos.fits") + new_files = list( + workdir.glob(f"*.cutout.image.restored.{stoke.lower()}*.linmos.fits") + ) if len(new_files) != 1: msg = f"LINMOS file not found! -- check {log_file}?" - raise Exception(msg) + raise FileNotFoundError(msg) - new_file = os.path.abspath(new_files[0]) - outer = os.path.basename(os.path.dirname(new_file)) - inner = os.path.basename(new_file) - new_file = os.path.join(outer, inner) + new_file = new_files[0].absolute() + outer = Path(new_file.parent.name) + inner = new_file.name + new_file = outer / inner - logger.info(f"Cube now in {workdir}/{inner}") + logger.info(f"Cube now in {(workdir/inner).as_posix()}") query = {"Source_ID": source} - newvalues = {"$set": {f"beams.{fieldname}.{stoke.lower()}_file": new_file}} + newvalues = { + "$set": {f"beams.{fieldname}.{stoke.lower()}_file": new_file.as_posix()} + } return pymongo.UpdateOne(query, newvalues) -def get_yanda(version="1.3.0") -> str: +def get_yanda(version="1.3.0") -> Path: """Pull yandasoft image from dockerhub. Args: @@ -300,7 +306,7 @@ def get_yanda(version="1.3.0") -> str: str: Path to yandasoft image. """ sclient.load(f"docker://csirocass/yandasoft:{version}-galaxy") - return os.path.abspath(sclient.pull()) + return Path(sclient.pull()).absolute() # We reduce the inner loop to a serial call @@ -314,6 +320,21 @@ def serial_loop( holofile: Path, image: Path, ) -> list[pymongo.UpdateOne | None]: + """Serial loop for LINMOS. + + Finds images, generates parsets, and runs LINMOS. + + Args: + field (str): Field name. + beams_row (tuple[int, pd.Series]): Row from beams collection. + stokeslist (list[str]): List of Stokes parameters. + cutdir (Path): Cutout directory. + holofile (Path): Holography file. + image (Path): Path to yandasoft image. + + Returns: + list[pymongo.UpdateOne | None]: List of mongo update objects. + """ results = [] for stoke in stokeslist: image_path = find_images.fn( @@ -331,7 +352,7 @@ def serial_loop( result = linmos.fn( parset=parset, fieldname=field, - image=str(image), + image=image, holofile=holofile, ) results.append(result) @@ -354,12 +375,14 @@ def main( stokeslist: list[str] | None = None, limit: int | None = None, ) -> None: - """LINMOS flow + """LINMOS flow. Args: field (str): RACS field name. datadir (str): Data directory. host (str): MongoDB host IP. + epoch (int): Epoch. + sbid (int, optional): SBID. Defaults to None. holofile (str): Path to primary beam file. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. @@ -429,6 +452,14 @@ def main( def linmos_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create the linmos parser. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser. + """ # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -446,7 +477,7 @@ def linmos_parser(parent_parser: bool = False) -> argparse.ArgumentParser: parser = linmos_parser.add_argument_group("linmos arguments") parser.add_argument( - "--holofile", type=str, default=None, help="Path to holography image" + "--holofile", type=Path, default=None, help="Path to holography image" ) parser.add_argument( "--yanda", @@ -464,8 +495,7 @@ def linmos_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" - + """Command-line interface.""" gen_parser = generic_parser(parent_parser=True) work_parser = workdir_arg_parser(parent_parser=True) lin_parser = linmos_parser(parent_parser=True) diff --git a/arrakis/logger.py b/arrakis/logger.py index a4846020..8df8a6b6 100755 --- a/arrakis/logger.py +++ b/arrakis/logger.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Logging module for arrakis""" +"""Logging module for arrakis.""" from __future__ import annotations @@ -7,32 +7,59 @@ import io import logging +from arrakis.utils.typing import Struct + # https://stackoverflow.com/questions/61324536/python-argparse-with-argumentdefaultshelpformatter-and-rawtexthelpformatter class UltimateHelpFormatter( argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter -): ... +): + """Combines RawTextHelpFormatter and ArgumentDefaultsHelpFormatter.""" -class TqdmToLogger(io.StringIO): - """ - Output stream for TQDM which will output to logger module instead of - the StdOut. +class Formats(Struct): + """Log formats. + + Attributes: + debug (str): Debug log format + info (str): Info log format + warning (str): Warning log format + error (str): Error log format + critical (str): Critical log format + """ + debug: str + """Debug log format""" + info: str + """Info log format""" + warning: str + """Warning log format""" + error: str + """Error log format""" + critical: str + """Critical log format""" + + +class TqdmToLogger(io.StringIO): + """Output stream for TQDM which will output to logger module instead of the StdOut.""" + logger = None level = None buf = "" def __init__(self, logger, level=None): + """TQDM logger.""" super().__init__() self.logger = logger self.level = level or logging.INFO def write(self, buf): + """Write to the buffer.""" self.buf = buf.strip("\r\n\t ") def flush(self): + """Flush the buffer.""" self.logger.log(self.level, self.buf) @@ -41,6 +68,8 @@ def flush(self): # "SPICE: %(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s" # ) class CustomFormatter(logging.Formatter): + """Custom formatter for logging.""" + grey = "\x1b[38;20m" blue = "\x1b[34;20m" green = "\x1b[32;20m" @@ -50,16 +79,24 @@ class CustomFormatter(logging.Formatter): reset = "\x1b[0m" format_str = "%(asctime)s.%(msecs)03d %(module)s - %(funcName)s: %(message)s" - FORMATS = { - logging.DEBUG: f"{blue}SPICE-%(levelname)s{reset} {format_str}", - logging.INFO: f"{green}SPICE-%(levelname)s{reset} {format_str}", - logging.WARNING: f"{yellow}SPICE-%(levelname)s{reset} {format_str}", - logging.ERROR: f"{red}SPICE-%(levelname)s{reset} {format_str}", - logging.CRITICAL: f"{bold_red}SPICE-%(levelname)s{reset} {format_str}", - } + FORMATS = Formats( + debug=f"{blue}SPICE-%(levelname)s{reset} {format_str}", + info=f"{green}SPICE-%(levelname)s{reset} {format_str}", + warning=f"{yellow}SPICE-%(levelname)s{reset} {format_str}", + error=f"{red}SPICE-%(levelname)s{reset} {format_str}", + critical=f"{bold_red}SPICE-%(levelname)s{reset} {format_str}", + ) + + def format(self, record) -> str: + """Format the log record. + + Args: + record (LogRecord): The log record. - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) + Returns: + str: Formatted log. + """ + log_fmt = self.FORMATS._asdict().get(record.levelname.lower()) formatter = logging.Formatter(log_fmt, "%Y-%m-%d %H:%M:%S") return formatter.format(record) diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 25c15965..13d7f217 100755 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -1,16 +1,15 @@ #!/usr/bin/env python3 -"""Make an Arrakis catalogue""" +"""Make an Arrakis catalogue.""" from __future__ import annotations import argparse import logging -import os import time import warnings from pathlib import Path from pprint import pformat -from typing import Callable, NamedTuple +from typing import Callable import astropy.units as u import dask.dataframe as dd @@ -44,14 +43,24 @@ upload_image_as_artifact_task, ) from arrakis.utils.plotting import latexify -from arrakis.utils.typing import ArrayLike, TableLike +from arrakis.utils.typing import ArrayLike, Struct, TableLike logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) -class SpectralIndices(NamedTuple): +class SpectralIndices(Struct): + """Specral indices. + + Attributes: + alphas (np.ndarray): Alpha values + alphas_err (np.ndarray): Alpha errors + betas (np.ndarray): Beta values + betas_err (np.ndarray): Beta + + """ + alphas: np.ndarray alphas_err: np.ndarray betas: np.ndarray @@ -59,7 +68,7 @@ class SpectralIndices(NamedTuple): def combinate(data: ArrayLike) -> tuple[ArrayLike, ArrayLike]: - """Return all combinations of data with itself + """Return all combinations of data with itself. Args: data (ArrayLike): Data to combine. @@ -84,8 +93,7 @@ def flag_blended_components(cat: TableLike) -> TableLike: """ def is_blended_component(sub_df: pd.DataFrame) -> pd.DataFrame: - """Return a boolean series indicating whether a component is the maximum - component in a source. + """Return a boolean series indicating whether a component is the maximum component in a source. Args: sub_df (pd.DataFrame): DataFrame containing all components for a source @@ -167,9 +175,9 @@ def is_blended_component(sub_df: pd.DataFrame) -> pd.DataFrame: index=sub_df.index, ) - df = cat.to_pandas() - df = df.set_index("cat_id") - ddf = dd.from_pandas(df, chunksize=1000) + cat_df = cat.to_pandas() + cat_df = cat_df.set_index("cat_id") + ddf = dd.from_pandas(cat_df, chunksize=1000) grp = ddf.groupby("source_id") logger.info("Identifying blended components...") with ProgressBar(): @@ -225,11 +233,22 @@ def is_blended_component(sub_df: pd.DataFrame) -> pd.DataFrame: return cat -def lognorm_from_percentiles(x1, p1, x2, p2): - """Return a log-normal distribuion X parametrized by: +def lognorm_from_percentiles[T](x1: T, p1: T, x2: T, p2: T) -> tuple[T, T]: + """Return a log-normal distribuion 'X' based on percentiles. + + Parametrized by: P(X < p1) = x1 P(X < p2) = x2 + + Args: + x1 (T): Value at p1 + p1 (T): Percentile 1 + x2 (T): Value at p2 + p2 (T): Percentile 2 + + Returns: + tuple[T, T]: Scale and mean of the log-normal distribution """ x1 = np.log(x1) x2 = np.log(x2) @@ -243,7 +262,15 @@ def lognorm_from_percentiles(x1, p1, x2, p2): @task(name="Fix sigma_add") -def sigma_add_fix(tab: TableLike) -> TableLike: +def sigma_add_fix[TableLike](tab: TableLike) -> TableLike: + """Fix sigma_add values. + + Args: + tab (TableLike): Table with sigma_add values + + Returns: + TableLike: Fixed table + """ sigma_Q_low = np.array(tab["sigma_add_Q"] - tab["sigma_add_Q_err_minus"]) sigma_Q_high = np.array(tab["sigma_add_Q"] + tab["sigma_add_Q_err_plus"]) @@ -297,7 +324,7 @@ def sigma_add_fix(tab: TableLike) -> TableLike: def is_leakage(frac: float, sep: float, fit: Callable) -> bool: - """Determine if a source is leakage + """Determine if a source is leakage. Args: frac (float): Polarised fraction @@ -319,16 +346,19 @@ def get_fit_func( do_plot: bool = False, high_snr_cut: float = 30.0, ) -> tuple[Callable, plt.Figure]: - """Fit an envelope to define leakage sources + """Fit an envelope to define leakage sources. Args: tab (TableLike): Catalogue to fit nbins (int, optional): Number of bins along seperation axis. Defaults to 21. + offset (float, optional): Offset for fit. Defaults to 0.002. + degree (int, optional): Degree of polynomial fit. Defaults to 2. + do_plot (bool, optional): Plot the fit. Defaults to False. + high_snr_cut (float, optional): SNR cut for high SNR sources. Defaults to 30.0. Returns: Callable: 3rd order polynomial fit. """ - logger.info(f"Using {high_snr_cut=}.") # Select high SNR sources @@ -432,7 +462,7 @@ def get_fit_func( def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: - """Compute the local RM flag + """Compute the local RM flag. Args: good_cat (Table): Table with just good RMs @@ -444,9 +474,9 @@ def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: logger.info("Computing voronoi bins and finding bad RMs") logger.info(f"Number of available sources: {len(good_cat)}.") - df = good_cat.to_pandas() - df = df.reset_index() - df = df.set_index("cat_id") + good_cat_df = good_cat.to_pandas() + good_cat_df = good_cat_df.reset_index() + good_cat_df = good_cat_df.set_index("cat_id") df_out = big_cat.to_pandas() df_out = df_out.reset_index() @@ -456,6 +486,8 @@ def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: try: def sn_func(index, signal=None, noise=None): + # Signal and noise are not used, but required by voronoi_2d_binning + _, _ = signal, noise try: sn = len(np.array(index)) except TypeError: @@ -498,11 +530,10 @@ def sn_func(index, signal=None, noise=None): ) if num_of_bins >= target_bins: break - else: - logger.info( - f"Found {num_of_bins} bins, targeting minimum {target_bins}" - ) - target_sn -= 5 + logger.info( + f"Found {num_of_bins} bins, targeting minimum {target_bins}" + ) + target_sn -= 5 except ValueError as e: if "Not enough S/N in the whole set of pixels." not in str(e): raise e @@ -518,7 +549,7 @@ def sn_func(index, signal=None, noise=None): if not fail: logger.info(f"Found {len(set(bin_number))} bins") - df["bin_number"] = bin_number + good_cat_df["bin_number"] = bin_number # Use sigma clipping to find outliers def masker(x): @@ -527,13 +558,13 @@ def masker(x): index=x.index, ) - perc_g = df.groupby("bin_number").apply( + perc_g = good_cat_df.groupby("bin_number").apply( masker, ) # Put flag into the catalogue - df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] - df = df.drop(columns=["bin_number"]) - df_out.update(df["local_rm_flag"]) + good_cat_df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] + good_cat_df = good_cat_df.drop(columns=["bin_number"]) + df_out.update(good_cat_df["local_rm_flag"]) except Exception as e: logger.error(f"Failed to compute local RM flag: {e}") @@ -557,18 +588,25 @@ def masker(x): @task(name="Add cuts and flags") -def cuts_and_flags( +def cuts_and_flags[TableLike]( cat: TableLike, leakage_degree: int = 4, leakage_bins: int = 16, leakage_snr: float = 30.0, ) -> TableLike: - """Cut out bad sources, and add flag columns + """Cut out bad sources, and add flag columns. A flag of 'True' means the source is bad. Args: - cat (rmt): Catalogue to cut and flag + cat (TableLike): Catalogue to cut and flag + leakage_degree (int, optional): Degree of leakage fit. Defaults to 4. + leakage_bins (int, optional): Number of bins for leakage fit. Defaults to 16. + leakage_snr (float, optional): SNR cut for leakage fit. Defaults to 30.0. + + Returns: + TableLike: Catalogue with cuts and flags + """ # SNR flag snr_flag = cat["snr_polint"] < 8 @@ -630,6 +668,14 @@ def cuts_and_flags( @task(name="Get spectral indices") def get_alpha(cat: TableLike) -> SpectralIndices: + """Get spectral indices from a catalogue. + + Args: + cat (TableLike): Catalogue + + Returns: + SpectralIndices: _description_ + """ coefs_str = cat["stokesI_model_coef"] coefs_err_str = cat["stokesI_model_coef_err"] alphas = [] @@ -658,7 +704,22 @@ def get_alpha(cat: TableLike) -> SpectralIndices: @task(name="Get integration times") -def get_integration_time(cat: RMTable, field_col: Collection, sbid: int | None = None): +def get_integration_time( + cat: RMTable, field_col: Collection, sbid: int | None = None +) -> u.Quantity: + """Get the integration times for a given catalogue. + + Args: + cat (RMTable): RM catalogue + field_col (Collection): Field collection + sbid (int | None, optional): SBID. Defaults to None. + + Raises: + ValueError: If no data is found for the given query. + + Returns: + u.Quantity: Integration times + """ logger.warning("Will be stripping the trailing field character prefix. ") field_names = [ name[:-1] if name[-1] in ("A", "B") else name for name in list(cat["tile_id"]) @@ -692,8 +753,8 @@ def get_integration_time(cat: RMTable, field_col: Collection, sbid: int | None = if doc_count == 0: msg = f"No data for query {query}" raise ValueError(msg) - else: - logger.warning("Using SELECT=0 instead.") + + logger.warning("Using SELECT=0 instead.") field_data = list(field_col.find(query, reutrn_vals)) tint_df = pd.DataFrame(field_data) @@ -708,7 +769,7 @@ def get_integration_time(cat: RMTable, field_col: Collection, sbid: int | None = logger.debug(f"Returned results: {tint_df=}") - tints = tint_df.loc[field_names]["SCAN_TINT"].values * u.s + tints: u.Quantity = tint_df.loc[field_names]["SCAN_TINT"].to_numpy() * u.s assert len(tints) == len(field_names), "Mismatch in number of integration times" assert len(tints) == len(cat), "Mismatch in number of integration times and sources" @@ -716,11 +777,12 @@ def get_integration_time(cat: RMTable, field_col: Collection, sbid: int | None = return tints -def add_metadata(vo_table: VOTableFile, filename: str): - """Add metadata to VO Table for CASDA +def add_metadata(vo_table: VOTableFile, filename: Path): + """Add metadata to VO Table for CASDA. Args: vo_table (vot): VO Table object + filename (Path): Output file name Returns: vot: VO Table object with metadata @@ -740,10 +802,7 @@ def add_metadata(vo_table: VOTableFile, filename: str): if len(vo_table.params) > 0: logger.warning(f"{filename} already has params - not adding") return vo_table - _, ext = os.path.splitext(filename) - cat_name = ( - os.path.basename(filename).replace(ext, "").replace(".", "_").replace("-", "_") - ) + cat_name = filename.stem.replace(".", "_").replace("-", "_") idx_fields = "ra,dec,cat_id,source_id" pri_fields = ( "ra,dec,cat_id,source_id,rm,polint,snr_polint,fracpol,stokesI,sigma_add" @@ -776,21 +835,8 @@ def add_metadata(vo_table: VOTableFile, filename: str): return vo_table -def replace_nans(filename: str): - """Replace NaNs in a XML table with a string - - Args: - filename (str): File name - """ - # with open(filename, "r") as f: - # xml = f.read() - # xml = xml.replace("NaN", "null") - # with open(filename, "w") as f: - # f.write(xml) - - def fix_blank_units(rmtab: TableLike) -> TableLike: - """Fix blank units in table + """Fix blank units in table. Args: rmtab (TableLike): TableLike @@ -808,7 +854,13 @@ def fix_blank_units(rmtab: TableLike) -> TableLike: @task(name="Write votable") -def write_votable(rmtab: TableLike, outfile: str) -> None: +def write_votable(rmtab: TableLike, outfile: Path) -> None: + """Write a table to a VO Table. + + Args: + rmtab (TableLike): RM Table + outfile (Path): Ouput file name + """ # Replace bad column names fix_columns = { "catalog": "catalog_name", @@ -824,13 +876,10 @@ def write_votable(rmtab: TableLike, outfile: str) -> None: vo_table.version = "1.3" vo_table = add_metadata(vo_table, outfile) vot.writeto(vo_table, outfile) - # Fix NaNs for CASDA - replace_nans(outfile) def update_tile_separations(rmtab: TableLike, field_col: Collection) -> TableLike: - """ - Update the tile separations in the catalogue + """Update the tile separations in the catalogue. Args: rmtab (TableLike): Table to update @@ -902,13 +951,18 @@ def main( username: str | None = None, password: str | None = None, verbose: bool = True, - outfile: str | None = None, + outfile: Path | None = None, ) -> None: - """Make a catalogue from the Arrakis database flow + """Make a catalogue from the Arrakis database flow. Args: field (str): RACS field name host (str): MongoDB host IP + epoch (int): Epoch of the data + sbid (int, optional): SBID to use. Defaults to None. + leakage_degree (int, optional): Degree of leakage fit. Defaults to 4. + leakage_bins (int, optional): Number of bins for leakage fit. Defaults to 16. + leakage_snr (float, optional): SNR for leakage fit. Defaults to 30.0. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. verbose (bool, optional): Verbose output. Defaults to True. @@ -1078,7 +1132,7 @@ def main( rmtab.add_column(new_col) if src == "synth": - for src_id, comp in comps_df.iterrows(): + for _src_id, comp in comps_df.iterrows(): try: data += [comp["rmclean_summary"][col]] except KeyError: @@ -1087,7 +1141,7 @@ def main( rmtab.add_column(new_col) if src == "header": - for src_id, comp in comps_df.iterrows(): + for _src_id, comp in comps_df.iterrows(): data += [comp["header"][col]] new_col = Column(data=data, name=name, dtype=typ, unit=unit) rmtab.add_column(new_col) @@ -1096,7 +1150,7 @@ def main( columns_possum.sourcefinder_columns, desc="Adding BDSF data", file=TQDM_OUT ): data = [] - for src_id, comp in comps_df.iterrows(): + for _, comp in comps_df.iterrows(): data += [comp[selcol]] new_col = Column(data=data, name=selcol) rmtab.add_column(new_col) @@ -1208,7 +1262,7 @@ def main( return logger.info(f"Writing {outfile} to disk") - _, ext = os.path.splitext(outfile) + ext = outfile.suffix if ext in (".xml", ".vot"): write_votable(rmtab, outfile) else: @@ -1219,6 +1273,14 @@ def main( def cat_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Make a catalogue parser. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser + """ # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -1259,7 +1321,7 @@ def cat_parser(parent_parser: bool = False) -> argparse.ArgumentParser: "--catfile", dest="outfile", default=None, - type=str, + type=Path, help="File to save table to.", ) @@ -1267,7 +1329,7 @@ def cat_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" + """Command-line interface.""" import argparse from astropy.utils.exceptions import AstropyWarning diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index aa8ecf18..a60ba5f6 100755 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 -"""Merge multiple RACS fields""" +"""Merge multiple RACS fields.""" from __future__ import annotations import argparse -import os from pathlib import Path from pprint import pformat from shutil import copyfile @@ -19,15 +18,23 @@ from arrakis.utils.io import try_mkdir -def make_short_name(name: str) -> str: - return os.path.join(os.path.basename(os.path.dirname(name)), os.path.basename(name)) +def make_short_name(name: Path) -> str: + """Make a short name for a file. + + Args: + name (Path): File name + + Returns: + str: Short name + """ + return (Path(name.parent.name) / name.name).as_posix() @task(name="Copy singleton island") def copy_singleton( beam: dict[str, Any], field_dict: dict[str, Path], merge_name: str, data_dir: Path ) -> list[pymongo.UpdateOne]: - """Copy an island within a single field to the merged field + """Copy an island within a single field to the merged field. Args: beam (dict): Beam document @@ -56,19 +63,23 @@ def copy_singleton( new_dir = data_dir / str(beam["Source_ID"]) new_dir.mkdir(exist_ok=True) - i_file_new = (new_dir / i_file_old.name).replace(".fits", ".edge.linmos.fits") - q_file_new = (new_dir / q_file_old.name).replace(".fits", ".edge.linmos.fits") - u_file_new = (new_dir / u_file_old.name).replace(".fits", ".edge.linmos.fits") + i_file_new = (new_dir / i_file_old.name).with_suffix(".edge.linmos.fits") + q_file_new = (new_dir / q_file_old.name).with_suffix(".edge.linmos.fits") + u_file_new = (new_dir / u_file_old.name).with_suffix(".edge.linmos.fits") for src, dst in zip( [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] ): copyfile(src, dst) - src_weight = src.replace(".image.restored.", ".weights.").replace( - ".ion", "" + src_weight = ( + src.as_posix() + .replace(".image.restored.", ".weights.") + .replace(".ion", "") ) - dst_weight = dst.replace(".image.restored.", ".weights.").replace( - ".ion", "" + dst_weight = ( + dst.as_posix() + .replace(".image.restored.", ".weights.") + .replace(".ion", "") ) copyfile(src_weight, dst_weight) @@ -92,7 +103,7 @@ def copy_singletons( beams_col: pymongo.collection.Collection, merge_name: str, ) -> list[pymongo.UpdateOne]: - """Copy islands that don't overlap other fields + """Copy islands that don't overlap other fields. Args: field_dict (Dict[str, str]): Field dictionary @@ -134,23 +145,32 @@ def copy_singletons( def genparset( - old_ims: list, + old_ims: list[Path], stokes: str, - new_dir: str, -) -> str: - imlist = "[" + ",".join([im.replace(".fits", "") for im in old_ims]) + "]" - weightlist = f"[{','.join([im.replace('.fits', '').replace('.image.restored.','.weights.').replace('.ion','') for im in old_ims])}]" + new_dir: Path, +) -> Path: + """Generate a linmos parset file. - im_outname = os.path.join(new_dir, os.path.basename(old_ims[0])).replace( - ".fits", ".edge.linmos" - ) + Args: + old_ims (list[Path]): Old images + stokes (str): Stokes parameter + new_dir (Path): Output directory + + Returns: + Path: Path to parset file + """ + imlist = "[" + ",".join([im.with_suffix("").as_posix() for im in old_ims]) + "]" + weightlist = f"[{','.join([im.with_suffix("").as_posix().replace('.image.restored.','.weights.').replace('.ion','') for im in old_ims])}]" + + im_outname = (new_dir / old_ims[0].name).with_suffix(".edge.linmos").as_posix() wt_outname = ( - os.path.join(new_dir, os.path.basename(old_ims[0])) - .replace(".fits", ".edge.linmos") + (new_dir / old_ims[0].name) + .with_suffix(".edge.linmos") + .as_posix() .replace(".image.restored.", ".weights.") ) - parset_file = os.path.join(new_dir, f"edge_linmos_{stokes}.in") + parset_file = new_dir / f"edge_linmos_{stokes}.in" parset = f"""# LINMOS parset linmos.names = {imlist} linmos.weights = {weightlist} @@ -162,16 +182,16 @@ def genparset( linmos.weightstate = Corrected """ - with open(parset_file, "w") as f: + with parset_file.open("w") as f: f.write(parset) return parset_file def merge_multiple_field( - beam: dict, field_dict: dict, merge_name: str, data_dir: Path, image: str + beam: dict, field_dict: dict[str, Path], merge_name: str, data_dir: Path, image: str ) -> list[pymongo.UpdateOne]: - """Merge an island that overlaps multiple fields + """Merge an island that overlaps multiple fields. Args: beam (dict): Beam document @@ -186,21 +206,21 @@ def merge_multiple_field( Returns: List[pymongo.UpdateOne]: Database updates """ - i_files_old = [] - q_files_old = [] - u_files_old = [] + i_files_old: list[Path] = [] + q_files_old: list[Path] = [] + u_files_old: list[Path] = [] updates = [] for field, vals in beam["beams"].items(): if field not in field_dict: continue field_dir = field_dict[field] try: - i_file_old = os.path.join(field_dir, vals["i_file"]) - q_file_old = os.path.join(field_dir, vals["q_file_ion"]) - u_file_old = os.path.join(field_dir, vals["u_file_ion"]) - except KeyError: + i_file_old = field_dir / str(vals["i_file"]) + q_file_old = field_dir / str(vals["q_file_ion"]) + u_file_old = field_dir / str(vals["u_file_ion"]) + except KeyError as e: msg = "Ion files not found. Have you run FRion?" - raise KeyError(msg) + raise KeyError(msg) from e i_files_old.append(i_file_old) q_files_old.append(q_file_old) u_files_old.append(u_file_old) @@ -225,7 +245,7 @@ def merge_multiple_fields( merge_name: str, image: str, ) -> list[pymongo.UpdateOne]: - """Merge multiple islands that overlap multiple fields + """Merge multiple islands that overlap multiple fields. Args: field_dict (Dict[str, str]): Field dictionary @@ -281,6 +301,22 @@ def main( password: str | None = None, yanda="1.3.0", ) -> str: + """Merge multiple RACS fields. + + Args: + fields (list[str]): List of field names. + field_dirs (list[Path]): List of field directories. + merge_name (str): Name of the merged field. + output_dir (Path): Output directory. + host (str): MongoDB host. + epoch (int): Epoch. + username (str | None, optional): MongoDB username. Defaults to None. + password (str | None, optional): MongoDB password. Defaults to None. + yanda (str, optional): Yandasoft version. Defaults to "1.3.0". + + Returns: + str: Intermediate directory + """ logger.debug(f"{fields=}") assert ( @@ -335,6 +371,15 @@ def main( def merge_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Merge parser. + + Args: + parent_parser (bool, optional): Whether this is a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: Merge parser + + """ # Help string to be shown using the -h option descStr = """ Mosaic RACS beam fields with linmos. @@ -398,8 +443,7 @@ def merge_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" - + """Command-line interface.""" m_parser = merge_parser(parent_parser=True) lin_parser = linmos_parser(parent_parser=True) diff --git a/arrakis/process_region.py b/arrakis/process_region.py index 4cacc065..be46206d 100755 --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Arrakis multi-field pipeline""" +"""Arrakis multi-field pipeline.""" from __future__ import annotations @@ -33,12 +33,13 @@ def process_merge( args: argparse.Namespace, host: str, inter_dir: Path, task_runner ) -> None: - """Workflow to merge spectra from overlapping fields together + """Workflow to merge spectra from overlapping fields together. Args: args (Namespace): Parameters to use for this process host (str): Address of the mongoDB servicing the processing inter_dir (Path): Location to store data from merged fields + task_runner (TaskRunner): Task runner to use for this process """ previous_future = None previous_future = ( @@ -141,7 +142,7 @@ def process_merge( def main(args: configargparse.Namespace) -> None: - """Main script + """Main script. Args: args (configargparse.Namespace): Command line arguments. @@ -177,6 +178,14 @@ def main(args: configargparse.Namespace) -> None: def pipeline_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Pipeline parser. + + Args: + parent_parser (bool, optional): Parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: Pipeline parser + """ descStr = f""" {logo_str} Arrakis regional pipeline. @@ -217,7 +226,7 @@ def pipeline_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" + """Command-line interface.""" # Help string to be shown using the -h option # Parse the command line options diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index ef8109ca..1eb830ec 100755 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Arrakis single-field pipeline""" +"""Arrakis single-field pipeline.""" from __future__ import annotations @@ -33,12 +33,15 @@ @flow(name="Combining+Synthesis on Arrakis") -def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: - """Workflow to process the SPICE-RACS data +def process_spice( + args: configargparse.Namespace, host: str, task_runner: BaseTaskRunner +) -> None: + """Workflow to process the SPICE-RACS data. Args: args (configargparse.Namespace): Configuration parameters for this run host (str): Host address of the mongoDB. + task_runner (BaseTaskRunner): Task runner for the workflow. """ outfile = f"{args.field}.pipe.test.fits" if args.outfile is None else args.outfile @@ -199,8 +202,7 @@ def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: def save_args(args: configargparse.Namespace) -> Path: - """Helper function to create a record of the input configuration arguments that - govern the pipeline instance + """Helper function to create a record of the input configuration arguments that govern the pipeline instance. Args: args (configargparse.Namespace): Supplied arguments for the Arrakis pipeline instance @@ -221,7 +223,7 @@ def create_dask_runner( dask_config: Path | None, overload: bool = False, ) -> DaskTaskRunner: - """Create a DaskTaskRunner + """Create a DaskTaskRunner. Args: dask_config (Path | None): Configuraiton file for the DaskTaskRunner @@ -260,7 +262,7 @@ def create_dask_runner( def main(args: configargparse.Namespace) -> None: - """Main script + """Main script. Args: args (configargparse.Namespace): Command line arguments. @@ -351,6 +353,15 @@ def main(args: configargparse.Namespace) -> None: def pipeline_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create the parser for the pipeline. + + Args: + parent_parser (bool, optional): If this is a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser for the pipeline. + + """ descStr = f""" {logo_str} @@ -408,7 +419,7 @@ def pipeline_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" + """Command-line interface.""" # Help string to be shown using the -h option pipe_parser = pipeline_parser(parent_parser=True) diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index 15953d1b..9821fe59 100755 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Run RM-synthesis on cutouts in parallel""" +"""Run RM-synthesis on cutouts in parallel.""" from __future__ import annotations @@ -47,7 +47,7 @@ def rmclean1d( rm_verbose=True, window=None, ) -> pymongo.UpdateOne: - """1D RM-CLEAN + """1D RM-CLEAN. Args: field (str): RACS field name. @@ -56,8 +56,10 @@ def rmclean1d( cutoff (float, optional): CLEAN cutouff (in sigma). Defaults to -3. maxIter (int, optional): Maximum CLEAN interation. Defaults to 10000. gain (float, optional): CLEAN gain. Defaults to 0.1. + sbid (int, optional): SBID. Defaults to None. savePlots (bool, optional): Save CLEAN plots. Defaults to False. rm_verbose (bool, optional): Verbose RM-CLEAN. Defaults to True. + window (float, optional): Further CLEAN in mask to this threshold. Defaults to None. Returns: pymongo.UpdateOne: MongoDB update query. @@ -154,11 +156,13 @@ def rmclean3d( gain=0.1, rm_verbose=False, ) -> pymongo.UpdateOne: - """Run RM-CLEAN on 3D cube + """Run RM-CLEAN on 3D cube. Args: + field (str): RACS field name. island (dict): MongoDB island entry. outdir (Path): Output directory. + sbid (int, optional): SBID. Defaults to None. cutoff (float, optional): CLEAN cutoff (in sigma). Defaults to -3. maxIter (int, optional): Max CLEAN iterations. Defaults to 10000. gain (float, optional): CLEAN gain. Defaults to 0.1. @@ -167,7 +171,6 @@ def rmclean3d( Returns: pymongo.UpdateOne: MongoDB update query. """ - iname = island["Source_ID"] prefix = f"{iname}_" rm3dfiles = island["rm_outputs_3d"]["rm3dfiles"] @@ -228,12 +231,14 @@ def main( window=None, rm_verbose=False, ): - """Run RM-CLEAN on cutouts flow + """Run RM-CLEAN on cutouts flow. Args: field (str): RACS field name. outdir (Path): Output directory. host (str): MongoDB host IP. + epoch (int): Epoch number. + sbid (int, optional): SBID. Defaults to None. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. dimension (str, optional): Which dimension to run RM-CLEAN. Defaults to "1d". @@ -245,6 +250,7 @@ def main( cutoff (float, optional): CLEAN cutoff (in sigma). Defaults to -3. maxIter (int, optional): Max CLEAN iterations. Defaults to 10000. gain (float, optional): Clean gain. Defaults to 0.1. + window (float, optional): Further CLEAN in mask to this threshold. Defaults to None. rm_verbose (bool, optional): Verbose output from RM-CLEAN. Defaults to False. """ outdir = outdir.absolute() / "cutouts" @@ -415,6 +421,15 @@ def main( def clean_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create a parser for RM-CLEAN on cutouts. + + Args: + parent_parser (bool, optional): Parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser. + + """ # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -458,8 +473,7 @@ def clean_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" - + """Command-line interface.""" from astropy.utils.exceptions import AstropyWarning warnings.simplefilter("ignore", category=AstropyWarning) diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 28e953fc..fce1440c 100755 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Run RM-CLEAN on cutouts in parallel""" +"""Run RM-CLEAN on cutouts in parallel.""" from __future__ import annotations @@ -49,7 +49,7 @@ class Spectrum(Struct): - """Single spectrum""" + """Single spectrum.""" data: np.ndarray """The spectrum data""" @@ -64,7 +64,7 @@ class Spectrum(Struct): class StokesSpectra(Struct): - """Multi Stokes spectra""" + """Multi Stokes spectra.""" i: Spectrum """The Stokes I spectrum""" @@ -75,7 +75,17 @@ class StokesSpectra(Struct): class StokesIFitResult(Struct): - """Stokes I fit results""" + """Stokes I fit results. + + Attributes: + alpha (float | None): The alpha parameter of the fit + amplitude (float | None): The amplitude parameter of the fit + x_0 (float | None): The x_0 parameter of the fit + model_repr (str | None): The model representation of the fit + modStokesI (np.ndarray | None): The model Stokes I spectrum + fit_dict (dict | None): The dictionary of the fit results + + """ alpha: float | None """The alpha parameter of the fit""" @@ -108,21 +118,26 @@ def rmsynthoncut3d( rm_verbose: bool = False, ion: bool = False, ) -> pymongo.UpdateOne: - """3D RM-synthesis + """3D RM-synthesis. Args: - island_id (str): RACS Island ID - freq (list): Frequencies in Hz - host (str): Host of MongoDB - field (str): RACS field ID - database (bool, optional): Update MongoDB. Defaults to False. - phiMax_radm2 (float, optional): Max Faraday depth. Defaults to None. - dPhi_radm2 (float, optional): Faraday dpeth channel width. Defaults to None. - nSamples (int, optional): Samples acorss RMSF. Defaults to 5. - weightType (str, optional): Weighting type. Defaults to 'variance'. - fitRMSF (bool, optional): Fit RMSF. Defaults to False. - not_RMSF (bool, optional): Skip calculation of RMSF. Defaults to False. - rm_verbose (bool, optional): Verbose RMsynth. Defaults to False. + island_id (str): Island ID + beam_tuple (tuple[str, pd.Series]): Beam tuple + outdir (Path): Output directory + freq (np.ndarray): Frequencies in Hz + field (str): Field name + sbid (int | None, optional): SBID. Defaults to None. + phiMax_radm2 (float | None, optional): Maximum Faraday depth. Defaults to None. + dPhi_radm2 (float | None, optional): Faraday depth spacing. Defaults to None. + nSamples (int, optional): Number of RMSF samples. Defaults to 5. + weightType (str, optional): Weight type. Defaults to "variance". + fitRMSF (bool, optional): Fit the RMSF. Defaults to True. + not_RMSF (bool, optional): Ignore the RMSF. Defaults to False. + rm_verbose (bool, optional): Verbose output (RM-Tools). Defaults to False. + ion (bool, optional): Use ion files. Defaults to False. + + Returns: + pymongo.UpdateOne: MongoDB update operation """ beam = dict(beam_tuple[1]) iname = island_id @@ -223,7 +238,7 @@ def rmsynthoncut3d( def cubelet_bane(cubelet: np.ndarray, header: fits.Header) -> tuple[np.ndarray]: - """Background and noise estimation on a cubelet + """Background and noise estimation on a cubelet. Args: cubelet (np.ndarray): 3D array of data @@ -273,7 +288,7 @@ def extract_single_spectrum( field_dict: dict, outdir: Path, ) -> Spectrum: - """Extract a single spectrum from a cubelet""" + """Extract a single spectrum from a cubelet.""" key = f"{stokes}_file_ion" if ion and stokes in ("q", "u") else f"{stokes}_file" filename = outdir / field_dict[key] with fits.open(filename, mode="denywrite", memmap=True) as hdulist: @@ -311,7 +326,7 @@ def extract_all_spectra( field_dict: dict, outdir: Path, ) -> StokesSpectra: - """Extract spectra from cubelets""" + """Extract spectra from cubelets.""" return StokesSpectra( *[ extract_single_spectrum( @@ -329,7 +344,7 @@ def extract_all_spectra( def sigma_clip_spectra( stokes_spectra: StokesSpectra, ) -> StokesSpectra: - """Sigma clip spectra + """Sigma clip spectra. Find outliers in the RMS spectra and set them to NaN @@ -375,6 +390,21 @@ def fit_stokes_I( rmsi: np.ndarray | None = None, polyOrd: int | None = None, ) -> StokesIFitResult: + """Fit stokes I spectrum. + + Args: + freq (np.ndarray): Frequencies in Hz + coord (SkyCoord): Component coordinate + tt0 (str | None, optional): Path to TT0 image. Defaults to None. + tt1 (str | None, optional): Path to TT1 image. Defaults to None. + do_own_fit (bool, optional): Peform own fit (not RM-Tools). Defaults to False. + iarr (np.ndarray | None, optional): Stokes I array. Defaults to None. + rmsi (np.ndarray | None, optional): Stokes I rms array. Defaults to None. + polyOrd (int | None, optional): Polynomial order. Defaults to None. + + Returns: + StokesIFitResult: alpha, amplitude, x_0, model_repr, modStokesI, fit_dict + """ if tt0 and tt1: mfs_i_0 = fits.getdata(tt0, memmap=True) mfs_i_1 = fits.getdata(tt1, memmap=True) @@ -427,7 +457,7 @@ def update_rmtools_dict( mDict: dict, fit_dict: dict, ) -> dict: - """Update the RM-Tools dictionary with the fit results from the Stokes I fit + """Update the RM-Tools dictionary with the fit results from the Stokes I fit. Args: mDict (dict): The RM-Tools dictionary @@ -496,27 +526,34 @@ def rmsynthoncut1d( ion: bool = False, do_own_fit: bool = False, ) -> pymongo.UpdateOne: - """1D RM synthesis + """1D RM-synthesis. Args: - comp_id (str): RACS component ID - outdir (str): Output directory - freq (list): Frequencies in Hz - host (str): MongoDB host - field (str): RACS field - sbid (int, optional): SBID. Defaults to None. - database (bool, optional): Update MongoDB. Defaults to False. - polyOrd (int, optional): Order of fit to I. Defaults to 3. - phiMax_radm2 (float, optional): Max FD. Defaults to None. - dPhi_radm2 (float, optional): Delta FD. Defaults to None. - nSamples (int, optional): Samples across RMSF. Defaults to 5. - weightType (str, optional): Weight type. Defaults to 'variance'. - fitRMSF (bool, optional): Fit RMSF. Defaults to False. + comp_tuple (tuple[str, pd.Series]): Component tuple + beam_tuple (tuple[str, pd.Series]): Beam tuple + outdir (Path): Output directory + freq (np.ndarray): Frequencies in Hz + field (str): Field name + sbid (int | None, optional): SBID. Defaults to None. + polyOrd (int, optional): Polynomial order. Defaults to 3. + phiMax_radm2 (float | None, optional): Max faraday depth. Defaults to None. + dPhi_radm2 (float | None, optional): Faraday depth spacing. Defaults to None. + nSamples (int, optional): Number of RMSF samples. Defaults to 5. + weightType (str, optional): Weight type. Defaults to "variance". + fitRMSF (bool, optional): Fit the RMSF. Defaults to True. noStokesI (bool, optional): Ignore Stokes I. Defaults to False. showPlots (bool, optional): Show plots. Defaults to False. savePlots (bool, optional): Save plots. Defaults to False. - debug (bool, optional): Turn on debug plots. Defaults to False. - rm_verbose (bool, optional): Verbose RMsynth. Defaults to False. + debug (bool, optional): Turn on debug output (RM-Tools). Defaults to False. + rm_verbose (bool, optional): Turn on verbose output (RM-Tools). Defaults to False. + fit_function (str, optional): Type of fit function. Defaults to "log". + tt0 (str | None, optional): Path to TT0 image. Defaults to None. + tt1 (str | None, optional): Path to TT1 image. Defaults to None. + ion (bool, optional): If ion files are used. Defaults to False. + do_own_fit (bool, optional): Do own fit (not RM-Tools). Defaults to False. + + Returns: + pymongo.UpdateOne: MongoDB update operation """ logger.setLevel(logging.INFO) save_name = field if sbid is None else f"{field}_{sbid}" @@ -761,7 +798,7 @@ def main( ion: bool = False, do_own_fit: bool = False, ) -> None: - """Run RMsynth on cutouts flow + """Run RMsynth on cutouts flow. Args: field (str): RACS field @@ -1036,6 +1073,14 @@ def main( def rm_common_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create an argument parser for common RM-synthesis arguments. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The argument parser + """ common_parser = argparse.ArgumentParser( formatter_class=UltimateHelpFormatter, add_help=not parent_parser, @@ -1057,6 +1102,14 @@ def rm_common_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def rmsynth_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create an argument parser for RM-synthesis. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The argument parser + """ # Help string to be shown using the -h option descStr = f""" {logo_str} @@ -1160,8 +1213,7 @@ def rmsynth_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): - """Command-line interface""" - + """Command-line interface.""" from astropy.utils.exceptions import AstropyWarning warnings.simplefilter("ignore", category=AstropyWarning) diff --git a/arrakis/utils/__init__.py b/arrakis/utils/__init__.py index e69de29b..d085abb1 100644 --- a/arrakis/utils/__init__.py +++ b/arrakis/utils/__init__.py @@ -0,0 +1 @@ +"""Arrakis utilities.""" diff --git a/arrakis/utils/coordinates.py b/arrakis/utils/coordinates.py index 3caa185c..4461b53c 100644 --- a/arrakis/utils/coordinates.py +++ b/arrakis/utils/coordinates.py @@ -1,4 +1,4 @@ -"""Coordinate utilities""" +"""Coordinate utilities.""" from __future__ import annotations @@ -48,7 +48,7 @@ def deg_to_dms(deg: float) -> dms_tuple: def coord_to_string(coord: SkyCoord) -> tuple[str, str]: - """Convert coordinate to string without astropy + """Convert coordinate to string without astropy. Args: coord (SkyCoord): Coordinate diff --git a/arrakis/utils/database.py b/arrakis/utils/database.py index eaa28165..3ee7e7dc 100644 --- a/arrakis/utils/database.py +++ b/arrakis/utils/database.py @@ -1,4 +1,4 @@ -"""Database utilities""" +"""Database utilities.""" from __future__ import annotations @@ -16,7 +16,7 @@ def validate_sbid_field_pair(field_name: str, sbid: int, field_col: Collection) -> bool: - """Validate field and sbid pair + """Validate field and sbid pair. Args: field_name (str): Field name. @@ -38,7 +38,7 @@ def validate_sbid_field_pair(field_name: str, sbid: int, field_col: Collection) def test_db( host: str, username: str | None = None, password: str | None = None ) -> bool: - """Test connection to MongoDB + """Test connection to MongoDB. Args: host (str): Mongo host IP. @@ -79,10 +79,11 @@ def get_db( username: str | None = None, password: str | None = None, ) -> tuple[Collection, Collection, Collection]: - """Get MongoDBs + """Get MongoDBs. Args: host (str): Mongo host IP. + epoch (int): Epoch number. username (str, optional): Username. Defaults to None. password (str, optional): Password. Defaults to None. @@ -104,10 +105,11 @@ def get_db( def get_field_db(host: str, epoch: int, username=None, password=None) -> Collection: - """Get MongoDBs + """Get MongoDBs. Args: host (str): Mongo host IP. + epoch (int): Epoch. username (str, optional): Username. Defaults to None. password (str, optional): Password. Defaults to None. @@ -126,10 +128,11 @@ def get_field_db(host: str, epoch: int, username=None, password=None) -> Collect def get_beam_inf_db(host: str, epoch: int, username=None, password=None) -> Collection: - """Get MongoDBs + """Get MongoDBs. Args: host (str): Mongo host IP. + epoch (int): Epoch number. username (str, optional): Username. Defaults to None. password (str, optional): Password. Defaults to None. diff --git a/arrakis/utils/exceptions.py b/arrakis/utils/exceptions.py index 4e01db2b..4fd5e636 100644 --- a/arrakis/utils/exceptions.py +++ b/arrakis/utils/exceptions.py @@ -1,4 +1,4 @@ -"""Errors and exceptions""" +"""Errors and exceptions.""" from __future__ import annotations @@ -12,7 +12,7 @@ class Error(OSError): - pass + """Base class for all exceptions raised by this module.""" class SameFileError(Error): @@ -20,18 +20,16 @@ class SameFileError(Error): class SpecialFileError(OSError): - """Raised when trying to do a kind of operation (e.g. copying) which is - not supported on a special file (e.g. a named pipe)""" + """Raised when trying to do a kind of operation (e.g. copying) which is not supported on a special file (e.g. a named pipe).""" class ExecError(OSError): - """Raised when a command could not be executed""" + """Raised when a command could not be executed.""" class ReadError(OSError): - """Raised when an archive cannot be read""" + """Raised when an archive cannot be read.""" class RegistryError(Exception): - """Raised when a registry operation with the archiving - and unpacking registeries fails""" + """Raised when a registry operation with the archiving and unpacking registeries fails.""" diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index 1274f38c..b25ca63f 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -1,4 +1,4 @@ -"""FITS utilities""" +"""FITS utilities.""" from __future__ import annotations @@ -43,7 +43,7 @@ def head2dict(h: fits.Header) -> dict[str, Any]: def fix_header(cutout_header: fits.Header, original_header: fits.Header) -> fits.Header: - """Make cutout header the same as original header + """Make cutout header the same as original header. Args: cutout_header (fits.Header): Cutout header diff --git a/arrakis/utils/fitting.py b/arrakis/utils/fitting.py index b8bd3bb3..96171855 100644 --- a/arrakis/utils/fitting.py +++ b/arrakis/utils/fitting.py @@ -1,4 +1,4 @@ -"""Fitting utilities""" +"""Fitting utilities.""" from __future__ import annotations @@ -23,6 +23,7 @@ def fitted_mean(data: np.ndarray, axis: int | None = None) -> float: Args: data (np.ndarray): Data array. + axis (int | None, optional): Axis to calculate along. Defaults to None. Returns: float: Mean. @@ -39,6 +40,7 @@ def fitted_std(data: np.ndarray, axis: int | None = None) -> float: Args: data (np.ndarray): Data array. + axis (int | None, optional): Axis to calculate along. Defaults to None. Returns: float: Standard deviation. diff --git a/arrakis/utils/io.py b/arrakis/utils/io.py index deebff25..04b4a7ca 100644 --- a/arrakis/utils/io.py +++ b/arrakis/utils/io.py @@ -1,4 +1,4 @@ -"""I/O utilities""" +"""I/O utilities.""" from __future__ import annotations @@ -23,7 +23,15 @@ def verify_tarball( tarball: str | Path, -): +) -> bool: + """Verify a tarball. + + Args: + tarball (str | Path): Path to tarball. + + Returns: + bool: If tarball is valid. + """ cmd = f"tar -tvf {tarball}" logger.info(f"Verifying tarball {tarball}") popen = sp.Popen(shlex.split(cmd), stderr=sp.PIPE) @@ -48,25 +56,22 @@ def parse_env_path(env_path: PathLike) -> Path: return Path(os.path.expandvars(env_path)) -def rsync(src, tgt): - os.system(f"rsync -rPvh {src} {tgt}") - +def rsync(src: str | Path, tgt: str | Path): + """Rsync a source to a target. -def prsync(wild_src: str, tgt: str, ncores: int): - os.system(f"ls -d {wild_src} | xargs -n 1 -P {ncores} -I% rsync -rvh % {tgt}") + Args: + src (str | Path): Source path + tgt (str | Path): Target path + """ + os.system(f"rsync -rPvh {src} {tgt}") -def try_symlink(src: str, dst: str): - """Create symlink if it doesn't exist +def prsync(wild_src: str, tgt: str | Path, ncores: int): + """Parallel rsync a source to a target. Args: - src (str): Source path - dst (str): Destination path - verbose (bool, optional): Verbose output. Defaults to True. + wild_src (str): Wildcard source path + tgt (str | Path): Target path + ncores (int): Number of cores """ - # Create output dir if it doesn't exist - try: - os.symlink(src, dst) - logger.info(f"Made symlink '{dst}'.") - except FileExistsError: - logger.info(f"Symlink '{dst}' exists.") + os.system(f"ls -d {wild_src} | xargs -n 1 -P {ncores} -I% rsync -rvh % {tgt}") diff --git a/arrakis/utils/json.py b/arrakis/utils/json.py index e5b6d651..e512a6b6 100644 --- a/arrakis/utils/json.py +++ b/arrakis/utils/json.py @@ -1,4 +1,4 @@ -"""JSON utilities""" +"""JSON utilities.""" from __future__ import annotations @@ -20,6 +20,15 @@ class MyEncoder(json.JSONEncoder): """ def default(self, obj): # pylint: disable=E0202 + """Custom JSON encoder. + + Args: + obj (Any): Object to encode. + + Returns: + Any: Encoded object. + + """ if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): diff --git a/arrakis/utils/meta.py b/arrakis/utils/meta.py index ccbb2e22..b38cf72b 100644 --- a/arrakis/utils/meta.py +++ b/arrakis/utils/meta.py @@ -1,4 +1,4 @@ -"""Generic program utilities""" +"""Generic program utilities.""" from __future__ import annotations @@ -15,17 +15,35 @@ # From https://stackoverflow.com/questions/58065055/floor-and-ceil-with-number-of-decimals#:~:text=The%20function%20np.,a%20number%20with%20zero%20decimals. -def my_ceil(a, precision=0): +def my_ceil[T](a: T, precision=0) -> T: + """Ceil a number to a given precision. + + Args: + a (T): A numeric value to ceil + precision (int, optional): Precision of ceil. Defaults to 0. + + Returns: + T: The ceil of a number + """ return np.true_divide(np.ceil(a * 10**precision), 10**precision) -def my_floor(a, precision=0): +def my_floor[T](a: T, precision=0) -> T: + """Floor a number to a given precision. + + Args: + a (T): A numeric value to floor + precision (int, optional): Precision of floor. Defaults to 0. + + Returns: + T: The floor of a number. + """ return np.true_divide(np.floor(a * 10**precision), 10**precision) # From https://stackoverflow.com/questions/1176136/convert-string-to-python-class-object def class_for_name(module_name: str, class_name: str) -> object: - """Returns a class object given a module name and class name + """Returns a class object given a module name and class name. Args: module_name (str): Module name @@ -42,6 +60,17 @@ def class_for_name(module_name: str, class_name: str) -> object: # stolen from https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python def zip_equal(*iterables): + """Zip iterables and assert they are the same length. + + Args: + *iterables: Iterables to zip + + Yields: + tuple: Zipped iterables + + Raises: + ValueError: If iterables have different lengths + """ sentinel = object() for combo in zip_longest(*iterables, fillvalue=sentinel): if sentinel in combo: @@ -51,7 +80,7 @@ def zip_equal(*iterables): def yes_or_no(question: str) -> bool: - """Ask a yes or no question via input() + """Ask a yes or no question via input(). Args: question (str): Question to ask diff --git a/arrakis/utils/msutils.py b/arrakis/utils/msutils.py index 581b8358..f47c825d 100644 --- a/arrakis/utils/msutils.py +++ b/arrakis/utils/msutils.py @@ -1,4 +1,4 @@ -"""MeasurementSet utilities""" +"""MeasurementSet utilities.""" from __future__ import annotations @@ -20,9 +20,9 @@ def get_pol_axis( ms: Path, feed_idx: int | None = None, col: str = "RECEPTOR_ANGLE" ) -> u.Quantity: - """Get the polarization axis from the ASKAP MS. Checks are performed - to ensure this polarisation axis angle is constant throughout the observation. + """Get the polarization axis from the ASKAP MS. + Checks are performed to ensure this polarisation axis angle is constant throughout the observation. Args: ms (Path): The path to the measurement set that will be inspected @@ -59,14 +59,14 @@ def get_pol_axis( def beam_from_ms(ms: str) -> int: - """Work out which beam is in this MS""" + """Work out which beam is in this MS.""" with table(ms, readonly=True, ack=False) as t: vis_feed = t.getcol("FEED1", 0, 1) return vis_feed[0] def field_idx_from_ms(ms: str) -> int: - """Get the field from MS metadata""" + """Get the field from MS metadata.""" with table(f"{ms}/FIELD", readonly=True, ack=False) as field: idxs = list(field.SOURCE_ID) assert len(idxs) == 1 or all( @@ -76,7 +76,7 @@ def field_idx_from_ms(ms: str) -> int: def field_name_from_ms(ms: str) -> str: - """Get the field name from MS metadata""" + """Get the field name from MS metadata.""" with table(f"{ms}/FIELD", readonly=True, ack=False) as field: names = list(field.NAME) assert len(names) == 1, "More than one field in MS" @@ -241,8 +241,10 @@ def wsclean( elliptical_beam: bool = False, ) -> str: """Construct a wsclean command. + If False or None is passed as a parameter, the parameter is not included in the command (i.e. wsclean will assume a default value). + Args: mslist (list): List of MSs to be processed. use_mpi (bool): Use wsclean-mp for parallel processing. @@ -740,10 +742,10 @@ def wsclean( bmin will be set to bmaj. Defaults to False. elliptical_beam (bool, optional): Allow the beam to be elliptical. Default. Defaults to False. + Returns: str: WSClean command """ - arguments = copy.deepcopy(locals()) mslist = arguments.pop("mslist") use_mpi = arguments.pop("use_mpi") diff --git a/arrakis/utils/pipeline.py b/arrakis/utils/pipeline.py index f5b5a536..92dc7216 100644 --- a/arrakis/utils/pipeline.py +++ b/arrakis/utils/pipeline.py @@ -1,4 +1,4 @@ -"""Pipeline and flow utility functions""" +"""Pipeline and flow utility functions.""" from __future__ import annotations @@ -24,7 +24,6 @@ from distributed.utils import LoopRunner from prefect import task from prefect.artifacts import create_markdown_artifact -from prefect_dask import get_dask_client from spectral_cube.utils import SpectralCubeWarning from tornado.ioloop import IOLoop from tqdm.auto import tqdm, trange @@ -59,8 +58,9 @@ def upload_image_as_artifact_task( image_path: Path, description: str | None = None ) -> UUID: - """Create and submit a markdown artifact tracked by prefect for an - input image. Currently supporting png formatted images. + """Create and submit a markdown artifact tracked by prefect for an input image. + + Currently supporting png formatted images. The input image is converted to a base64 encoding, and embedded directly within the markdown string. Therefore, be mindful of the image size as this @@ -95,6 +95,14 @@ def upload_image_as_artifact_task( def workdir_arg_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create a workdir parser. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser. + """ # Parse the command line options work_parser = argparse.ArgumentParser( add_help=not parent_parser, @@ -111,6 +119,14 @@ def workdir_arg_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def generic_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Create a generic parser. + + Args: + parent_parser (bool, optional): If a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser. + """ descStr = f""" {logo_str} Generic pipeline options @@ -184,85 +200,10 @@ def generic_parser(parent_parser: bool = False) -> argparse.ArgumentParser: return gen_parser -class performance_report_prefect: - """Gather performance report from prefect_dask - - Basically stolen from: - https://distributed.dask.org/en/latest/_modules/distributed/client.html#performance_report - - This creates a static HTML file that includes many of the same plots of the - dashboard for later viewing. - - The resulting file uses JavaScript, and so must be viewed with a web - browser. Locally we recommend using ``python -m http.server`` or hosting - the file live online. - - Parameters - ---------- - filename: str, optional - The filename to save the performance report locally - - stacklevel: int, optional - The code execution frame utilized for populating the Calling Code section - of the report. Defaults to `1` which is the frame calling ``performance_report_prefect`` - - mode: str, optional - Mode parameter to pass to :func:`bokeh.io.output.output_file`. Defaults to ``None``. - - storage_options: dict, optional - Any additional arguments to :func:`fsspec.open` when writing to a URL. - - Examples - -------- - >>> with performance_report_prefect(filename="myfile.html", stacklevel=1): - ... x.compute() - """ - - def __init__( - self, filename="dask-report.html", stacklevel=1, mode=None, storage_options=None - ): - self.filename = filename - # stacklevel 0 or less - shows dask internals which likely isn't helpful - self._stacklevel = stacklevel if stacklevel > 0 else 1 - self.mode = mode - self.storage_options = storage_options or {} - - async def __aenter__(self): - self.start = time.time() - with get_dask_client() as client: - self.last_count = client.run_on_scheduler( - lambda dask_scheduler: dask_scheduler.monitor.count - ) - client.get_task_stream(start=0, stop=0) # ensure plugin - - async def __aexit__(self, exc_type, exc_value, traceback, code=None): - import fsspec - - with get_dask_client() as client: - if code is None: - code = client._get_computation_code(self._stacklevel + 1) - data = await client.scheduler.performance_report( - start=self.start, last_count=self.last_count, code=code, mode=self.mode - ) - with fsspec.open( - self.filename, mode="w", compression="infer", **self.storage_options - ) as f: - f.write(data) - - def __enter__(self): - with get_dask_client() as client: - client.sync(self.__aenter__) - - def __exit__(self, exc_type, exc_value, traceback): - with get_dask_client() as client: - code = client._get_computation_code(self._stacklevel + 1) - client.sync(self.__aexit__, exc_type, exc_value, traceback, code=code) - - def inspect_client( client: distributed.Client | None = None, ) -> tuple[str, int, int, u.Quantity, int, u.Quantity]: - """_summary_ + """_summary_. Args: client (Union[distributed.Client,None]): Dask client to inspect. @@ -294,6 +235,18 @@ def chunk_dask( progress_text="", verbose=True, ) -> list: + """Run a task in chunks. + + Args: + outputs (list): List of outputs to chunk + batch_size (int, optional): Chunk size. Defaults to 10_000. + task_name (str, optional): Name of task. Defaults to "". + progress_text (str, optional): Description of task. Defaults to "". + verbose (bool, optional): Verbose output. Defaults to True. + + Returns: + list: Completed futures + """ client = get_client() chunk_outputs = [] for i in trange( @@ -312,7 +265,7 @@ def chunk_dask( def delayed_to_da(list_of_delayed: list[Delayed], chunk: int | None = None) -> da.Array: - """Convert list of delayed arrays to a dask array + """Convert list of delayed arrays to a dask array. Args: list_of_delayed (List[delayed]): List of delayed objects @@ -333,7 +286,7 @@ def delayed_to_da(list_of_delayed: list[Delayed], chunk: int | None = None) -> d # stolen from https://github.com/tqdm/tqdm/issues/278 class TqdmProgressBar(ProgressBar): - """Tqdm for Dask""" + """Tqdm for Dask.""" def __init__( self, @@ -345,6 +298,19 @@ def __init__( start=True, **tqdm_kwargs, ): + """Make a Tqdm progress bar. + + Args: + keys (Any): Iterable of keys to track + scheduler (Any | None, optional): scheduler. Defaults to None. + interval (str, optional): update interval. Defaults to "100ms". + loop (Any | None, optional): Loop. Defaults to None. + complete (bool, optional): Complete. Defaults to True. + start (bool, optional): Start. Defaults to True. + + kwargs: + **tqdm_kwargs: Tqdm keyword arguments + """ super().__init__(keys, scheduler, interval, complete) self.tqdm = tqdm(keys, **tqdm_kwargs) self.loop = loop or IOLoop() @@ -364,7 +330,7 @@ def _draw_stop(self, **kwargs): def tqdm_dask(futures_in: distributed.Future, **kwargs) -> None: - """Tqdm for Dask futures""" + """Tqdm for Dask futures.""" futures = futures_of(futures_in) if not isinstance(futures, (set, list)): futures = [futures] @@ -372,7 +338,7 @@ def tqdm_dask(futures_in: distributed.Future, **kwargs) -> None: def port_forward(port: int, target: str) -> None: - """Forward ports to local host + """Forward ports to local host. Args: port (int): port to forward diff --git a/arrakis/utils/plotting.py b/arrakis/utils/plotting.py index 8c8fd2f2..ae466cb4 100644 --- a/arrakis/utils/plotting.py +++ b/arrakis/utils/plotting.py @@ -1,4 +1,4 @@ -"""Plotting utilities""" +"""Plotting utilities.""" from __future__ import annotations @@ -15,13 +15,14 @@ def latexify(fig_width=None, fig_height=None, columns=1): """Set up matplotlib's RC params for LaTeX plotting. + Call this before plotting a figure. - Parameters - ---------- - fig_width : float, optional, inches - fig_height : float, optional, inches - columns : {1, 2} + Args: + fig_width (float, optional): Figure width. Defaults to None. + fig_height (float, optional): Figure height. Defaults to None. + columns (int, optional): Number of columns. Defaults + """ from math import sqrt diff --git a/arrakis/utils/typing.py b/arrakis/utils/typing.py index 19b6ee2f..47350450 100644 --- a/arrakis/utils/typing.py +++ b/arrakis/utils/typing.py @@ -1,9 +1,8 @@ -"""Typing utilities""" +"""Typing utilities.""" from __future__ import annotations from pathlib import Path -from typing import TypeVar import numpy as np import pandas as pd @@ -12,9 +11,6 @@ from astropy.units import Quantity from rmtable import RMTable -ArrayLike = TypeVar( - "ArrayLike", np.ndarray, pd.Series, pd.DataFrame, SkyCoord, Quantity -) -TableLike = TypeVar("TableLike", RMTable, Table) -PathLike = TypeVar("PathLike", str, Path) -T = TypeVar("T") +ArrayLike = np.ndarray | pd.Series | pd.DataFrame | SkyCoord | Quantity +TableLike = RMTable | Table +PathLike = str | Path diff --git a/arrakis/validate.py b/arrakis/validate.py index 82510c2c..7520deb3 100755 --- a/arrakis/validate.py +++ b/arrakis/validate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Make validation plots from a catalogue""" +"""Make validation plots from a catalogue.""" from __future__ import annotations @@ -30,7 +30,13 @@ class GriddedMap(Struct): - """Gridded catalogue data""" + """Gridded catalogue data. + + Attributes: + data (np.ndarray): Gridded data + wcs (WCS): WCS of the gridded data + + """ data: np.ndarray """Gridded data""" @@ -39,7 +45,15 @@ class GriddedMap(Struct): class BinnedMap(Struct): - """Binned catalogue data""" + """Binned catalogue data. + + Attributes: + data (np.ndarray): Binned data + xc (np.ndarray): X bin centres + yc (np.ndarray): Y bin centres + wcs (WCS): WCS of the b + + """ data: np.ndarray """Binned data""" @@ -54,6 +68,17 @@ class BinnedMap(Struct): def make_gridded_map( tab: Table, column: str, npix: int = 512, map_size: u.Quantity = 8 * u.deg ) -> GriddedMap: + """Make a gridded map from a table. + + Args: + tab (Table): The table. + column (str): Reference column. + npix (int, optional): Number of pixels. Defaults to 512. + map_size (u.Quantity, optional): Angular size of map. Defaults to 8*u.deg. + + Returns: + GriddedMap: data, wcs + """ logger.info(f"Making gridded map for {column}") coords = SkyCoord(ra=tab["ra"], dec=tab["dec"], unit="deg") coarse_shape = (npix, npix) @@ -105,6 +130,14 @@ def make_gridded_map( def filter_then_median(arr: T) -> T: + """Filter then median. + + Args: + arr (T): Array of data. + + Returns: + T: Filtered then medianed data. + """ arr_clip = sigma_clip( arr, maxiters=None, sigma=3, cenfunc=np.nanmedian, stdfunc=mad_std ) @@ -118,6 +151,18 @@ def make_binned_map( npix: int = 512, map_size: u.Quantity = 8 * u.deg, ) -> BinnedMap: + """Make a binned map from a table. + + Args: + tab (Table): The table. + column (str): Reference column. + bins (int, optional): Number of bins. Defaults to 15. + npix (int, optional): Number of pixels. Defaults to 512. + map_size (u.Quantity, optional): Angular size of map. Defaults to 8*u.deg. + + Returns: + BinnedMap: data, xc, yc, wcs + """ logger.info(f"Making binned map for {column}") coords = SkyCoord(ra=tab["ra"], dec=tab["dec"], unit="deg") coarse_shape = (npix, npix) @@ -159,6 +204,16 @@ def plot_rms_bkg( npix: int = 512, map_size: u.Quantity = 8 * u.deg, ) -> Figure: + """Make RMS and background plots. + + Args: + tab (Table): Catalogue table. + npix (int, optional): Number of pixels. Defaults to 512. + map_size (u.Quantity, optional): Angular size of map. Defaults to 8*u.deg. + + Returns: + Figure: The RMS and background plot. + """ err_bkg_dict = {} for stokes in "IQU": err_bkg_dict[stokes] = {} @@ -230,6 +285,18 @@ def plot_leakage( npix: int = 512, map_size: u.Quantity = 8 * u.deg, ) -> Figure: + """Make a leakage plot. + + Args: + tab (Table): Catalogue table. + snr_cut (float, optional): SNR cut. Defaults to 50. + bins (int, optional): Number of bins. Defaults to 11. + npix (int, optional): Number of pixels. Defaults to 512. + map_size (u.Quantity, optional): Angular size of map. Defaults to 8*u.deg. + + Returns: + Figure: The leakage plot. + """ hi_i_tab = tab[tab["stokesI"] / tab["stokesI_err"] > snr_cut] hi_i_tab["stokesQ_frac"] = hi_i_tab["stokesQ"] / hi_i_tab["stokesI"] hi_i_tab["stokesU_frac"] = hi_i_tab["stokesU"] / hi_i_tab["stokesI"] @@ -278,7 +345,17 @@ def plot_leakage( def cross_match( my_tab: Table, other_tab: Table, radius: u.Quantity = 1 * u.arcsec -) -> Table: +) -> tuple[Table, Table]: + """Cross-match two tables. + + Args: + my_tab (Table): Our table. + other_tab (Table): Their table. + radius (u.Quantity, optional): Radius of crossmastch. Defaults to 1*u.arcsec. + + Returns: + tuple[Table, Table]: Our matches, their matches. + """ my_coords = SkyCoord(ra=my_tab["ra"], dec=my_tab["dec"], unit="deg") other_coords = SkyCoord(ra=other_tab["ra"], dec=other_tab["dec"], unit="deg") idx, d2d, _ = my_coords.match_to_catalog_sky(other_coords) @@ -294,6 +371,16 @@ def plot_rm( npix: int = 512, map_size: u.Quantity = 8 * u.deg, ) -> Figure: + """RM bubble plot. + + Args: + tab (Table): Catalogue table. + npix (int, optional): Number of pixels in gridded map. Defaults to 512. + map_size (u.Quantity, optional): Angular size of gridded map. Defaults to 8*u.deg. + + Returns: + Figure: Bubble plot figure. + """ good_idx = ( (~tab["snr_flag"]) & (~tab["leakage_flag"]) @@ -435,7 +522,16 @@ def main( map_size: float = 8, snr_cut: float = 50, bins: int = 11, -): +) -> None: + """Validation flow. + + Args: + catalogue_path (Path): Path to the catalogue. + npix (int, optional): Number of pixels in gridded maps. Defaults to 512. + map_size (float, optional): Size of gridded maps in degrees. Defaults to 8. + snr_cut (float, optional): SNR cut for maps. Defaults to 50. + bins (int, optional): Number of bins in gridded maps. Defaults to 11. + """ outdir = catalogue_path.parent tab = Table.read(catalogue_path) @@ -490,6 +586,14 @@ def main( def validation_parser(parent_parser: bool = False) -> argparse.ArgumentParser: + """Parse validation arguments. + + Args: + parent_parser (bool, optional): If this is a parent parser. Defaults to False. + + Returns: + argparse.ArgumentParser: The parser + """ descStr = f""" {logo_str} Arrakis: @@ -517,7 +621,8 @@ def validation_parser(parent_parser: bool = False) -> argparse.ArgumentParser: return val_parser -def cli(): +def cli() -> None: + """Command-line interface.""" catalogue_parser = cat_parser(parent_parser=True) val_parser = validation_parser(parent_parser=True) parser = argparse.ArgumentParser( diff --git a/arrakis/wsclean_rmsynth.py b/arrakis/wsclean_rmsynth.py index 77c354c7..dea2ce18 100755 --- a/arrakis/wsclean_rmsynth.py +++ b/arrakis/wsclean_rmsynth.py @@ -23,7 +23,7 @@ from scipy.optimize import curve_fit -class RMSynthParams(NamedTuple): +class RMSynthParams(Struct): phis: np.ndarray phis_double: np.ndarray lsq: np.ndarray @@ -313,7 +313,6 @@ def simple_clean( Returns: np.ndarray: Spectrum """ - # Fit to PI fdf_p = np.abs(fdf) sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) diff --git a/pyproject.toml b/pyproject.toml index 0306996f..020c1206 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ extend-select = [ "EXE", # flake8-executable "NPY", # NumPy specific rules "PD", # pandas-vet + "D", # flake8-docstrings ] ignore = [ "PLR09", # Too many <...> @@ -160,3 +161,8 @@ isort.required-imports = ["from __future__ import annotations"] [tool.ruff.lint.per-file-ignores] "tests/**" = ["T20"] "noxfile.py" = ["T20"] +# Ignore all rules for wsclean +"arrakis/wsclean_rmsynth.py" = ["ALL"] + +[tool.ruff.lint.pydocstyle] +convention = "google" From b3bab85303be7484ade06c3e96fdfc2c69c3e216 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 03:53:45 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/casda_prepare.py | 3 ++- scripts/compare_leakage.py | 3 +-- scripts/hello_mpi_world.py | 4 +--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index f18380e2..9c3b575d 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -383,7 +383,8 @@ def write_polspec(table: Table, filename: str, overwrite: bool = False): filename : str Name and relative path of the file to save to. overwrite : bool [False] - Overwrite the file if it already exists?""" + Overwrite the file if it already exists? + """ # This is going to be complicated, because the automatic write algorithm # doesn't like variable length arrays. pyfits can support it, it just # needs a little TLC to get it into the correct format. diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index 019916bd..d93abda4 100644 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -1,5 +1,4 @@ -""" -The interpolation works as follows: +"""The interpolation works as follows: Take pixels offsets x,y from reference pixel in input image, multiply by axis increments to get offx and offy. diff --git a/scripts/hello_mpi_world.py b/scripts/hello_mpi_world.py index f47226bb..797f3ff7 100755 --- a/scripts/hello_mpi_world.py +++ b/scripts/hello_mpi_world.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -""" -Parallel Hello World -""" +"""Parallel Hello World""" from __future__ import annotations From 05198c42dff37712c2c8495b8166b2f18007be83 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 11:57:19 +0800 Subject: [PATCH 08/17] Ruff --- arrakis/frion.py | 2 +- arrakis/utils/fitsutils.py | 2 +- pyproject.toml | 5 +++-- scripts/casda_prepare.py | 3 ++- scripts/compare_leakage.py | 3 +-- scripts/copy_cutouts.py | 5 ++--- scripts/copy_cutouts_askap.py | 5 ++--- scripts/fix_dr1_cat.py | 1 - scripts/hello_mpi_world.py | 4 +--- scripts/spica.py | 3 +-- 10 files changed, 14 insertions(+), 19 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index 9a198fb5..acb476e2 100755 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -15,7 +15,6 @@ import numpy as np import pymongo from astropy.time import Time, TimeDelta -from FRion import correct, predict from prefect import flow, task from tqdm.auto import tqdm @@ -28,6 +27,7 @@ ) from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser +from FRion import correct, predict logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index c0c5dbe2..b25ca63f 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -11,10 +11,10 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS -from FRion.correct import find_freq_axis from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger +from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) diff --git a/pyproject.toml b/pyproject.toml index 020c1206..03074506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ tar_cubelets = { reference="scripts/tar_cubelets.py", type="file"} create_mongodb = { reference="scripts/create_mongodb.py", type="file"} [tool.ruff] -src = ["arrakis", "scripts", "tests"] +src = ["arrakis"] [tool.ruff.lint] extend-select = [ @@ -159,7 +159,8 @@ isort.required-imports = ["from __future__ import annotations"] # typing-modules = ["cutout_fits._compat.typing"] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["T20"] +"tests/**" = ["ALL"] +"scripts/**" = ["ALL"] "noxfile.py" = ["T20"] # Ignore all rules for wsclean "arrakis/wsclean_rmsynth.py" = ["ALL"] diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index f18380e2..9c3b575d 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -383,7 +383,8 @@ def write_polspec(table: Table, filename: str, overwrite: bool = False): filename : str Name and relative path of the file to save to. overwrite : bool [False] - Overwrite the file if it already exists?""" + Overwrite the file if it already exists? + """ # This is going to be complicated, because the automatic write algorithm # doesn't like variable length arrays. pyfits can support it, it just # needs a little TLC to get it into the correct format. diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index 019916bd..d93abda4 100644 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -1,5 +1,4 @@ -""" -The interpolation works as follows: +"""The interpolation works as follows: Take pixels offsets x,y from reference pixel in input image, multiply by axis increments to get offx and offy. diff --git a/scripts/copy_cutouts.py b/scripts/copy_cutouts.py index 6f8693c5..c9e92f8c 100755 --- a/scripts/copy_cutouts.py +++ b/scripts/copy_cutouts.py @@ -3,11 +3,10 @@ import argparse import os -from arrakis.logger import logger, logging -from arrakis.utils.io import try_mkdir - import copy_data import spica +from arrakis.logger import logger, logging +from arrakis.utils.io import try_mkdir logger.setLevel(logging.INFO) diff --git a/scripts/copy_cutouts_askap.py b/scripts/copy_cutouts_askap.py index d2f4d2f4..e1f9d4cc 100644 --- a/scripts/copy_cutouts_askap.py +++ b/scripts/copy_cutouts_askap.py @@ -3,11 +3,10 @@ import argparse import os -from arrakis.logger import logger, logging -from arrakis.utils.io import try_mkdir - import copy_data import spica +from arrakis.logger import logger, logging +from arrakis.utils.io import try_mkdir logger.setLevel(logging.INFO) diff --git a/scripts/fix_dr1_cat.py b/scripts/fix_dr1_cat.py index 47fe1f69..27da194a 100755 --- a/scripts/fix_dr1_cat.py +++ b/scripts/fix_dr1_cat.py @@ -22,7 +22,6 @@ from astropy.time import Time from astropy.units import cds from rmtable import RMTable - from spica import SPICA diff --git a/scripts/hello_mpi_world.py b/scripts/hello_mpi_world.py index f47226bb..797f3ff7 100755 --- a/scripts/hello_mpi_world.py +++ b/scripts/hello_mpi_world.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -""" -Parallel Hello World -""" +"""Parallel Hello World""" from __future__ import annotations diff --git a/scripts/spica.py b/scripts/spica.py index ff5b69d3..058875c1 100755 --- a/scripts/spica.py +++ b/scripts/spica.py @@ -6,13 +6,12 @@ from glob import glob from pathlib import Path +import copy_data import numpy as np from arrakis.logger import logger, logging from arrakis.utils.io import try_mkdir from astropy.table import Table -import copy_data - logger.setLevel(logging.INFO) racs_area = os.path.abspath("/askapbuffer/payne/mcc381/RACS") From 42f056d4ec743fd6f86c0272b37f94e22175b6c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 03:57:28 +0000 Subject: [PATCH 09/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- arrakis/frion.py | 2 +- arrakis/utils/fitsutils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index acb476e2..9a198fb5 100755 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -15,6 +15,7 @@ import numpy as np import pymongo from astropy.time import Time, TimeDelta +from FRion import correct, predict from prefect import flow, task from tqdm.auto import tqdm @@ -27,7 +28,6 @@ ) from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser -from FRion import correct, predict logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index b25ca63f..c0c5dbe2 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -11,10 +11,10 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS +from FRion.correct import find_freq_axis from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger -from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) From 831f6009052d39f74c1f1f1f7076e3de7808ff84 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 11:58:01 +0800 Subject: [PATCH 10/17] Ignore docs --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 03074506..a17fd6c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,7 @@ isort.required-imports = ["from __future__ import annotations"] "tests/**" = ["ALL"] "scripts/**" = ["ALL"] "noxfile.py" = ["T20"] +"docs/**" = ["ALL"] # Ignore all rules for wsclean "arrakis/wsclean_rmsynth.py" = ["ALL"] From 70bd41747fb2357ef06a492adaa1c2d68ad7c865 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 11:59:18 +0800 Subject: [PATCH 11/17] Ignores --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a17fd6c9..3bb45966 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,6 +163,7 @@ isort.required-imports = ["from __future__ import annotations"] "scripts/**" = ["ALL"] "noxfile.py" = ["T20"] "docs/**" = ["ALL"] +"submit/**" = ["ALL"] # Ignore all rules for wsclean "arrakis/wsclean_rmsynth.py" = ["ALL"] From 23b81194c0a067957a44a7b69ff550e776a98481 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 15:36:47 +0800 Subject: [PATCH 12/17] Ignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7c21c098..de024cd0 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ setup.py.bak # Test data test/data/testdb/* +.nox From bfd9c19d81de1aae9777624c5e83875986163abb Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 15:38:04 +0800 Subject: [PATCH 13/17] Ignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index de024cd0..cebe648e 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,4 @@ setup.py.bak # Test data test/data/testdb/* .nox +arrakis/_version.py From 27f7b501d3bb7e3fe8b503cc834893f1b4be17bc Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 15:40:51 +0800 Subject: [PATCH 14/17] Formatting --- arrakis/frion.py | 6 +++--- arrakis/imager.py | 8 ++++++-- arrakis/init_database.py | 7 ++++++- arrakis/linmos.py | 2 +- arrakis/makecat.py | 9 +++++---- arrakis/merge_fields.py | 11 ++++++++--- arrakis/rmclean_oncuts.py | 4 ++-- arrakis/rmsynth_oncuts.py | 12 ++++++++---- arrakis/utils/fitsutils.py | 2 +- arrakis/utils/pipeline.py | 2 +- arrakis/validate.py | 1 + 11 files changed, 42 insertions(+), 22 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index 9a198fb5..bff30735 100755 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -5,9 +5,9 @@ import argparse import logging +from collections.abc import Callable from pathlib import Path from pprint import pformat -from typing import Callable from typing import NamedTuple as Struct from urllib.error import URLError @@ -15,7 +15,6 @@ import numpy as np import pymongo from astropy.time import Time, TimeDelta -from FRion import correct, predict from prefect import flow, task from tqdm.auto import tqdm @@ -28,6 +27,7 @@ ) from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser +from FRion import correct, predict logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) @@ -439,7 +439,7 @@ def main( frion_results = [] assert len(islands) == len(beams_cor), "Islands and beams must be the same length" for island, beam in tqdm( - zip(islands, beams_cor), + zip(islands, beams_cor, strict=False), desc="Submitting tasks", file=TQDM_OUT, total=len(islands), diff --git a/arrakis/imager.py b/arrakis/imager.py index 1d560a6f..a17e94a2 100755 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -205,7 +205,9 @@ def make_validation_plots(prefix: Path, pols: str) -> None: for stokes in pols: mfs_image = get_mfs_image(prefix_str, stokes) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) - for ax, sub_image, title in zip(axs, mfs_image, ("Image", "Model", "Residual")): + for ax, sub_image, title in zip( + axs, mfs_image, ("Image", "Model", "Residual"), strict=False + ): abs_sub_image = np.abs(sub_image) if title == "Model": norm = ImageNormalize( @@ -951,7 +953,9 @@ def main( ms_list_fixed.append(ms_fix) pol_angles.append(pol_angle_deg) - for ms, ms_fix, pol_angle_deg in zip(mslist, ms_list_fixed, pol_angles): + for ms, ms_fix, pol_angle_deg in zip( + mslist, ms_list_fixed, pol_angles, strict=False + ): # Image with wsclean # split out stokes I and QUV if "I" in pols: diff --git a/arrakis/init_database.py b/arrakis/init_database.py index d3f75233..648bad34 100755 --- a/arrakis/init_database.py +++ b/arrakis/init_database.py @@ -301,7 +301,12 @@ def get_beams(mastercat: Table, database: Table, epoch: int = 0) -> list[dict]: beam_list: list[dict] = [] for _i, (val, idx) in enumerate( - tqdm(zip(vals, ixs), total=len(vals), desc="Getting beams", file=TQDM_OUT) + tqdm( + zip(vals, ixs, strict=False), + total=len(vals), + desc="Getting beams", + file=TQDM_OUT, + ) ): beam_dict = {} name = mastercat[val]["Source_Name"] diff --git a/arrakis/linmos.py b/arrakis/linmos.py index fcaa984f..d9f839e7 100755 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -100,7 +100,7 @@ def find_images( assert len(image_list) == len(weight_list), "Unequal number of weights and images" - for im, wt in zip(image_list, weight_list): + for im, wt in zip(image_list, weight_list, strict=False): assert ( im.parent.name == wt.parent.name ), "Image and weight are in different areas!" diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 13d7f217..750b9ba6 100755 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -7,9 +7,9 @@ import logging import time import warnings +from collections.abc import Callable from pathlib import Path from pprint import pformat -from typing import Callable import astropy.units as u import dask.dataframe as dd @@ -287,7 +287,7 @@ def sigma_add_fix[TableLike](tab: TableLike) -> TableLike: med, std = np.zeros_like(s_Q), np.zeros_like(s_Q) for i, (_s_Q, _scale_Q, _s_U, _scale_U) in tqdm( - enumerate(zip(s_Q, scale_Q, s_U, scale_U)), + enumerate(zip(s_Q, scale_Q, s_U, scale_U, strict=False)), total=len(s_Q), desc="Calculating sigma_add", file=TQDM_OUT, @@ -440,7 +440,7 @@ def get_fit_func( rasterized=True, ) plt.plot(bins_c, meds, alpha=1, c=color, label="Median", linewidth=2) - for s, ls in zip((1, 2), ("--", ":")): + for s, ls in zip((1, 2), ("--", ":"), strict=False): for r in ("ups", "los"): plt.plot( bins_c, @@ -682,7 +682,7 @@ def get_alpha(cat: TableLike) -> SpectralIndices: alphas_err = [] betas = [] betas_err = [] - for c, c_err in zip(coefs_str, coefs_err_str): + for c, c_err in zip(coefs_str, coefs_err_str, strict=False): coefs = c.split(",") coefs_err = c_err.split(",") # alpha is the 2nd last coefficient @@ -1101,6 +1101,7 @@ def main( columns_possum.input_sources, columns_possum.input_names, columns_possum.output_units, + strict=False, ), total=len(columns_possum.output_cols), desc="Making table by column", diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index a60ba5f6..c00c36d4 100755 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -68,7 +68,9 @@ def copy_singleton( u_file_new = (new_dir / u_file_old.name).with_suffix(".edge.linmos.fits") for src, dst in zip( - [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] + [i_file_old, q_file_old, u_file_old], + [i_file_new, q_file_new, u_file_new], + strict=False, ): copyfile(src, dst) src_weight = ( @@ -229,7 +231,9 @@ def merge_multiple_field( try_mkdir(new_dir, verbose=False) - for stokes, imlist in zip(["I", "Q", "U"], [i_files_old, q_files_old, u_files_old]): + for stokes, imlist in zip( + ["I", "Q", "U"], [i_files_old, q_files_old, u_files_old], strict=False + ): parset_file = genparset(imlist, stokes, new_dir) update = linmos.fn(parset_file, merge_name, image) updates.append(update) @@ -324,7 +328,8 @@ def main( ), f"List of fields must be the same length as length of field dirs. {len(fields)=},{len(field_dirs)=}" field_dict: dict[str, Path] = { - field: field_dir / "cutouts" for field, field_dir in zip(fields, field_dirs) + field: field_dir / "cutouts" + for field, field_dir in zip(fields, field_dirs, strict=False) } image = get_yanda(version=yanda) diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index 9821fe59..61b1faf4 100755 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -115,9 +115,9 @@ def rmclean1d( # Ensure JSON serializable for k, v in outdict.items(): - if isinstance(v, (np.float64, np.float32)): + if isinstance(v, np.float64 | np.float32): outdict[k] = float(v) - elif isinstance(v, (np.int_, np.int32)): + elif isinstance(v, np.int_ | np.int32): outdict[k] = int(v) elif isinstance(v, np.ndarray): outdict[k] = v.tolist() diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index fce1440c..fa69cc87 100755 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -692,9 +692,9 @@ def rmsynthoncut1d( # Ensure JSON serializable for k, v in mDict.items(): - if isinstance(v, (np.float64, np.float32)): + if isinstance(v, np.float64 | np.float32): mDict[k] = float(v) - elif isinstance(v, (np.int_, np.int32)): + elif isinstance(v, np.int_ | np.int32): mDict[k] = int(v) elif isinstance(v, np.ndarray): mDict[k] = v.tolist() @@ -998,7 +998,11 @@ def main( logger.info(f"Running RMsynth on {n_comp} components") outputs = [] for comp_tuple, beam_tuple in tqdm( - zip(components.iterrows(), beams.loc[components.Source_ID].iterrows()), + zip( + components.iterrows(), + beams.loc[components.Source_ID].iterrows(), + strict=False, + ), total=n_comp, desc="Submitting RMsynth 1D jobs", file=TQDM_OUT, @@ -1033,7 +1037,7 @@ def main( logger.info(f"Running RMsynth on {n_island} islands") outputs = [] for island_id, beam_tuple in tqdm( - zip(island_ids, beams.loc[island_ids].iterrows()), + zip(island_ids, beams.loc[island_ids].iterrows(), strict=False), total=n_island, desc="Submitting RMsynth 3D jobs", file=TQDM_OUT, diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index c0c5dbe2..b25ca63f 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -11,10 +11,10 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS -from FRion.correct import find_freq_axis from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger +from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) diff --git a/arrakis/utils/pipeline.py b/arrakis/utils/pipeline.py index 92dc7216..f6ed08c0 100644 --- a/arrakis/utils/pipeline.py +++ b/arrakis/utils/pipeline.py @@ -332,7 +332,7 @@ def _draw_stop(self, **kwargs): def tqdm_dask(futures_in: distributed.Future, **kwargs) -> None: """Tqdm for Dask futures.""" futures = futures_of(futures_in) - if not isinstance(futures, (set, list)): + if not isinstance(futures, set | list): futures = [futures] TqdmProgressBar(futures, **kwargs) diff --git a/arrakis/validate.py b/arrakis/validate.py index 7520deb3..d6591b31 100755 --- a/arrakis/validate.py +++ b/arrakis/validate.py @@ -411,6 +411,7 @@ def plot_rm( ("NVSS", "SPASS"), (nvss_tab, spass_tab), (ax_dict["N"], ax_dict["S"]), + strict=False, ): ax.set_title(label) racs_match, other_match = cross_match(good_tab, other_cat, radius=60 * u.arcsec) From 878d00d50980dfc36cbf22ff4f4729f22ad0e9f2 Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 15:41:02 +0800 Subject: [PATCH 15/17] CICD --- .github/CONTRIBUTING.md | 89 +++++++++++++ .github/dependabot.yml | 16 +-- .github/release.yml | 5 + .github/workflows/cd.yml | 60 +++++++++ .github/workflows/ci.yml | 71 ++++++++++ noxfile.py | 101 ++++++++++++++ pyproject.toml | 276 ++++++++++++++++++++++++--------------- pyproject.toml.bak | 171 ++++++++++++++++++++++++ 8 files changed, 675 insertions(+), 114 deletions(-) create mode 100644 .github/CONTRIBUTING.md create mode 100644 .github/release.yml create mode 100644 .github/workflows/cd.yml create mode 100644 .github/workflows/ci.yml create mode 100644 noxfile.py create mode 100644 pyproject.toml.bak diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 00000000..a7f329d7 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,89 @@ +See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed +description of best practices for developing scientific packages. + +[spc-dev-intro]: https://learn.scientific-python.org/development/ + +# Quick development + +The fastest way to start with development is to use nox. If you don't have nox, +you can use `pipx run nox` to run it without installing, or `pipx install nox`. +If you don't have pipx (pip for applications), then you can install with +`pip install pipx` (the only case were installing an application with regular +pip is reasonable). If you use macOS, then pipx and nox are both in brew, use +`brew install pipx nox`. + +To use, run `nox`. This will lint and test using every installed version of +Python on your system, skipping ones that are not installed. You can also run +specific jobs: + +```console +$ nox -s lint # Lint only +$ nox -s tests # Python tests +$ nox -s docs -- --serve # Build and serve the docs +$ nox -s build # Make an SDist and wheel +``` + +Nox handles everything for you, including setting up an temporary virtual +environment for each run. + +# Setting up a development environment manually + +You can set up a development environment by running: + +```bash +python3 -m venv .venv +source ./.venv/bin/activate +pip install -v -e .[dev] +``` + +If you have the +[Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you +can instead do: + +```bash +py -m venv .venv +py -m install -v -e .[dev] +``` + +# Pre-commit + +You should prepare pre-commit, which will help you by checking that commits pass +required checks: + +```bash +pip install pre-commit # or brew install pre-commit on macOS +pre-commit install # Will install a pre-commit hook into the git repo +``` + +You can also/alternatively run `pre-commit run` (changes only) or +`pre-commit run --all-files` to check even without installing the hook. + +# Testing + +Use pytest to run the unit checks: + +```bash +pytest +``` + +# Coverage + +Use pytest-cov to generate coverage reports: + +```bash +pytest --cov=cutout-fits +``` + +# Building docs + +You can build the docs using: + +```bash +nox -s docs +``` + +You can see a preview with: + +```bash +nox -s docs -- --serve +``` diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b85bc841..6c4b3695 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,13 +1,11 @@ -# To get started with Dependabot version updates, you'll need to specify which -# package ecosystems to update and where the package manifests are located. -# Please see the documentation for all configuration options: -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates - version: 2 updates: - - package-ecosystem: "pip" # See documentation for possible values - directory: "/" # Location of package manifests + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" schedule: interval: "weekly" - reviewers: - - "AlecThomson" + groups: + actions: + patterns: + - "*" diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 00000000..9d1e0987 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,5 @@ +changelog: + exclude: + authors: + - dependabot + - pre-commit-ci diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 00000000..325c02d4 --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,60 @@ +name: CD + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + release: + types: + - published + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + # Many color libraries just need this to be set to any value, but at least + # one distinguishes color depth, where "3" -> "256-bit color". + FORCE_COLOR: 3 + +jobs: + dist: + name: Distribution build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: hynek/build-and-inspect-python-package@v2 + + publish: + needs: [dist] + name: Publish to PyPI + environment: pypi + permissions: + id-token: write + attestations: write + contents: read + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + + steps: + - uses: actions/download-artifact@v4 + with: + name: Packages + path: dist + + - name: Generate artifact attestation for sdist and wheel + uses: actions/attest-build-provenance@v1.3.3 + with: + subject-path: "dist/*" + + - uses: pypa/gh-action-pypi-publish@release/v1 + # with: + # Remember to tell (test-)pypi about this repo before publishing + # Remove this line to publish to PyPI + # repository-url: https://test.pypi.org/legacy/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..046fba4c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,71 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + # Many color libraries just need this to be set to any value, but at least + # one distinguishes color depth, where "3" -> "256-bit color". + FORCE_COLOR: 3 + +jobs: + pre-commit: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --hook-stage manual --all-files + - name: Run PyLint + run: pipx run nox -s pylint -- --output-format=github + + checks: + name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.12"] + runs-on: [ubuntu-latest, windows-latest, macos-14] + + include: + - python-version: "pypy-3.10" + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + allow-prereleases: true + + - name: Install package + run: python -m pip install uv && uv pip install .[test] --system + + - name: Test package + run: >- + python -m pytest -ra --cov --cov-report=xml --cov-report=term + --durations=20 + + - name: Upload coverage report + uses: codecov/codecov-action@v4.5.0 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 00000000..0f69efe4 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,101 @@ +"""Nox configuration.""" + +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +import nox + +DIR = Path(__file__).parent.resolve() + +nox.needs_version = ">=2024.3.2" +nox.options.sessions = ["lint", "pylint", "tests"] +nox.options.default_venv_backend = "mamba|uv|virtualenv" + + +@nox.session(python="3.10") +def lint(session: nox.Session) -> None: + """Run the linter.""" + session.install("pre-commit") + session.run( + "pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs + ) + + +@nox.session(python="3.10") +def pylint(session: nox.Session) -> None: + """Run PyLint.""" + # This needs to be installed into the package environment, and is slower + # than a pre-commit check + session.conda_install("casacore", "python-casacore", channel="conda-forge") + session.install(".", "pylint>=3.2") + session.run( + "pylint", + "--ignored-classes=astropy.units", + "-d duplicate-code", + "arrakis", + *session.posargs, + ) + + +@nox.session(python="3.10") +def tests(session: nox.Session) -> None: + """Run the unit and regular tests.""" + session.install(".[test]") + session.run("pytest", *session.posargs) + + +@nox.session(reuse_venv=True) +def docs(session: nox.Session) -> None: + """Build the docs. Pass --non-interactive to avoid serving. First positional argument is the target directory.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-b", dest="builder", default="html", help="Build target (default: html)" + ) + parser.add_argument("output", nargs="?", help="Output directory") + args, posargs = parser.parse_known_args(session.posargs) + serve = args.builder == "html" and session.interactive + + session.install("-e.[docs]", "sphinx-autobuild") + + shared_args = ( + "-n", # nitpicky mode + "-T", # full tracebacks + f"-b={args.builder}", + "docs", + args.output or f"docs/_build/{args.builder}", + *posargs, + ) + + if serve: + session.run("sphinx-autobuild", "--open-browser", *shared_args) + else: + session.run("sphinx-build", "--keep-going", *shared_args) + + +@nox.session(python="3.10") +def build_api_docs(session: nox.Session) -> None: + """Build (regenerate) API docs.""" + session.install("sphinx") + session.run( + "sphinx-apidoc", + "-o", + "docs/api/", + "--module-first", + "--no-toc", + "--force", + "arrakis", + ) + + +@nox.session(python="3.10") +def build(session: nox.Session) -> None: + """Build an SDist and wheel.""" + build_path = DIR.joinpath("build") + if build_path.exists(): + shutil.rmtree(build_path) + + session.install("build") + session.run("python", "-m", "build") diff --git a/pyproject.toml b/pyproject.toml index 3bb45966..d89add16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,125 +1,173 @@ -[tool.poetry] +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + + +[project] name = "arrakis" -version = "0.0.0" # A placeholder +authors = [ + { name = "Alec Thomson", email = "alec.thomson@csiro.au" }, + { name = "Tim Galvin", email = "tim.galvin@csiro.au" }, +] description = "Processing the SPICE." -homepage = "https://research.csiro.au/racs/" -repository = "https://github.com/AlecThomson/arrakis" -documentation = "https://arrakis.readthedocs.io/en/latest/" -authors = ["Alec Thomson", "Tim Galvin"] -license = "BSD-3-Clause" readme = "README.md" -classifiers=[ - "License :: OSI Approved :: BSD License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering :: Astronomy", - "Development Status :: 3 - Alpha", +license.file = "LICENSE" +requires-python = ">=3.10, <3.11" +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Typing :: Typed", ] -packages = [ - { include = "arrakis"}, +dynamic = ["version"] +dependencies = [ + "numpy>=1.26.4", + "astropy>=5", + "bilby", + "ConfigArgParse", + "dask", + "distributed @ git+https://github.com/AlecThomson/distributed@drainclose", + "FRion>=1.1.3", + "h5py", + "ipython", + "matplotlib>=3.8", + "numba", + "numba_progress", + "pandas>=2", + "psutil", + "pymongo", + "pymultinest", + "pytest", + "python_casacore", + "RACS-tools>=3.0.5", + "radio_beam", + "scipy", + "spectral_cube>=0.6.3", + "spython", + "tqdm", + "vorbin", + "graphviz", + "bokeh<3", + "prefect>=2", + "prefect-dask", + "RMTable>=1.2.1", + "RM-Tools>=1.4.2", + "PolSpectra>=1.1.0", + "fixms>=0.2.6", + "fitscube>=0.3", + "psycopg2-binary", + "sqlalchemy", + "scikit-image>=0.23", + "setuptools", ] -include = [ - {path='arrakis/configs/*', format=['sdist','wheel']}, +[project.optional-dependencies] +test = [ + "pytest >=6", + "pytest-cov >=3", +] +dev = [ + "pytest >=6", + "pytest-cov >=3", ] - -[tool.poetry.dependencies] -python = ">=3.8" -astropy = ">=5" -bilby = "*" -ConfigArgParse = "*" -dask = "*" -distributed = {git="https://github.com/AlecThomson/distributed", branch="drainclose"} -dask_jobqueue = {version=">=0.8.3", optional=true} -dask_mpi = "*" -FRion = ">=1.1.3" -h5py = "*" -ipython = "*" -matplotlib = ">=3.8" -numba = "*" -numba_progress = "*" -pandas = ">=2" -psutil = "*" -pymongo = "*" -pymultinest = "*" -pytest = "*" -python_casacore = "*" -RACS-tools = ">=3.0.5" -radio_beam = "*" -RMextract = {git = "https://github.com/lofar-astron/RMextract", optional=true} -schwimmbad = "*" -scipy = "*" -spectral_cube = ">=0.6.3" -spython = "*" -tqdm = "*" -vorbin = "*" -graphviz = "*" -bokeh = "<3" -prefect = ">=2" -prefect-dask = "*" -RMTable = ">=1.2.1" -RM-Tools = ">=1.4.2" -PolSpectra = ">=1.1.0" -setuptools = "*" -fixms = ">=0.2.6" -fitscube = ">=0.3" -psycopg2-binary = "*" -sqlalchemy = "*" -scikit-image = ">=0.23" - -[tool.poetry.dev-dependencies] -black = ">=23" -flake8 = ">=5" -isort = ">=5" -mypy = ">=1" -pre-commit = ">=3.2" - -[tool.poetry.extras] docs = [ - "sphinx", - "sphinx_rtd_theme", - "sphinx-book-theme", - "sphinx-autoapi", - "m2r2", - "numpydoc", - "sphinxcontrib-mermaid", + "sphinx", + "sphinx_rtd_theme", + "sphinx-book-theme", + "sphinx-autoapi", + "m2r2", + "numpydoc", + "sphinxcontrib-mermaid", +] +mpi = [ + "dask_mpi", + "mpi4py", ] -RMextract = ["RMextract"] -jobqueue = ["dask_jobqueue"] -[tool.poetry-dynamic-versioning] -enable = true +jobqueue = [ + "dask_jobqueue >=0.8.3", +] -[build-system] -requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0", "numpy"] -build-backend = "poetry_dynamic_versioning.backend" +ion = [ + "RMextract>=0.5.0", +] -[tool.poetry.scripts] +[project.urls] +Homepage = "https://research.csiro.au/racs/" +Repository = "https://github.com/AlecThomson/arrakis" +"Bug Tracker" = "https://github.com/AlecThomson/arrakis/issues" +Discussions = "https://github.com/AlecThomson/arrakis/discussions" +Changelog = "https://github.com/AlecThomson/arrakis/releases" +Documentation = "https://arrakis.readthedocs.io" + +[project.scripts] spice_init = "arrakis.init_database:cli" spice_process = "arrakis.process_spice:cli" spice_region = "arrakis.process_region:cli" spice_cat = "arrakis.makecat:cli" spice_image = "arrakis.imager:cli" -# Misc scripts -casda_prepare = { reference="scripts/casda_prepare.py", type="file"} -check_cutout = { reference="scripts/check_cutout.py", type="file"} -compare_leakage = { reference="scripts/compare_leakage.py", type="file"} -compute_leakage = { reference="scripts/compute_leakage.py", type="file"} -copy_cutouts_askap = { reference="scripts/copy_cutouts_askap.py", type="file"} -copy_cutouts = { reference="scripts/copy_cutouts.py", type="file"} -copy_data = { reference="scripts/copy_data.py", type="file"} -find_row = { reference="scripts/find_row.py", type="file"} -find_sbid = { reference="scripts/find_sbid.py", type="file"} -fix_dr1_cat = { reference="scripts/fix_dr1_cat.py", type="file"} -fix_src_cat = { reference="scripts/fix_src_cat.py", type="file"} -hello_mpi_world = { reference="scripts/hello_mpi_world.py", type="file"} -make_links = { reference="scripts/make_links.py", type="file"} -spica = { reference="scripts/spica.py", type="file"} -tar_cubelets = { reference="scripts/tar_cubelets.py", type="file"} -create_mongodb = { reference="scripts/create_mongodb.py", type="file"} +[tool.hatch] +version.source = "vcs" +build.hooks.vcs.version-file = "arrakis/_version.py" +metadata.allow-direct-references = true + +[tool.hatch.envs.default] +features = ["test"] +scripts.test = "pytest {args}" + +[tool.hatch.build.targets.wheel] +packages = ["arrakis"] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +filterwarnings = [ + "error", +] +log_cli_level = "INFO" +testpaths = [ + "tests", +] + + +[tool.coverage] +run.source = ["arrakis"] +report.exclude_also = [ + '\.\.\.', + 'if typing.TYPE_CHECKING:', +] + +[tool.mypy] +files = ["arrakis", "tests"] +python_version = "3.8" +warn_unused_configs = true +strict = true +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +warn_unreachable = true +disallow_untyped_defs = false +disallow_incomplete_defs = false + +[[tool.mypy.overrides]] +module = "arrakis.*" +disallow_untyped_defs = true +disallow_incomplete_defs = true + +[[tool.mypy.overrides]] +module = "astropy.*" +ignore_missing_imports = true [tool.ruff] src = ["arrakis"] @@ -156,7 +204,7 @@ ignore = [ ] isort.required-imports = ["from __future__ import annotations"] # Uncomment if using a _compat.typing backport -# typing-modules = ["cutout_fits._compat.typing"] +# typing-modules = ["fitscube._compat.typing"] [tool.ruff.lint.per-file-ignores] "tests/**" = ["ALL"] @@ -169,3 +217,21 @@ isort.required-imports = ["from __future__ import annotations"] [tool.ruff.lint.pydocstyle] convention = "google" + + +[tool.pylint] +py-version = "3.8" +ignore-paths = [".*/_version.py"] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "design", + "fixme", + "line-too-long", + "missing-module-docstring", + "missing-function-docstring", + "wrong-import-position", +] + +[tool.codespell] +ignore-words-list = "datas" diff --git a/pyproject.toml.bak b/pyproject.toml.bak new file mode 100644 index 00000000..3bb45966 --- /dev/null +++ b/pyproject.toml.bak @@ -0,0 +1,171 @@ +[tool.poetry] +name = "arrakis" +version = "0.0.0" # A placeholder +description = "Processing the SPICE." +homepage = "https://research.csiro.au/racs/" +repository = "https://github.com/AlecThomson/arrakis" +documentation = "https://arrakis.readthedocs.io/en/latest/" +authors = ["Alec Thomson", "Tim Galvin"] +license = "BSD-3-Clause" +readme = "README.md" +classifiers=[ + "License :: OSI Approved :: BSD License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Astronomy", + "Development Status :: 3 - Alpha", +] +packages = [ + { include = "arrakis"}, +] + +include = [ + {path='arrakis/configs/*', format=['sdist','wheel']}, +] + +[tool.poetry.dependencies] +python = ">=3.8" +astropy = ">=5" +bilby = "*" +ConfigArgParse = "*" +dask = "*" +distributed = {git="https://github.com/AlecThomson/distributed", branch="drainclose"} +dask_jobqueue = {version=">=0.8.3", optional=true} +dask_mpi = "*" +FRion = ">=1.1.3" +h5py = "*" +ipython = "*" +matplotlib = ">=3.8" +numba = "*" +numba_progress = "*" +pandas = ">=2" +psutil = "*" +pymongo = "*" +pymultinest = "*" +pytest = "*" +python_casacore = "*" +RACS-tools = ">=3.0.5" +radio_beam = "*" +RMextract = {git = "https://github.com/lofar-astron/RMextract", optional=true} +schwimmbad = "*" +scipy = "*" +spectral_cube = ">=0.6.3" +spython = "*" +tqdm = "*" +vorbin = "*" +graphviz = "*" +bokeh = "<3" +prefect = ">=2" +prefect-dask = "*" +RMTable = ">=1.2.1" +RM-Tools = ">=1.4.2" +PolSpectra = ">=1.1.0" +setuptools = "*" +fixms = ">=0.2.6" +fitscube = ">=0.3" +psycopg2-binary = "*" +sqlalchemy = "*" +scikit-image = ">=0.23" + +[tool.poetry.dev-dependencies] +black = ">=23" +flake8 = ">=5" +isort = ">=5" +mypy = ">=1" +pre-commit = ">=3.2" + +[tool.poetry.extras] +docs = [ + "sphinx", + "sphinx_rtd_theme", + "sphinx-book-theme", + "sphinx-autoapi", + "m2r2", + "numpydoc", + "sphinxcontrib-mermaid", +] +RMextract = ["RMextract"] +jobqueue = ["dask_jobqueue"] + +[tool.poetry-dynamic-versioning] +enable = true + +[build-system] +requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0", "numpy"] +build-backend = "poetry_dynamic_versioning.backend" + +[tool.poetry.scripts] +spice_init = "arrakis.init_database:cli" +spice_process = "arrakis.process_spice:cli" +spice_region = "arrakis.process_region:cli" +spice_cat = "arrakis.makecat:cli" +spice_image = "arrakis.imager:cli" + +# Misc scripts +casda_prepare = { reference="scripts/casda_prepare.py", type="file"} +check_cutout = { reference="scripts/check_cutout.py", type="file"} +compare_leakage = { reference="scripts/compare_leakage.py", type="file"} +compute_leakage = { reference="scripts/compute_leakage.py", type="file"} +copy_cutouts_askap = { reference="scripts/copy_cutouts_askap.py", type="file"} +copy_cutouts = { reference="scripts/copy_cutouts.py", type="file"} +copy_data = { reference="scripts/copy_data.py", type="file"} +find_row = { reference="scripts/find_row.py", type="file"} +find_sbid = { reference="scripts/find_sbid.py", type="file"} +fix_dr1_cat = { reference="scripts/fix_dr1_cat.py", type="file"} +fix_src_cat = { reference="scripts/fix_src_cat.py", type="file"} +hello_mpi_world = { reference="scripts/hello_mpi_world.py", type="file"} +make_links = { reference="scripts/make_links.py", type="file"} +spica = { reference="scripts/spica.py", type="file"} +tar_cubelets = { reference="scripts/tar_cubelets.py", type="file"} +create_mongodb = { reference="scripts/create_mongodb.py", type="file"} + +[tool.ruff] +src = ["arrakis"] + +[tool.ruff.lint] +extend-select = [ + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + # "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable + "NPY", # NumPy specific rules + "PD", # pandas-vet + "D", # flake8-docstrings +] +ignore = [ + "PLR09", # Too many <...> + "PLR2004", # Magic value used in comparison + "ISC001", # Conflicts with formatter +] +isort.required-imports = ["from __future__ import annotations"] +# Uncomment if using a _compat.typing backport +# typing-modules = ["cutout_fits._compat.typing"] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["ALL"] +"scripts/**" = ["ALL"] +"noxfile.py" = ["T20"] +"docs/**" = ["ALL"] +"submit/**" = ["ALL"] +# Ignore all rules for wsclean +"arrakis/wsclean_rmsynth.py" = ["ALL"] + +[tool.ruff.lint.pydocstyle] +convention = "google" From c4214edcb589e149a677799fa2a7b940923860b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 07:59:00 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- arrakis/frion.py | 2 +- arrakis/utils/fitsutils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index bff30735..02427d7d 100755 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -15,6 +15,7 @@ import numpy as np import pymongo from astropy.time import Time, TimeDelta +from FRion import correct, predict from prefect import flow, task from tqdm.auto import tqdm @@ -27,7 +28,6 @@ ) from arrakis.utils.fitsutils import getfreq from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser -from FRion import correct, predict logger.setLevel(logging.INFO) TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) diff --git a/arrakis/utils/fitsutils.py b/arrakis/utils/fitsutils.py index b25ca63f..c0c5dbe2 100644 --- a/arrakis/utils/fitsutils.py +++ b/arrakis/utils/fitsutils.py @@ -11,10 +11,10 @@ from astropy.io import fits from astropy.utils.exceptions import AstropyWarning from astropy.wcs import WCS +from FRion.correct import find_freq_axis from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger -from FRion.correct import find_freq_axis warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) From 626d20985e98fa7078e54606faaaee744092089c Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Thu, 25 Jul 2024 16:19:56 +0800 Subject: [PATCH 17/17] Version bumpy --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d89add16..5f4e6d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dependencies = [ "RMTable>=1.2.1", "RM-Tools>=1.4.2", "PolSpectra>=1.1.0", - "fixms>=0.2.6", + "fixms>=0.2.9", "fitscube>=0.3", "psycopg2-binary", "sqlalchemy",