Skip to content

Commit

Permalink
Merge pull request #278 from roman-corgi/crop
Browse files Browse the repository at this point in the history
Crop Step Function
  • Loading branch information
maxwellmb authored Jan 18, 2025
2 parents 6208510 + dc273b2 commit aad3489
Show file tree
Hide file tree
Showing 2 changed files with 347 additions and 0 deletions.
127 changes: 127 additions & 0 deletions corgidrp/l3_to_l4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from corgidrp import data
from scipy.ndimage import rotate as rotate_scipy # to avoid duplicated name
from scipy.ndimage import shift
import warnings
import numpy as np
import glob

Expand Down Expand Up @@ -36,6 +37,132 @@ def find_star(input_dataset):

return input_dataset.copy()

def crop(input_dataset,sizexy=60,centerxy=None):
"""
Crop the Images in a Dataset to a desired field of view. Default behavior is to
crop the image to the dark hole region, centered on the pixel intersection nearest
to the star location. Assumes 3D Image data is a stack of 2D data arrays, so only
crops the last two indices. Currently only configured for HLC mode.
TODO:
- Pad with nans if you try to crop outside the array (handle err & DQ too)
- Option to crop to an odd data array and center on a pixel?
Args:
input_dataset (corgidrp.data.Dataset): a dataset of Images (any level)
sizexy (int or array of int): desired frame size, if only one number is provided the
desired shape is assumed to be square, otherwise xy order. Defaults to 60.
centerxy (float or array of float): desired center (xy order), should be a pixel intersection (a.k.a
half-integer) otherwise the function rounds to the nearest intersection. Defaults to the
"STARLOCX/Y" header values.
Returns:
corgidrp.data.Dataset: a version of the input dataset cropped to the desired FOV.
"""

# Copy input dataset
dataset = input_dataset.copy()

# Require even data shape
if not np.all(np.array(sizexy)%2==0):
raise UserWarning('Even sizexy is required.')

# Need to loop over frames and reinit dataset because array sizes change
frames_out = []

for frame in dataset:
prihdr = frame.pri_hdr
exthdr = frame.ext_hdr
dqhdr = frame.dq_hdr
errhdr = frame.err_hdr

# Require that mode is HLC for now
if not prihdr['MODE'] == 'HLC':
raise UserWarning('Crop function is currently only configured for mode HLC.')

# Assign new array sizes and center location
frame_shape = frame.data.shape
if isinstance(sizexy,int):
sizexy = [sizexy]*2
if isinstance(centerxy,float):
centerxy = [centerxy] * 2
elif centerxy is None:
if ("STARLOCX" in exthdr.keys()) and ("STARLOCY" in exthdr.keys()):
centerxy = np.array([exthdr["STARLOCX"],exthdr["STARLOCY"]])
else: raise ValueError('centerxy not provided but STARLOCX/Y are missing from image extension header.')

# Round to centerxy to nearest half-pixel
centerxy = np.array(centerxy)
if not np.all((centerxy-0.5)%1 == 0):
old_centerxy = centerxy.copy()
centerxy = np.round(old_centerxy-0.5)+0.5
warnings.warn(f'Desired center {old_centerxy} is not at the intersection of 4 pixels. Centering on the nearest intersection {centerxy}')

# Crop the data
start_ind = (centerxy + 0.5 - np.array(sizexy)/2).astype(int)
end_ind = (centerxy + 0.5 + np.array(sizexy)/2).astype(int)
x1,y1 = start_ind
x2,y2 = end_ind

# Check if cropping outside the FOV
xleft_pad = -x1 if (x1<0) else 0
xrright_pad = x2-frame_shape[-1]+1 if (x2 > frame_shape[-1]) else 0
ybelow_pad = -y1 if (y1<0) else 0
yabove_pad = y2-frame_shape[-2]+1 if (y2 > frame_shape[-2]) else 0

if np.any(np.array([xleft_pad,xrright_pad,ybelow_pad,yabove_pad])> 0) :
raise ValueError("Trying to crop to a region outside the input data array. Not yet configured.")

