Skip to content

Commit

Permalink
Merge pull request #7 from ratt-ru/corrflags
Browse files Browse the repository at this point in the history
Parallel plotting in surfchi2 + update python requires
  • Loading branch information
landmanbester authored Jun 24, 2024
2 parents 02a8ca1 + 4b3329d commit 9276bd3
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 160 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.10", "3.11"]

steps:
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version}}

- name: Checkout source
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
fetch-depth: 1

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
long_description_content_type="text/markdown",
url="https://github.com/ratt-ru/surfvis",
packages=find_packages(),
python_requires='>=3.7',
python_requires='>=3.10',
install_requires=requirements,
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
3 changes: 2 additions & 1 deletion surfvis/flagchi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,5 @@ def main():
columns=[options.fcol, 'FLAG_ROW'],
rechunk=True)

dask.compute(writes)
with ProgressBar():
dask.compute(writes)
265 changes: 111 additions & 154 deletions surfvis/surfchi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import dask
import dask.array as da
from dask.diagnostics import ProgressBar
from surfvis.utils import surfchisq
from surfvis.utils import surfchisq, surfchisq_plot
from daskms import xds_from_storage_ms as xds_from_ms
from daskms import xds_from_storage_table as xds_from_table
from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr
# might make for cooler histograms but doesn't work out of the box
from astropy.visualization import hist
from pathlib import Path
import concurrent.futures as cf


# COMMAND LINE OPTIONS
Expand Down Expand Up @@ -50,6 +52,10 @@ def create_parser():
def main():
(options,args) = create_parser().parse_args()

print('Input Options:')
for key, value in vars(options).items():
print(' %25s = %s' % (key, value))

if options.dataout == '':
options.dataout = os.getcwd() + '/chi2'

Expand Down Expand Up @@ -111,39 +117,37 @@ def main():

ridx = np.zeros(len(row_chunks))
ridx[1:] = np.cumsum(row_chunks)[0:-1]
rbin_idx.append(da.from_array(ridx.astype(int), chunks=1))
rbin_counts.append(da.from_array(row_chunks, chunks=1))
rbin_idx.append(ridx.astype(int))
rbin_counts.append(row_chunks)

ntime = ut.size
tidx = np.arange(0, ntime, utpc)
tbin_idx.append(da.from_array(tidx.astype(int), chunks=1))
tbin_idx.append(tidx.astype(int))
tidx2 = np.append(tidx, ntime)
tcounts = tidx2[1:] - tidx2[0:-1]
tbin_counts.append(da.from_array(tcounts, chunks=1))
tbin_counts.append(tcounts)

t0 = ut[tidx]
t0s.append(da.from_array(t0, chunks=1))
t0s.append(t0)
tf = ut[tidx + tcounts -1]
tfs.append(da.from_array(tf, chunks=1))
tfs.append(tf)

fidx = np.arange(0, nchan, options.nfreqs)
fbin_idx.append(da.from_array(fidx, chunks=1))
fbin_idx.append(fidx)
fidx2 = np.append(fidx, nchan)
fcounts = fidx2[1:] - fidx2[0:-1]
fbin_counts.append(da.from_array(fcounts, chunks=1))
fbin_counts.append(fcounts)

schema = {}
schema[options.rcol] = {'dims': ('chan', 'corr')}
schema[options.wcol] = {'dims': ('chan', 'corr')}
schema[options.fcol] = {'dims': ('chan', 'corr')}

xds = xds_from_ms(msname,
columns=[options.rcol, options.wcol, options.fcol,
'ANTENNA1', 'ANTENNA2', 'TIME'],
chunks=chunks,
group_cols=['FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'],
table_schema=schema)

columns=[options.rcol, options.wcol, options.fcol,'ANTENNA1', 'ANTENNA2', 'TIME'],
chunks=chunks,
group_cols=['FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'],
table_schema=schema)
if options.use_corrs is None:
print('Using only diagonal correlations')
if len(xds[0].corr) > 1:
Expand All @@ -155,137 +159,89 @@ def main():
print(f"Using correlations {use_corrs}")
ncorr = len(use_corrs)

out_ds = []
idts = []
for i, ds in enumerate(xds):
ds = ds.sel(corr=use_corrs)

resid = ds.get(options.rcol).data
if options.wcol == 'SIGMA_SPECTRUM':
weight = 1.0/ds.get(options.wcol).data**2
else:
weight = ds.get(options.wcol).data
flag = ds.get(options.fcol).data
ant1 = ds.ANTENNA1.data
ant2 = ds.ANTENNA2.data

