Skip to content

Commit

Permalink
ssins multi-file chunking support for docker memory issue
Browse files Browse the repository at this point in the history
  • Loading branch information
d3v-null committed Sep 27, 2024
1 parent 39fa2fd commit 0fced77
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 86 deletions.
93 changes: 55 additions & 38 deletions demo/03_mwalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,58 @@
from pandas import DataFrame
import sys

metafits = sys.argv[-1]
# where to put the channel and antenna info
ctx = MetafitsContext(metafits)
header = [
"gpubox_number",
"rec_chan_number",
"chan_start_hz",
"chan_centre_hz",
"chan_end_hz",
]
df = DataFrame({h: [getattr(c, h) for c in ctx.metafits_coarse_chans] for h in header})
channels = metafits.replace(".metafits", "-channels.tsv")
df.to_csv(channels, index=False, sep="\t")
print(f"wrote channels to {channels}")

header = [
"ant",
"tile_id",
"tile_name",
"electrical_length_m",
"east_m",
"north_m",
"height_m",
]
df = DataFrame({h: [getattr(a, h) for a in ctx.antennas] for h in header})
df["flagged"] = [a.rfinput_x.flagged | a.rfinput_y.flagged for a in ctx.antennas]
# get elements from antenna rfinput_x, assuming it's the same as rfinput_y
rfheader = ["rec_number", "flavour", "has_whitening_filter"]
for h in rfheader:
df[h] = [getattr(a.rfinput_x, h) for a in ctx.antennas]
# rec_type is "ReceiverType.RRI", I want just "RRI"
df["rec_type"] = [
str(a.rfinput_x.rec_type).replace("ReceiverType.", "") for a in ctx.antennas
]

antennas = metafits.replace(".metafits", "-antennas.tsv")
df.to_csv(antennas, index=False, sep="\t", float_format="%+8.3f")
print(f"wrote antennas to {antennas}")

def get_channel_df(ctx: MetafitsContext):
header = [
"gpubox_number",
"rec_chan_number",
"chan_start_hz",
"chan_centre_hz",
"chan_end_hz",
]
df = DataFrame(
{h: [getattr(c, h) for c in ctx.metafits_coarse_chans] for h in header}
)
return df


def get_antenna_df(ctx: MetafitsContext):
header = [
"ant",
"tile_id",
"tile_name",
"electrical_length_m",
"east_m",
"north_m",
"height_m",
]
df = DataFrame({h: [getattr(a, h) for a in ctx.antennas] for h in header})
df["flagged"] = [a.rfinput_x.flagged | a.rfinput_y.flagged for a in ctx.antennas]
# get elements from antenna rfinput_x, assuming it's the same as rfinput_y
rfheader = ["rec_number", "flavour", "has_whitening_filter"]
for h in rfheader:
df[h] = [getattr(a.rfinput_x, h) for a in ctx.antennas]
# rec_type is "ReceiverType.RRI", I want just "RRI"
df["rec_type"] = [
str(a.rfinput_x.rec_type).replace("ReceiverType.", "") for a in ctx.antennas
]
return df


def main():
metafits = sys.argv[-1]
ctx = MetafitsContext(metafits)
df_ch = get_channel_df(ctx)
# where to put the channel and antenna info
channels = metafits.replace(".metafits", "-channels.tsv")
df_ch.to_csv(channels, index=False, sep="\t")
print(f"wrote channels to {channels}")

df_ant = get_antenna_df(ctx)
antennas = metafits.replace(".metafits", "-antennas.tsv")
df_ant.to_csv(antennas, index=False, sep="\t", float_format="%+8.3f")
print(f"wrote antennas to {antennas}")


if __name__ == "__main__":
main()
187 changes: 139 additions & 48 deletions demo/04_ssins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,68 @@
from pyuvdata import UVData
from SSINS import SS, INS, MF
import os
from os.path import splitext, dirname
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
from astropy.time import Time
import argparse
from itertools import groupby
import re

from matplotlib.axis import Axis