if frame.data.ndim == 2:
cropped_frame_data = frame.data[y1:y2,x1:x2]
cropped_frame_err = frame.err[:,y1:y2,x1:x2]
cropped_frame_dq = frame.dq[y1:y2,x1:x2]
elif frame.data.ndim == 3:
cropped_frame_data = frame.data[:,y1:y2,x1:x2]
cropped_frame_err = frame.err[:,:,y1:y2,x1:x2]
cropped_frame_dq = frame.dq[:,y1:y2,x1:x2]
else:
raise ValueError('Crop function only supports 2D or 3D frame data.')

# Update headers
exthdr["NAXIS1"] = sizexy[0]
exthdr["NAXIS2"] = sizexy[1]
dqhdr["NAXIS1"] = sizexy[0]
dqhdr["NAXIS2"] = sizexy[1]
errhdr["NAXIS1"] = sizexy[0]
errhdr["NAXIS2"] = sizexy[1]
errhdr["NAXIS3"] = cropped_frame_err.shape[-3]
if frame.data.ndim == 3:
exthdr["NAXIS3"] = frame.data.shape[0]
dqhdr["NAXIS3"] = frame.dq.shape[0]
errhdr["NAXIS4"] = frame.err.shape[0]

updated_hdrs = []
if ("STARLOCX" in exthdr.keys()):
exthdr["STARLOCX"] -= x1
exthdr["STARLOCY"] -= y1
updated_hdrs.append('STARLOCX/Y')
if ("MASKLOCX" in exthdr.keys()):
exthdr["MASKLOCX"] -= x1
exthdr["MASKLOCY"] -= y1
updated_hdrs.append('MASKLOCX/Y')
if ("CRPIX1" in prihdr.keys()):
prihdr["CRPIX1"] -= x1
prihdr["CRPIX2"] -= y1
updated_hdrs.append('CRPIX1/2')
new_frame = data.Image(cropped_frame_data,prihdr,exthdr,cropped_frame_err,cropped_frame_dq,frame.err_hdr,frame.dq_hdr)
frames_out.append(new_frame)

output_dataset = data.Dataset(frames_out)

history_msg = f"""Frames cropped to new shape {output_dataset[0].data.shape} on center {centerxy}.\
Updated header kws: {", ".join(updated_hdrs)}."""

output_dataset.update_after_processing_step(history_msg)

return output_dataset