# ncorr = resid.shape[0]

# time = ds.TIME.values
# utime = np.unique(time)

# spw = xds_from_table(msname + '::SPECTRAL_WINDOW')
# freq = spw[0].CHAN_FREQ.values

field = ds.FIELD_ID
ddid = ds.DATA_DESC_ID
scan = ds.SCAN_NUMBER

tmp = surfchisq(resid, weight, flag, ant1, ant2,
rbin_idx[i], rbin_counts[i],
fbin_idx[i], fbin_counts[i])

d = xr.Dataset(
data_vars={'data': (('time', 'freq', 'corr', 'p', 'q', '2'), tmp),
'fbin_idx': (('freq'), fbin_idx[i]),
'fbin_counts': (('freq'), fbin_counts[i]),
'tbin_idx': (('time'), tbin_idx[i]),
'tbin_counts': (('time'), tbin_counts[i])},
attrs = {'FIELD_ID': ds.FIELD_ID,
'DATA_DESC_ID': ds.DATA_DESC_ID,
'SCAN_NUMBER': ds.SCAN_NUMBER},
# coords={'time': (('time'), utime),
# 'freq': (('freq'), freq),
# 'corr': (('corr'), np.arange(ncorr))}
)

idt = f'::F{ds.FIELD_ID}_D{ds.DATA_DESC_ID}_S{ds.SCAN_NUMBER}'
out_ds.append(xds_to_zarr(d, options.dataout + idt))
idts.append(idt)


dask.compute(out_ds)

# primitive plotting
if options.imagesout is not None:
foldername = options.imagesout.rstrip('/')
if not os.path.isdir(foldername):
os.system('mkdir '+ foldername)

for idt in idts:
xds = xds_from_zarr(options.dataout + idt)
for ds in xds:
field = ds.FIELD_ID
if not os.path.isdir(foldername + f'/field{field}'):
os.system('mkdir '+ foldername + f'/field{field}')

spw = ds.DATA_DESC_ID
if not os.path.isdir(foldername + f'/field{field}' + f'/spw{spw}'):
os.system('mkdir '+ foldername + f'/field{field}' + f'/spw{spw}')

scan = ds.SCAN_NUMBER
if not os.path.isdir(foldername + f'/field{field}' + f'/spw{spw}' + f'/scan{scan}'):
os.system('mkdir '+ foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}')

tmp = ds.data.values
tbin_idx = ds.tbin_idx.values
tbin_counts = ds.tbin_counts.values
fbin_idx = ds.fbin_idx.values
fbin_counts = ds.fbin_counts.values

ntime, nfreq, ncorr, _, _, _ = tmp.shape

basename = foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}/'
if len(os.listdir(basename)):
print(f"Removing contents of {basename} folder")
os.system(f'rm {basename}*.png')
for t in range(ntime):
for f in range(nfreq):
for c in range(ncorr):
chi2 = tmp[t, f, c, :, :, 0]
N = tmp[t, f, c, :, :, 1]
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
t0 = tbin_idx[t]
tf = tbin_idx[t] + tbin_counts[t]
chan0 = fbin_idx[f]
chanf = fbin_idx[f] + fbin_counts[f]
makeplot(chi2_dof, basename + f't{t}_f{f}_c{c}.png',
f't {t0}-{tf}, chan {chan0}-{chanf}, corr {c}')

# reduce over corr
chi2 = np.nansum(tmp[t, f, (0, -1), :, :, 0], axis=0)
N = np.nansum(tmp[t, f, (0, -1), :, :, 1], axis=0)
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
t0 = tbin_idx[t]
tf = tbin_idx[t] + tbin_counts[t]
chan0 = fbin_idx[f]
chanf = fbin_idx[f] + fbin_counts[f]
makeplot(chi2_dof, basename + f't{t}_f{f}.png',
f't {t0}-{tf}, chan {chan0}-{chanf}')

# reduce over freq
chi2 = np.nansum(tmp[t, :, (0, -1), :, :, 0], axis=(0,1))
N = np.nansum(tmp[t, :, (0, -1), :, :, 1], axis=(0,1))
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
t0 = tbin_idx[t]
tf = tbin_idx[t] + tbin_counts[t]
makeplot(chi2_dof, basename + f't{t}.png',
f't {t0}-{tf}')

