Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite abscal.match_times() to be much faster when dealing with very long lists of modelfiles #982

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions hera_cal/abscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2538,28 +2538,77 @@
Returns:
matched_modelfiles : type=list, list of modelfiles that overlap w/ datafile in LST
"""
# get lst arrays
data_dlst, data_dtime, data_lsts, data_times = io.get_file_times(datafile, filetype=filetype)
model_dlsts, model_dtimes, model_lsts, model_times = io.get_file_times(modelfiles, filetype=filetype)

# shift model files relative to first file & first index if needed
for ml in model_lsts:
ml[ml < model_lsts[0][0]] += 2 * np.pi
# also ensure that ml is increasing
ml[ml < ml[0]] += 2 * np.pi
# get model start and stop, buffering by dlst / 2
model_starts = np.asarray([ml[0] - md / 2.0 for ml, md in zip(model_lsts, model_dlsts)])
model_ends = np.asarray([ml[-1] + md / 2.0 for ml, md in zip(model_lsts, model_dlsts)])

# shift data relative to model if needed
data_lsts[data_lsts < model_starts[0]] += 2 * np.pi
# make sure monotonically increasing.
data_lsts = np.unwrap(data_lsts)
# select model files
match = np.asarray(modelfiles)[(model_starts < data_lsts[-1] + atol)
& (model_ends > data_lsts[0] - atol)]

return match
# get first modelfile time arrays
m0_dlst, m0_dtime, m0_lsts, m0_times = io.get_file_times(modelfiles[0], filetype=filetype)

# get data time arrays
data_dlst, _, data_lsts, _ = io.get_file_times(datafile, filetype=filetype)

def unwrap(lsts, branch_cut):
lsts[lsts < branch_cut] += 2 * np.pi
lsts[lsts < lsts[0]] += 2 * np.pi # also ensure that it's increasing internally
unwrap(data_lsts, m0_lsts[0] - m0_dlst / 2) # shift data relative to model

# There is an edge-case in which the first data lst is just below the first model lst
# but after the last model lst, and the other data lsts are above the first model lst.
# In this case, we will keep thinking we are too low in LST, so we move up til we've
# exhausted all the model files, but we should have been moving down.
mlast_dlst, _, mlast_lsts, _ = io.get_file_times(modelfiles[0], filetype=filetype)
if (
data_lsts[0] > mlast_lsts[-1] + mlast_dlst / 2
and data_lsts[-1] - 2 * np.pi > m0_lsts[0] - m0_dlst / 2
):
data_lsts -= 2 * np.pi

def lst_overlap(m_lsts, m_dlst):
def intervals_overlap(s1, e1, s2, e2):
return (s1 < e2) and (e1 > s2)
for shift in [0, -2 * np.pi, 2 * np.pi]:
if intervals_overlap(m_lsts[0] - m_dlst / 2 + shift, m_lsts[-1] + m_dlst / 2 + shift, data_lsts[0] - atol, data_lsts[-1] + atol):
return 'during'
if (m_lsts[-1] + m_dlst / 2) < (data_lsts[0] - atol):
return 'before'
else:
return 'after'

Check warning on line 2572 in hera_cal/abscal.py

View check run for this annotation

Codecov / codecov/patch

hera_cal/abscal.py#L2572

Added line #L2572 was not covered by tests

# binary search to find a matching file
idx = len(modelfiles) // 2
step_size = max([len(modelfiles) // 2, 1])
indicies_tried = set([])
while True:
# get lsts for model file at idx
if (idx >= len(modelfiles)) or (idx < 0) or (idx in indicies_tried):
return [] # no match found
m_dlst, _, m_lsts, _ = io.get_file_times(modelfiles[idx], filetype=filetype)
indicies_tried.add(idx)
unwrap(m_lsts, m0_lsts[0])

# figure out whether this model file overlaps with the data
overlap = lst_overlap(m_lsts, m_dlst)
if overlap == 'during':
break
else:
step_size = max([abs(step_size) // 2, 1]) * (1 if overlap == 'before' else -1)
idx += step_size

# now loop over neighboring files, going forward and backwards simultaneously, looking for additional overlaps
match = set([modelfiles[idx]])
for i in range(1, len(modelfiles)):
m1_dlst, _, m1_lsts, _ = io.get_file_times(modelfiles[(idx - i) % len(modelfiles)], filetype=filetype)
unwrap(m1_lsts, m0_lsts[0])
m2_dlst, _, m2_lsts, _ = io.get_file_times(modelfiles[(idx + i) % len(modelfiles)], filetype=filetype)
unwrap(m2_lsts, m0_lsts[0])
overlap1 = lst_overlap(m1_lsts, m1_dlst)
overlap2 = lst_overlap(m2_lsts, m2_dlst)
if overlap1 == 'during':
match.add(modelfiles[(idx - i) % len(modelfiles)])
if overlap2 == 'during':
match.add(modelfiles[(idx + i) % len(modelfiles)])
if (overlap1 != 'during') and (overlap2 != 'during'):
break

# return match in the same order as the files appeared in modelfiles
return sorted(match, key=modelfiles.index)


def cut_bls(datacontainer, bls=None, min_bl_cut=None, max_bl_cut=None, inplace=False):
Expand Down
3 changes: 3 additions & 0 deletions hera_cal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,9 @@ def get_file_times(filepaths, filetype='uvh5'):
time_array = time_array[baseline_array == most_common_bl_num]
lst_array = lst_array[baseline_array == most_common_bl_num]

# Handle phase wraps
lst_array[lst_array < lst_array[0]] += 2 * np.pi

# figure out dtime and dlst, handling the case where a diff cannot be done.
if len(time_array) > 1:
int_time = np.median(np.diff(time_array))
Expand Down
8 changes: 8 additions & 0 deletions hera_cal/tests/test_abscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,17 @@ def test_match_times(self):
# test basic execution
relevant_mfiles = abscal.match_times(dfiles[0], mfiles, filetype='miriad')
assert len(relevant_mfiles) == 2
relevant_mfiles = abscal.match_times(dfiles[0], mfiles[0:1], filetype='miriad')
assert len(relevant_mfiles) == 1
relevant_mfiles = abscal.match_times(dfiles[0], mfiles[1:2], filetype='miriad')
assert len(relevant_mfiles) == 1
# test basic execution
relevant_mfiles = abscal.match_times(dfiles[1], mfiles, filetype='miriad')
assert len(relevant_mfiles) == 1
relevant_mfiles = abscal.match_times(dfiles[1], mfiles[1:2], filetype='miriad')
assert len(relevant_mfiles) == 1
relevant_mfiles = abscal.match_times(dfiles[1], mfiles[0:1], filetype='miriad')
assert len(relevant_mfiles) == 0
# test no overlap
mfiles = sorted(glob.glob(os.path.join(DATA_PATH, 'zen.2457698.40355.xx.HH.uvcA')))
relevant_mfiles = abscal.match_times(dfiles[0], mfiles, filetype='miriad')
Expand Down
Loading