def do_psf_subtraction(input_dataset, reference_star_dataset=None):
"""
Expand Down
220 changes: 220 additions & 0 deletions tests/test_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import numpy as np
import pytest
from corgidrp.data import Dataset, Image
from corgidrp.l3_to_l4 import crop
from corgidrp.mocks import create_default_headers

def make_test_dataset(shape=[100,100],centxy=None):
"""
Make 2D or 3D test data.
Args:
shape (arraylike, optional): data shape. Defaults to [100,100].
centxy (arraylike,optional): location of 4 pixel dot. Defaults to center of array.
Returns:
corgidrp.data.Dataset: test data with a 2x2 "PSF" at location centxy.
"""
shape = np.array(shape)

test_arr = np.zeros(shape)
if centxy is None:
cent = np.array(shape)/2 - 0.5
else:
cent = [centxy[-i] for i in np.array(range(len(centxy)))+1]

prihdr,exthdr = create_default_headers()
exthdr['STARLOCX'] = cent[1]
exthdr['STARLOCY'] = cent[0]
exthdr['MASKLOCX'] = cent[1]
exthdr['MASKLOCY'] = cent[0]
exthdr['CRPIX1'] = cent[1] + 1
exthdr['CRPIX2'] = cent[0] + 1
prihdr['MODE'] = 'HLC'

if len(shape) == 2:
test_arr[int(cent[0]-0.5):int(cent[0]+1.5),int(cent[1]-0.5):int(cent[1]+1.5)] = 1

elif len(shape) == 3:
test_arr[:,int(cent[0]-0.5):int(cent[0]+1.5),int(cent[1]-0.5):int(cent[1]+1.5)] = 1

test_dataset = Dataset([Image(test_arr,prihdr,exthdr)])

return test_dataset

goal_arr = np.zeros((10,10))
goal_arr[4:6,4:6] = 1

goal_rect_arr = np.zeros((10,20))
goal_rect_arr[4:6,9:11] = 1

def test_2d_square_center_crop():
""" Test cropping to the center of a square using the header keywords "STARLOCX/Y".
"""

test_dataset = make_test_dataset(shape=[100,100],centxy=[49.5,49.5])
cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=None)

if not cropped_test_dataset[0].data == pytest.approx(goal_arr):
raise Exception("Unexpected result for 2D square crop test.")

def test_manual_center_crop():
""" Test overriding crop location using centerxy argument.
"""

test_dataset = make_test_dataset(shape=[100,100],centxy=[49.5,49.5])
cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=[50.5,50.5])

offset_goal_arr = np.zeros((10,10))
offset_goal_arr[3:5,3:5] = 1

if not cropped_test_dataset[0].data == pytest.approx(offset_goal_arr):
raise Exception("Unexpected result for manual crop test.")

def test_2d_square_offcenter_crop():
""" Test cropping off-center square data.
"""

test_dataset = make_test_dataset(shape=[100,100],centxy=[24.5,49.5])
cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=None)

if not cropped_test_dataset[0].data == pytest.approx(goal_arr):
raise Exception("Unexpected result for 2D square offcenter crop test.")

def test_2d_rect_offcenter_crop():
""" Test cropping off-center non-square data.
"""
test_dataset = make_test_dataset(shape=[100,40],centxy=[24.5,49.5])
cropped_test_dataset = crop(test_dataset,sizexy=[20,10],centerxy=None)

if not cropped_test_dataset[0].data == pytest.approx(goal_rect_arr):
raise Exception("Unexpected result for 2D rect offcenter crop test.")

def test_3d_rect_offcenter_crop():
""" Test cropping 3D off-center non-square data.
"""
test_dataset = make_test_dataset(shape=[3,100,40],centxy=[24.5,49.5])
cropped_test_dataset = crop(test_dataset,sizexy=[20,10],centerxy=None)

goal_rect_arr3d = np.array([goal_rect_arr,goal_rect_arr,goal_rect_arr])

if not cropped_test_dataset[0].data == pytest.approx(goal_rect_arr3d):
raise Exception("Unexpected result for 2D rect offcenter crop test.")


def test_edge_of_FOV():
""" Test cropping right at the edge of the data array.
"""
test_dataset = make_test_dataset(shape=[100,100],centxy=[94.5,94.5])
cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=None)

if not cropped_test_dataset[0].data == pytest.approx(goal_arr):
raise Exception("Unexpected result for edge of FOV crop test.")

def test_outside_FOV():
""" Test cropping over the edge of the data array.
"""

test_dataset = make_test_dataset(shape=[100,100],centxy=[95.5,95.5])

with pytest.raises(ValueError):
_ = crop(test_dataset,sizexy=10,centerxy=None)

def test_nonhalfinteger_centxy():
""" Test trying to center the crop not on a pixel intersection.
"""
test_dataset = make_test_dataset(shape=[100,100],centxy=[49.5,49.5])
cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=[49.7,49.7])

if not cropped_test_dataset[0].data == pytest.approx(goal_arr):
raise Exception("Unexpected result for non half-integer crop test.")

def test_header_updates_2d():
""" Test that the header values are updated correctly.
"""

test_dataset = make_test_dataset(shape=[100,100],centxy=[49.5,49.5])
test_dataset[0].ext_hdr["MASKLOCX"] = 49.5
test_dataset[0].ext_hdr["MASKLOCY"] = 49.5
test_dataset[0].pri_hdr["CRPIX1"] = 50.5
test_dataset[0].pri_hdr["CRPIX2"] = 50.5

cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=None)

if not cropped_test_dataset[0].ext_hdr["STARLOCX"] == 4.5:
raise Exception("Frame header kw STARLOCX not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["STARLOCY"] == 4.5:
raise Exception("Frame header kw STARLOCY not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["MASKLOCX"] == 4.5:
raise Exception("Frame header kw MASKLOCX not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["MASKLOCY"] == 4.5:
raise Exception("Frame header kw MASKLOCY not updated correctly.")
if not cropped_test_dataset[0].pri_hdr["CRPIX1"] == 5.5:
raise Exception("Frame header kw CRPIX1 not updated correctly.")
if not cropped_test_dataset[0].pri_hdr["CRPIX2"] == 5.5:
raise Exception("Frame header kw CRPIX2 not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["NAXIS1"] == 10:
raise Exception("Frame header kw NAXIS1 not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["NAXIS2"] == 10:
raise Exception("Frame header kw NAXIS2 not updated correctly.")
if not cropped_test_dataset[0].err_hdr["NAXIS1"] == 10:
raise Exception("Frame err header kw NAXIS1 not updated correctly.")
if not cropped_test_dataset[0].err_hdr["NAXIS2"] == 10:
raise Exception("Frame err header kw NAXIS2 not updated correctly.")
if not cropped_test_dataset[0].dq_hdr["NAXIS1"] == 10:
raise Exception("Frame dq header kw NAXIS1 not updated correctly.")
if not cropped_test_dataset[0].dq_hdr["NAXIS2"] == 10:
raise Exception("Frame dq header kw NAXIS2 not updated correctly.")

def test_header_updates_3d():
""" Test that the header values are updated correctly.
"""

test_dataset = make_test_dataset(shape=[3,100,100],centxy=[49.5,49.5])
test_dataset[0].ext_hdr["MASKLOCX"] = 49.5
test_dataset[0].ext_hdr["MASKLOCY"] = 49.5
test_dataset[0].pri_hdr["CRPIX1"] = 50.5
test_dataset[0].pri_hdr["CRPIX2"] = 50.5

cropped_test_dataset = crop(test_dataset,sizexy=10,centerxy=None)

if not cropped_test_dataset[0].ext_hdr["STARLOCX"] == 4.5:
raise Exception("Frame header kw STARLOCX not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["STARLOCY"] == 4.5:
raise Exception("Frame header kw STARLOCY not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["MASKLOCX"] == 4.5:
raise Exception("Frame header kw MASKLOCX not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["MASKLOCY"] == 4.5:
raise Exception("Frame header kw MASKLOCY not updated correctly.")
if not cropped_test_dataset[0].pri_hdr["CRPIX1"] == 5.5:
raise Exception("Frame header kw CRPIX1 not updated correctly.")
if not cropped_test_dataset[0].pri_hdr["CRPIX2"] == 5.5:
raise Exception("Frame header kw CRPIX2 not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["NAXIS1"] == 10:
raise Exception("Frame header kw NAXIS1 not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["NAXIS2"] == 10:
raise Exception("Frame header kw NAXIS2 not updated correctly.")
if not cropped_test_dataset[0].ext_hdr["NAXIS3"] == 3:
raise Exception("Frame header kw NAXIS3 not updated correctly.")
if not cropped_test_dataset[0].dq_hdr["NAXIS1"] == 10:
raise Exception("Frame dq header kw NAXIS1 not updated correctly.")
if not cropped_test_dataset[0].dq_hdr["NAXIS2"] == 10:
raise Exception("Frame dq header kw NAXIS2 not updated correctly.")
if not cropped_test_dataset[0].dq_hdr["NAXIS3"] == 3:
raise Exception("Frame dq header kw NAXIS3 not updated correctly.")
if not cropped_test_dataset[0].err_hdr["NAXIS1"] == 10:
raise Exception("Frame err header kw NAXIS1 not updated correctly.")
if not cropped_test_dataset[0].err_hdr["NAXIS2"] == 10:
raise Exception("Frame err header kw NAXIS2 not updated correctly.")
if not cropped_test_dataset[0].err_hdr["NAXIS3"] == 3:
raise Exception("Frame err header kw NAXIS3 not updated correctly.")

if __name__ == "__main__":
test_2d_square_center_crop()
test_2d_square_offcenter_crop()
test_2d_rect_offcenter_crop()
test_edge_of_FOV()
test_outside_FOV()
test_nonhalfinteger_centxy()
test_header_updates_2d()
test_header_updates_3d()

0 comments on commit aad3489

Please sign in to comment.