def get_parser():
parser = argparse.ArgumentParser()
# arguments for UVData.read()
group_uvd_read = parser.add_argument_group("UVData.read")
group_uvd_read.add_argument("files", nargs="+")
group_uvd_read.add_argument(
# arguments for SS.read()
group_read = parser.add_argument_group("SS.read")
group_read.add_argument(
"files",
nargs="+",
help="raw .fits (with .metafits), .uvfits supported",
)
group_read.add_argument(
"--no-diff",
default=False,
action="store_true",
help="don't difference visibilities in time (sky-subtract)",
)
group_uvd_read.add_argument(
group_read.add_argument(
"--no-flag-init",
default=False,
action="store_true",
help="skip flagging of edge channels, quack time",
)
group_uvd_read.add_argument(
group_read.add_argument(
"--remove-coarse-band",
default=False,
action="store_true",
help="Correct coarse PFB passband (resolution must be > 10kHz)",
)
group_uvd_read.add_argument(
group_read.add_argument(
"--correct-van-vleck",
default=False,
action="store_true",
help="Correct van vleck quantization artifacts in legacy correlator. slow!",
)
group_uvd_read.add_argument(
group_read.add_argument(
"--include-flagged-ants",
default=False,
action="store_true",
help="Include flagged antenna when reading raw files",
)
group_read.add_argument(
"--flag-choice",
default=None,
nargs=1,
choices=["original"],
help="original = apply flags from visibilities before running ssins (not recommended)",
)

# arguments for UVData.select()
group_uvd_sel = parser.add_argument_group("UVData.select")
group_mutex = group_uvd_sel.add_mutually_exclusive_group()
# arguments for SS.select()
group_sel = parser.add_argument_group("SS.select")
group_mutex = group_sel.add_mutually_exclusive_group()
group_mutex.add_argument(
"--sel-ants",
default=[],
Expand All @@ -69,7 +83,7 @@ def get_parser():
help="antenna names to skip, default: none",
)

group_uvd_sel.add_argument(
group_sel.add_argument(
"--sel-pols",
default=[],
nargs="*",
Expand Down Expand Up @@ -126,6 +140,63 @@ def get_parser():
return parser


def group_by_filetype(paths):
def filetype_classifier(path):
_, ext = splitext(path)
return ext

return {
k: [*v]
for k, v in groupby(
sorted(paths, key=filetype_classifier), key=filetype_classifier
)
}


def group_raw_by_channel(metafits, raw_fits):
__import__("sys").path.insert(0, dirname(__file__))
mwalib_tools = __import__("03_mwalib")
ctx = mwalib_tools.MetafitsContext(metafits)
df_ch = mwalib_tools.get_channel_df(ctx)

def channel_classifier(path):
ch_token = path.split("_")[-2]
if match := re.match(r"gpubox(\d+)", ch_token):
channel = df_ch[df_ch["gpubox_number"] == int(match[1])]
if len(channel) == 0:
raise UserWarning(f"no match of gpubox{match[1]} in {df_ch}")
return int(channel.rec_chan_number.iloc[0])
elif match := re.match(r"ch(\d+)", ch_token):
return match[1]
else:
raise UserWarning(f"unknown channel token {ch_token}")

return {
k: sorted([*v])
for k, v in groupby(
sorted(raw_fits, key=channel_classifier), key=channel_classifier
)
}


def mwalib_get_common_times(metafits, raw_fits, good=True):
from mwalib import CorrelatorContext

gps_times = []
with CorrelatorContext(metafits, raw_fits) as corr_ctx:
timestep_idxs = (
corr_ctx.common_good_timestep_indices
if good
else corr_ctx.common_timestep_indices
)
for time_idx in timestep_idxs:
gps_times.append(corr_ctx.timesteps[time_idx].gps_time_ms / 1000.0)
times = Time(gps_times, format="gps", scale="utc")
int_time = times[1] - times[0]
times -= int_time / 2.0
return times


def get_unflagged_ants(ss: UVData, args):
all_ant_nums = np.array(ss.antenna_numbers)
all_ant_names = np.array(ss.antenna_names)
Expand Down Expand Up @@ -326,50 +397,70 @@ def main():
args = parser.parse_args()
print(f"{args=}")

if len(args.files) > 1 and args.files[-1].endswith(".fits"):
metafits = args.files[0]
# output name is basename of metafits, or uvfits if provided
base, _ = os.path.splitext(metafits)
elif len(args.files) == 1:
vis = args.files[-1]
base, _ = os.path.splitext(vis)
else:
parser.print_usage()
exit(1)
file_groups = group_by_filetype(args.files)

print(f"reading from {args.files=}")
print(f"reading from {file_groups=}")
# sky-subtract https://ssins.readthedocs.io/en/latest/sky_subtract.html
ss = SS()
ss.read(
args.files,
read_data=True,
diff=(not args.no_diff), # difference timesteps
remove_coarse_band=args.remove_coarse_band, # does not work with low freq res
correct_van_vleck=args.correct_van_vleck, # slow
remove_flagged_ants=(not args.include_flagged_ants), # remove flagged antennas
flag_init=(not args.no_flag_init),
)
read_kwargs = {
"diff": (not args.no_diff), # difference timesteps
"remove_coarse_band": args.remove_coarse_band, # does not work with low freq res
"correct_van_vleck": args.correct_van_vleck, # slow
"remove_flagged_ants": (not args.include_flagged_ants), # remove flagged ants
"flag_init": (not args.no_flag_init),
"ant_str": args.spectrum_type,
"flag_choice": args.flag_choice,
}

# output name is basename of metafits, first uvfits or first ms if provided
base = None
# metafits and mwaf flag files only used if raw fits supplied
metafits = None
raw_fits = None
if ".fits" in file_groups:
if ".metafits" not in file_groups:
raise UserWarning(f"fits supplied, but no metafits in {args.files}")
if len(file_groups[".metafits"]) > 1:
raise UserWarning(f"multiple metafits supplied in {args.files}")
metafits = file_groups[".metafits"][0]
base, _ = splitext(metafits)
raw_fits = file_groups[".fits"]
if len(raw_fits) > 1:
times = mwalib_get_common_times(metafits, raw_fits)
time_array = times.jd.astype(float)
# group and read raw by channel to save memory
raw_channel_groups = group_raw_by_channel(metafits, raw_fits)
for ch in sorted([*raw_channel_groups.keys()]):
ss_ = type(ss)()
ss_.read(
[metafits, *raw_channel_groups[ch]],
read_data=True,
times=time_array,
**read_kwargs,
)
if ss.data_array is None:
ss = ss_
else:
ss.__add__(ss_, inplace=True)
else:
ss.read([metafits, *raw_fits], read_data=True, **read_kwargs)
elif ".uvfits" in file_groups and ".ms" in file_groups:
raise UserWarning(f"both ms and uvfits in {args.files}")
elif ".uvfits" in file_groups or ".ms" in file_groups:
vis = file_groups.get(".uvfits", []) + file_groups.get(".ms", [])
base, _ = os.path.splitext(vis[0])
ss.read(vis, read_data=True, **read_kwargs)
else:
parser.print_usage()
exit(1)

unflagged_ants = get_unflagged_ants(ss, args)

select_kwargs = {
"inplace": False,
}
if args.spectrum_type == "cross":
select_kwargs["bls"] = [
(a, b)
for (a, b) in ss.get_antpairs()
if a != b and (b in unflagged_ants or a in unflagged_ants)
]
else:
select_kwargs["bls"] = [
(a, b) for (a, b) in ss.get_antpairs() if a == b and a in unflagged_ants
]
select_kwargs = {}
if args.sel_pols:
select_kwargs["polarizations"] = args.sel_pols

ss = ss.select(**select_kwargs)
ss.apply_flags(flag_choice="original")
ss.select(inplace=True, **select_kwargs)
ss.apply_flags(flag_choice=args.flag_choice)

plt.style.use("dark_background")
cmap = mpl.colormaps.get_cmap("viridis")
Expand Down

0 comments on commit 0fced77

Please sign in to comment.