# now the entire scan
chi2 = np.nansum(tmp[:, :, (0, -1), :, :, 0], axis=(0,1,2))
N = np.nansum(tmp[:, :, (0, -1), :, :, 1], axis=(0,1,2))
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
makeplot(chi2_dof, basename + f'scan.png',
f'scan {scan}.png')
chi2s = {}
counts = {}
futures = []
foldername = options.imagesout.rstrip('/')
with cf.ProcessPoolExecutor(max_workers=options.nthreads) as executor:
for i, ds in enumerate(xds):
field = ds.FIELD_ID
spw = ds.DATA_DESC_ID
scan = ds.SCAN_NUMBER

basename = foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}/'
odir = Path(basename).resolve()
odir.mkdir(parents=True, exist_ok=True)

ntime = tbin_idx[i].size
nfreq = fbin_idx[i].size
ncorr = len(use_corrs)
for t in range(ntime):
for f in range(nfreq):
for c in range(ncorr):
t0 = tbin_idx[i][t]
tf = t0 + tbin_counts[i][t]
chan0 = fbin_idx[i][f]
chanf = chan0 + fbin_counts[i][f]
row0 = rbin_idx[i][t]
rowf = rbin_idx[i][t] + rbin_counts[i][t]
Inu = slice(chan0, chanf)
Irow = slice(row0, rowf)
dso = ds[{'row': Irow, 'chan': Inu}]
# import ipdb; ipdb.set_trace()
dso = dso.sel(corr=use_corrs)
resid = dso.get(options.rcol).data
if options.wcol == 'SIGMA_SPECTRUM':
weight = 1.0/dso.get(options.wcol).data**2
else:
weight = dso.get(options.wcol).data
flag = dso.get(options.fcol).data
ant1 = dso.ANTENNA1.data
ant2 = dso.ANTENNA2.data
t0 = tbin_idx[i][t]
tf = t0 + tbin_counts[i][t]
chan0 = fbin_idx[i][f]
chanf = chan0 + fbin_counts[i][f]
fut = executor.submit(surfchisq_plot, resid, weight, flag, ant1, ant2,
field, spw, scan,
basename + f't{t}_f{f}_c{c}.png',
f't {t0}-{tf}, chan {chan0}-{chanf}, corr {c}')
futures.append(fut)

# to reduce over time, freq and corr at the end
nant = np.maximum(ant1.compute().max(), ant2.compute().max()) + 1
chi2s[f'field{field}_spw{spw}_scan{scan}'] = np.zeros((nant, nant), dtype=float)
counts[f'field{field}_spw{spw}_scan{scan}'] = np.zeros((nant, nant), dtype=float)
print(f"Submitted field{field}_spw{spw}_scan{scan}")

# reduce per scan
num_completed = 0
num_futures = len(futures)
for fut in cf.as_completed(futures):
num_completed += 1
print(f"\rProcessing: {num_completed}/{num_futures}", end='', flush=True)
try:
field, spw, scan, chi2, count = fut.result()
chi2s[f'field{field}_spw{spw}_scan{scan}'] += chi2
counts[f'field{field}_spw{spw}_scan{scan}'] += count
except Exception as e:
raise e

# LB - is it worth doing this in parallel?
print("Plotting per scan")
for key, val in chi2s.items():
field, spw, scan = key.split('_')
field = field.strip('field')
spw = spw.strip('spw')
scan = scan.strip('scan')
count = counts[key]
chi2_dof = np.zeros_like(val)
chi2_dof[count>0] = val[count>0]/count[count>0]
chi2_dof[count<=0] = np.nan

basename = foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}/'
makeplot(chi2_dof, basename + f'combined.png',
f'scan {scan}.png')

def makeplot(data, name, subt):
nant, _ = data.shape
Expand All @@ -305,14 +261,15 @@ def makeplot(data, name, subt):

rax = divider.append_axes("right", size="50%", pad=0.025)
x = data[~ np.isnan(data)]
hist(x, bins='scott', ax=rax, histtype='stepfilled',
alpha=0.5, density=False)
rax.set_yticks([])
rax.tick_params(axis='y', which='both',
bottom=False, top=False,
labelbottom=False)
rax.tick_params(axis='x', which='both',
length=1, width=1, labelsize=8)
if x.any():
hist(x, bins='scott', ax=rax, histtype='stepfilled',
alpha=0.5, density=False)
rax.set_yticks([])
rax.tick_params(axis='y', which='both',
bottom=False, top=False,
labelbottom=False)
rax.tick_params(axis='x', which='both',
length=1, width=1, labelsize=8)

fig.suptitle(subt, fontsize=20)
plt.savefig(name, dpi=250)
Expand Down
Loading

0 comments on commit 9276bd3

Please sign in to comment.