diff --git a/corgidrp/data.py b/corgidrp/data.py index f1afd7c9..7217c690 100644 --- a/corgidrp/data.py +++ b/corgidrp/data.py @@ -1,9 +1,14 @@ import os +import warnings import numpy as np import numpy.ma as ma import astropy.io.fits as fits import astropy.time as time import pandas as pd +import pyklip +from pyklip.instruments.Instrument import Data as pyKLIP_Data +from pyklip.instruments.utils.wcsgen import generate_wcs +from astropy import wcs import copy import corgidrp @@ -144,19 +149,19 @@ def add_error_term(self, input_error, err_name): Updates Dataset.all_err. Args: - input_error (np.array): 2-d or 3-d error layer + input_error (np.array): per-frame or per-dataset error layer err_name (str): name of the uncertainty layer """ - if input_error.ndim == 3: + if input_error.ndim == self.all_data.ndim: for i,frame in enumerate(self.frames): frame.add_error_term(input_error[i], err_name) - elif input_error.ndim ==2: + elif input_error.ndim == self.all_data.ndim - 1: for frame in self.frames: frame.add_error_term(input_error, err_name) else: - raise ValueError("input_error is not either a 2D or 3D array.") + raise ValueError("input_error is not either a 2D or 3D array for 2D data, or a 3D or 4D array for 3D data.") # Preserve pointer links between Dataset.all_err and Image.err self.all_err = np.array([frame.err for frame in self.frames]) @@ -236,9 +241,6 @@ def split_dataset(self, prihdr_keywords=None, exthdr_keywords=None): return split_datasets, unique_vals - - - class Image(): """ Base class for 2-D image data. Data can be created by passing in the data/header explicitly, or @@ -286,8 +288,8 @@ def __init__(self, data_or_filepath, pri_hdr=None, ext_hdr=None, err = None, dq if err is not None: if np.shape(self.data) != np.shape(err)[-self.data.ndim:]: raise ValueError("The shape of err is {0} while we are expecting shape {1}".format(err.shape[-self.data.ndim:], self.data.shape)) - #we want to have a 3 dim error array - if err.ndim > 2: + #we want to have an extra dimension in the error array + if err.ndim == self.data.ndim+1: self.err = err else: self.err = err.reshape((1,)+err.shape) @@ -295,7 +297,7 @@ def __init__(self, data_or_filepath, pri_hdr=None, ext_hdr=None, err = None, dq err_hdu = hdulist.pop("ERR") self.err = err_hdu.data self.err_hdr = err_hdu.header - if self.err.ndim == 2: + if self.err.ndim == self.data.ndim: self.err = self.err.reshape((1,)+self.err.shape) else: self.err = np.zeros((1,)+self.data.shape) @@ -348,7 +350,7 @@ def __init__(self, data_or_filepath, pri_hdr=None, ext_hdr=None, err = None, dq if np.shape(self.data) != np.shape(err)[-self.data.ndim:]: raise ValueError("The shape of err is {0} while we are expecting shape {1}".format(err.shape[-self.data.ndim:], self.data.shape)) #we want to have a 3 dim error array - if err.ndim > 2: + if err.ndim == self.data.ndim + 1: self.err = err else: self.err = err.reshape((1,)+err.shape) @@ -516,7 +518,7 @@ def get_masked_data(self): def add_error_term(self, input_error, err_name): """ - Add a layer of a specific additive uncertainty on the 3-dim error array extension + Add a layer of a specific additive uncertainty on the 3- or 4-dim error array extension and update the combined uncertainty in the first layer. Update the error header and assign the error name. @@ -524,18 +526,22 @@ def add_error_term(self, input_error, err_name): in the configuration file Args: - input_error (np.array): 2-d error layer + input_error (np.array): error layer with same shape as data err_name (str): name of the uncertainty layer """ - if input_error.ndim != 2 or input_error.shape != self.data.shape: - raise ValueError("we expect a 2-dimensional error layer with dimensions {0}".format(self.data.shape)) + ndim = self.data.ndim + if not (input_error.ndim==2 or input_error.ndim==3) or input_error.shape != self.data.shape: + raise ValueError("we expect a 2-dimensional or 3-dimensional error layer with dimensions {0}".format(self.data.shape)) #first layer is always the updated combined error - self.err[0,:,:] = np.sqrt(self.err[0,:,:]**2 + input_error**2) + if ndim == 2: + self.err[0,:,:] = np.sqrt(self.err[0,:,:]**2 + input_error**2) + elif ndim == 3: + self.err[0,:,:,:] = np.sqrt(self.err[0,:,:,:]**2 + input_error**2) self.err_hdr["Layer_1"] = "combined_error" if corgidrp.track_individual_errors: - #append new error as layer on 3D cube + #append new error as layer on 3D or 4D cube self.err=np.append(self.err, [input_error], axis=0) layer = str(self.err.shape[0]) @@ -601,7 +607,6 @@ def add_extension_hdu(self, name, data = None, header=None): self.hdu_names.append(name) self.hdu_list.append(new_hdu) - class Dark(Image): """ Dark calibration frame for a given exposure time and EM gain. @@ -659,7 +664,6 @@ def __init__(self, data_or_filepath, pri_hdr=None, ext_hdr=None, input_dataset=N if self.ext_hdr['DATATYPE'] != 'Dark': raise ValueError("File that was loaded was not a Dark file.") - class FlatField(Image): """ Master flat generated from raster scan of uranus or Neptune. @@ -700,7 +704,6 @@ def __init__(self, data_or_filepath, pri_hdr=None, ext_hdr=None, input_dataset=N if self.ext_hdr['DATATYPE'] != 'FlatField': raise ValueError("File that was loaded was not a FlatField file.") - class NonLinearityCalibration(Image): """ Class for non-linearity calibration files. Although it's not strictly an image that you might look at, it is a 2D array of data @@ -918,7 +921,6 @@ def save(self, filedir=None, filename=None): hdulist.writeto(self.filepath, overwrite=True) hdulist.close() - class BadPixelMap(Image): """ Class for bad pixel map. The bad pixel map indicates which pixels are hot @@ -965,7 +967,6 @@ def __init__(self, data_or_filepath, pri_hdr=None, ext_hdr=None, input_dataset=N self.dq_hdr['COMMENT'] = 'DQ not meaningful for this calibration; just present for class consistency' self.err_hdr['COMMENT'] = 'err not meaningful for this calibration; just present for class consistency' - class DetectorNoiseMaps(Image): """ Class for DetectorNoiseMaps calibration file. The data is a 3-D stack of 3 frames, each of which is a full SCI frame of fitted @@ -1370,6 +1371,413 @@ def __init__(self, data_or_filepath, err = None, pri_hdr=None, ext_hdr=None, err self.filedir = "." self.filename = "{0}_FluxcalFactor_{1}_{2}.fits".format(orig_input_filename, self.filter, self.nd_filter) +class PyKLIPDataset(pyKLIP_Data): + """ + A pyKLIP instrument class for Roman Coronagraph Instrument data. + + # TODO: Add more bandpasses, modes to self.wave_hlc + # Add wcs header info! + + Attrs: + input: Input corgiDRP dataset. + centers: Star center locations. + filenums: file numbers. + filenames: file names. + PAs: position angles. + wvs: wavelengths. + wcs: WCS header information. Currently None. + IWA: inner working angle. + OWA: outer working angle. + psflib: corgiDRP dataset containing reference PSF observations. + output: PSF subtracted pyKLIP dataset + + """ + + #################### + ### Constructors ### + #################### + + def __init__(self, + dataset, + psflib_dataset=None, + highpass=False): + """ + Initialize the pyKLIP instrument class for space telescope data. + # TODO: Determine inner working angle based on PAM positions + - Inner working angle based on Focal plane mask (starts with HLC) + color filter ('1F') for primary mode + - Outer working angle based on field stop? (should be R1C1 or R1C3 for primary mode) + + Args: + dataset (corgidrp.data.Dataset): + Dataset containing input science observations. + psflib_dataset (corgidrp.data.Dataset, optional): + Dataset containing input reference observations. The default is None. + highpass (bool, optional): + Toggle to do highpass filtering. Defaults fo False. + """ + + # Initialize pyKLIP Data class. + super(PyKLIPDataset, self).__init__() + + # Set filter wavelengths + self.wave_hlc = {'1F': 575e-9} # meters + + # Read science and reference files. + self.readdata(dataset, psflib_dataset, highpass) + + pass + + ################################ + ### Instance Required Fields ### + ################################ + + @property + def input(self): + return self._input + @input.setter + def input(self, newval): + self._input = newval + + @property + def centers(self): + return self._centers + @centers.setter + def centers(self, newval): + self._centers = newval + + @property + def filenums(self): + return self._filenums + @filenums.setter + def filenums(self, newval): + self._filenums = newval + + @property + def filenames(self): + return self._filenames + @filenames.setter + def filenames(self, newval): + self._filenames = newval + + @property + def PAs(self): + return self._PAs + @PAs.setter + def PAs(self, newval): + self._PAs = newval + + @property + def wvs(self): + return self._wvs + @wvs.setter + def wvs(self, newval): + self._wvs = newval + + @property + def wcs(self): + return self._wcs + @wcs.setter + def wcs(self, newval): + self._wcs = newval + + @property + def IWA(self): + return self._IWA + @IWA.setter + def IWA(self, newval): + self._IWA = newval + + @property + def OWA(self): + return self._OWA + @OWA.setter + def OWA(self, newval): + self._OWA = newval + + @property + def psflib(self): + return self._psflib + @psflib.setter + def psflib(self, newval): + self._psflib = newval + + @property + def output(self): + return self._output + @output.setter + def output(self, newval): + self._output = newval + + ############### + ### Methods ### + ############### + + def readdata(self, + dataset, + psflib_dataset, + highpass=False): + """ + Read the input science observations. + + Args: + dataset (corgidrp.data.Dataset): + Dataset containing input science observations. + psflib_dataset (corgidrp.data.Dataset, optional): + Dataset containing input reference observations. The default is None. + highpass (bool, optional): + Toggle to do highpass filtering. Defaults fo False. + """ + + # Check input. + if not isinstance(dataset, corgidrp.data.Dataset): + raise UserWarning('Input dataset is not a corgidrp Dataset object.') + if len(dataset) == 0: + raise UserWarning('No science frames in the input dataset.') + + if not psflib_dataset is None: + if not isinstance(psflib_dataset, corgidrp.data.Dataset): + raise UserWarning('Input psflib_dataset is not a corgidrp Dataset object.') + + # Loop through frames. + input_all = [] + centers_all = [] # pix + filenames_all = [] + PAs_all = [] # deg + wvs_all = [] # m + wcs_all = [] + PIXSCALE = [] # arcsec + + psflib_data_all = [] + psflib_centers_all = [] # pix + psflib_filenames_all = [] + + # Iterate over frames in dataset + + for i, frame in enumerate(dataset): + + phead = frame.pri_hdr + shead = frame.ext_hdr + + TELESCOP = phead['TELESCOP'] + INSTRUME = phead['INSTRUME'] + CFAMNAME = shead['CFAMNAME'] + data = frame.data + if data.ndim == 2: + data = data[np.newaxis, :] + if data.ndim != 3: + raise UserWarning('Requires 2D/3D data cube') + NINTS = data.shape[0] + pix_scale = shead['PLTSCALE'] * 1000. # arcsec + PIXSCALE += [pix_scale] + + # Get centers. + centers = np.array([shead['STARLOCX'], shead['STARLOCY']] * NINTS) + + # Get metadata. + input_all += [data] + centers_all += [centers] + filenames_all += [os.path.split(phead['FILENAME'])[1] + '_INT%.0f' % (j + 1) for j in range(NINTS)] + PAs_all += [shead['ROLL']] * NINTS + + if TELESCOP != "ROMAN" or INSTRUME != "CGI": + raise UserWarning('Data is not from Roman Space Telescope Coronagraph Instrument.') + + # Get center wavelengths + try: + CWAVEL = self.wave_hlc[CFAMNAME] + except: + raise UserWarning(f'CFAM position {CFAMNAME} is not configured in corgidrp.data.PyKLIPDataset .') + + # Rounding error introduced here? + wvs_all += [CWAVEL] * NINTS + + # pyklip will look for wcs.cd, so make sure that attribute exists + wcs_obj = wcs.WCS(header=shead, naxis=shead['WCSAXES']) + + if not hasattr(wcs_obj.wcs,'cd'): + wcs_obj.wcs.cd = wcs_obj.wcs.pc * wcs_obj.wcs.cdelt + + for j in range(NINTS): + wcs_all += [wcs_obj.deepcopy()] + + try: + input_all = np.concatenate(input_all) + except ValueError: + raise UserWarning('Unable to concatenate images. Some science files do not have matching image shapes') + centers_all = np.concatenate(centers_all).reshape(-1, 2) + filenames_all = np.array(filenames_all) + filenums_all = np.array(range(len(filenames_all))) + PAs_all = np.array(PAs_all) + wvs_all = np.array(wvs_all).astype(np.float32) + wcs_all = np.array(wcs_all) + PIXSCALE = np.unique(np.array(PIXSCALE)) + if len(PIXSCALE) != 1: + raise UserWarning('Some science files do not have matching pixel scales') + iwa_all = np.min(wvs_all) / 6.5 * 180. / np.pi * 3600. / PIXSCALE[0] # pix + owa_all = np.sum(np.array(input_all.shape[1:]) / 2.) # pix + + # Recenter science images so that the star is at the center of the array. + new_center = (np.array(data.shape[1:])-1)/ 2. + new_center = new_center[::-1] + for i, image in enumerate(input_all): + recentered_image = pyklip.klip.align_and_scale(image, new_center=new_center, old_center=centers_all[i]) + input_all[i] = recentered_image + centers_all[i] = new_center + + # Assign pyKLIP variables. + self._input = input_all + self._centers = centers_all + self._filenames = filenames_all + self._filenums = filenums_all + self._PAs = PAs_all + self._wvs = wvs_all + self._wcs = wcs_all + self._IWA = iwa_all + self._OWA = owa_all + + # Prepare reference library + if not psflib_dataset is None: + psflib_data_all = [] + psflib_centers_all = [] # pix + psflib_filenames_all = [] + + for i, frame in enumerate(psflib_dataset): + + phead = frame.pri_hdr + shead = frame.ext_hdr + + data = frame.data + if data.ndim == 2: + data = data[np.newaxis, :] + if data.ndim != 3: + raise UserWarning('Requires 2D/3D data cube') + NINTS = data.shape[0] + pix_scale = shead['PLTSCALE'] * 1000. # arcsec + PIXSCALE += [pix_scale] + + # Get centers. + centers = np.array([shead['STARLOCX'], shead['STARLOCY']] * NINTS) + + psflib_data_all += [data] + psflib_centers_all += [centers] + psflib_filenames_all += [os.path.split(phead['FILENAME'])[1] + '_INT%.0f' % (j + 1) for j in range(NINTS)] + + psflib_data_all = np.concatenate(psflib_data_all) + if psflib_data_all.ndim != 3: + raise UserWarning('Some reference files do not have matching image shapes') + psflib_centers_all = np.concatenate(psflib_centers_all).reshape(-1, 2) + psflib_filenames_all = np.array(psflib_filenames_all) + + # Recenter reference images. + new_center = (np.array(data.shape[1:])-1)/ 2. + new_center = new_center[::-1] + for i, image in enumerate(psflib_data_all): + recentered_image = pyklip.klip.align_and_scale(image, new_center=new_center, old_center=psflib_centers_all[i]) + psflib_data_all[i] = recentered_image + psflib_centers_all[i] = new_center + + # Append science data. + psflib_data_all = np.append(psflib_data_all, self._input, axis=0) + psflib_centers_all = np.append(psflib_centers_all, self._centers, axis=0) + psflib_filenames_all = np.append(psflib_filenames_all, self._filenames, axis=0) + + # Initialize PSF library. + psflib = pyklip.rdi.PSFLibrary(psflib_data_all, new_center, psflib_filenames_all, compute_correlation=True, highpass=highpass) + + # Prepare PSF library. + psflib.prepare_library(self) + + # Assign pyKLIP variables. + self._psflib = psflib + + else: + self._psflib = None + + pass + + def savedata(self, + filepath, + data, + klipparams=None, + filetype='', + zaxis=None, + more_keywords=None): + """ + Function to save the data products that will be called internally by + pyKLIP. + + Args: + filepath (path): + Path of the output FITS file. + data (3D-array): + KLIP-subtracted data of shape (nkl, ny, nx). + klipparams (str, optional): + PyKLIP keyword arguments used for the KLIP subtraction. The default + is None. + filetype (str, optional): + Data type of the pyKLIP product. The default is ''. + zaxis (list, optional): + List of KL modes used for the KLIP subtraction. The default is + None. + more_keywords (dict, optional): + Dictionary of additional header keywords to be written to the + output FITS file. The default is None. + """ + + # Make FITS file. + hdul = fits.HDUList() + hdul.append(fits.PrimaryHDU(data)) + + # Write all used files to header. Ignore duplicates. + filenames = np.unique(self.filenames) + Nfiles = np.size(filenames) + hdul[0].header['DRPNFILE'] = (Nfiles, 'Num raw files used in pyKLIP') + for i, filename in enumerate(filenames): + if i < 1000: + hdul[0].header['FILE_{0}'.format(i)] = filename + '.fits' + else: + print('WARNING: Too many files to be written to header, skipping') + break + + # Write PSF subtraction parameters and pyKLIP version to header. + try: + pyklipver = pyklip.__version__ + except: + pyklipver = 'unknown' + hdul[0].header['PSFSUB'] = ('pyKLIP', 'PSF Subtraction Algo') + hdul[0].header.add_history('Reduced with pyKLIP using commit {0}'.format(pyklipver)) + hdul[0].header['CREATOR'] = 'pyKLIP-{0}'.format(pyklipver) + hdul[0].header['pyklipv'] = (pyklipver, 'pyKLIP version that was used') + if klipparams is not None: + hdul[0].header['PSFPARAM'] = (klipparams, 'KLIP parameters') + hdul[0].header.add_history('pyKLIP reduction with parameters {0}'.format(klipparams)) + + # Write z-axis units to header if necessary. + if zaxis is not None: + if 'KL Mode' in filetype: + hdul[0].header['CTYPE3'] = 'KLMODES' + for i, klmode in enumerate(zaxis): + hdul[0].header['KLMODE{0}'.format(i)] = (klmode, 'KL Mode of slice {0}'.format(i)) + + # Write extra keywords to header if necessary. + if more_keywords is not None: + for hdr_key in more_keywords: + hdul[0].header[hdr_key] = more_keywords[hdr_key] + + # Update image center. + center = self.output_centers[0] + hdul[0].header.update({'PSFCENTX': center[0], 'PSFCENTY': center[1]}) + hdul[0].header.update({'CRPIX1': center[0] + 1, 'CRPIX2': center[1] + 1}) + hdul[0].header.add_history('Image recentered to {0}'.format(str(center))) + + # Write FITS file. + try: + hdul.writeto(filepath, overwrite=True) + except TypeError: + hdul.writeto(filepath, clobber=True) + hdul.close() + + pass datatypes = { "Image" : Image, "Dark" : Dark, diff --git a/corgidrp/detector.py b/corgidrp/detector.py index d8b8dbfc..6883a526 100644 --- a/corgidrp/detector.py +++ b/corgidrp/detector.py @@ -817,3 +817,44 @@ def create_onsky_flatfield(dataset, planet=None,band=None,up_radius=55,im_size=N return(onsky_flatfield) +def nan_flags(dataset,threshold=1): + """Replaces each DQ-flagged pixel (>= the given threshold) in the dataset with np.nan. + + Args: + dataset (corgidrp.data.Dataset): input dataset. + threshold (int, optional): DQ threshold to replace with nans. Defaults to 1. + + Returns: + corgidrp.data.Dataset: dataset with flagged pixels replaced. + """ + + dataset_out = dataset.copy() + + # mask bad pixels + bad = np.where(dataset_out.all_dq >= threshold) + dataset_out.all_data[bad] = np.nan + + new_error = np.zeros_like(dataset_out.all_data) + new_error[bad] = np.nan + dataset_out.add_error_term(new_error, 'DQ flagged') + + return dataset_out + +def flag_nans(dataset,flag_val=1): + """Assigns a DQ flag to each nan pixel in the dataset. + + Args: + dataset (corgidrp.data.Dataset): input dataset. + flag_val (int, optional): DQ value to assign. Defaults to 1. + + Returns: + corgidrp.data.Dataset: dataset with nan values flagged. + """ + + dataset_out = dataset.copy() + + # mask bad pixels + bad = np.isnan(dataset_out.all_data) + dataset_out.all_dq[bad] = flag_val + + return dataset_out diff --git a/corgidrp/l3_to_l4.py b/corgidrp/l3_to_l4.py index c2c2abff..7641df4d 100644 --- a/corgidrp/l3_to_l4.py +++ b/corgidrp/l3_to_l4.py @@ -2,11 +2,16 @@ from pyklip.klip import rotate from corgidrp import data +from corgidrp.detector import flag_nans,nan_flags from scipy.ndimage import rotate as rotate_scipy # to avoid duplicated name from scipy.ndimage import shift import warnings import numpy as np import glob +import pyklip.rdi +import os +from astropy.io import fits +import warnings def distortion_correction(input_dataset, distortion_calibration): """ @@ -33,6 +38,7 @@ def find_star(input_dataset): Returns: corgidrp.data.Dataset: a version of the input dataset with the stars identified + in ext_hdr["STARLOCX/Y"] """ return input_dataset.copy() @@ -78,9 +84,9 @@ def crop(input_dataset,sizexy=None,centerxy=None): dqhdr = frame.dq_hdr errhdr = frame.err_hdr - # Pick default crop size based on the size of the effective field of view (determined by the Lyot stop) + # Pick default crop size based on the size of the effective field of view if sizexy is None: - if prihdr['LSAMNAME'] == 'NFOV': + if exthdr['LSAMNAME'] == 'NFOV': sizexy = 60 else: raise UserWarning('Crop function is currently only configured for NFOV (narrow field-of-view) observations if sizexy is not provided.') @@ -160,27 +166,152 @@ def crop(input_dataset,sizexy=None,centerxy=None): output_dataset = data.Dataset(frames_out) - history_msg = f"""Frames cropped to new shape {output_dataset[0].data.shape} on center {centerxy}.\ - Updated header kws: {", ".join(updated_hdrs)}.""" - - output_dataset.update_after_processing_step(history_msg) + history_msg1 = f"""Frames cropped to new shape {list(output_dataset[0].data.shape)} on center {list(centerxy)}. Updated header kws: {", ".join(updated_hdrs)}.""" + output_dataset.update_after_processing_step(history_msg1) return output_dataset -def do_psf_subtraction(input_dataset, reference_star_dataset=None): +def do_psf_subtraction(input_dataset, reference_star_dataset=None, + mode=None, annuli=1,subsections=1,movement=1, + numbasis=[1,4,8,16],outdir='KLIP_SUB',fileprefix="", + do_crop=True, + crop_sizexy=None + ): """ Perform PSF subtraction on the dataset. Optionally using a reference star dataset. - + TODO: + Handle nans & propagate DQ array + What info is missing from output dataset headers? + Add comments to new ext header cards + Args: input_dataset (corgidrp.data.Dataset): a dataset of Images (L3-level) - reference_star_dataset (corgidrp.data.Dataset): a dataset of Images of the reference star [optional] + reference_star_dataset (corgidrp.data.Dataset, optional): a dataset of Images of the reference + star [optional] + mode (str, optional): pyKLIP PSF subraction mode, e.g. ADI/RDI/ADI+RDI. Mode will be chosen autonomously + if not specified. + annuli (int, optional): number of concentric annuli to run separate subtractions on. Defaults to 1. + subsections (int, optional): number of angular subsections to run separate subtractions on. Defaults to 1. + movement (int, optional): KLIP movement parameter. Defaults to 1. + numbasis (int or list of int, optional): number of KLIP modes to retain. Defaults to [1,4,8,16]. + outdir (str or path, optional): path to output directory. Defaults to "KLIP_SUB". + fileprefix (str, optional): prefix of saved output files. Defaults to "". + do_crop (bool): whether to crop data before PSF subtraction. Defaults to True. + crop_sizexy (list of int, optional): Desired size to crop the images to before PSF subtraction. Defaults to + None, which results in the step choosing a crop size based on the imaging mode. Returns: - corgidrp.data.Dataset: a version of the input dataset with the PSF subtraction applied + corgidrp.data.Dataset: a version of the input dataset with the PSF subtraction applied (L4-level) + """ - return input_dataset.copy() + sci_dataset = input_dataset.copy() + + # Use input reference dataset if provided + if not reference_star_dataset is None: + ref_dataset = reference_star_dataset.copy() + + # Try getting PSF references via the "PSFREF" header kw + else: + split_datasets, unique_vals = sci_dataset.split_dataset(prihdr_keywords=["PSFREF"]) + unique_vals = np.array(unique_vals) + + if 0. in unique_vals: + sci_dataset = split_datasets[int(np.nonzero(np.array(unique_vals) == 0)[0])] + else: + raise UserWarning('No science files found in input dataset.') + + if 1. in unique_vals: + ref_dataset = split_datasets[int(np.nonzero(np.array(unique_vals) == 1)[0])] + else: + ref_dataset = None + + assert len(sci_dataset) > 0, "Science dataset has no data." + + # Choose PSF subtraction mode if unspecified + if mode is None: + + if not ref_dataset is None and len(sci_dataset)==1: + mode = 'RDI' + elif not ref_dataset is None: + mode = 'ADI+RDI' + else: + mode = 'ADI' + + else: assert mode in ['RDI','ADI+RDI','ADI'], f"Mode {mode} is not configured." + + # Format numbases + if isinstance(numbasis,int): + numbasis = [numbasis] + + # Set up outdir + outdir = os.path.join(outdir,mode) + if not os.path.exists(outdir): + os.makedirs(outdir) + + # Crop data + if do_crop: + sci_dataset = crop(sci_dataset,sizexy=crop_sizexy) + ref_dataset = None if ref_dataset is None else crop(ref_dataset,sizexy=crop_sizexy) + + # Mask data where DQ > 0, let pyklip deal with the nans + sci_dataset_masked = nan_flags(sci_dataset) + ref_dataset_masked = None if ref_dataset is None else nan_flags(ref_dataset) + + # Run pyklip + pyklip_dataset = data.PyKLIPDataset(sci_dataset_masked,psflib_dataset=ref_dataset_masked) + pyklip.parallelized.klip_dataset(pyklip_dataset, outputdir=outdir, + annuli=annuli, subsections=subsections, movement=movement, numbasis=numbasis, + calibrate_flux=False, mode=mode,psf_library=pyklip_dataset._psflib, + fileprefix=fileprefix) + + # Construct corgiDRP dataset from pyKLIP result + result_fpath = os.path.join(outdir,f'{fileprefix}-KLmodes-all.fits') + pyklip_data = fits.getdata(result_fpath) + pyklip_hdr = fits.getheader(result_fpath) + + # TODO: Handle errors correctly + err = np.zeros([1,*pyklip_data.shape]) + dq = np.zeros_like(pyklip_data) # This will get filled out later + + # Collapse sci_dataset headers + pri_hdr = sci_dataset[0].pri_hdr.copy() + ext_hdr = sci_dataset[0].ext_hdr.copy() + + # Add relevant info from the pyklip headers: + skip_kws = ['PSFCENTX','PSFCENTY','CREATOR','CTYPE3'] + for kw, val, comment in pyklip_hdr._cards: + if not kw in skip_kws: + ext_hdr.set(kw,val,comment) + + # Record KLIP algorithm explicitly + pri_hdr.set('KLIP_ALG',mode) + + # Add info from pyklip to ext_hdr + ext_hdr['STARLOCX'] = pyklip_hdr['PSFCENTX'] + ext_hdr['STARLOCY'] = pyklip_hdr['PSFCENTY'] + + if "HISTORY" in sci_dataset[0].ext_hdr.keys(): + history_str = str(sci_dataset[0].ext_hdr['HISTORY']) + ext_hdr['HISTORY'] = ''.join(history_str.split('\n')) + + # Construct Image and Dataset object + frame = data.Image(pyklip_data, + pri_hdr=pri_hdr, ext_hdr=ext_hdr, + err=err, dq=dq) + + dataset_out = data.Dataset([frame]) + + # Flag nans in the dq array and then add nans to the error array + dataset_out = flag_nans(dataset_out,flag_val=1) + dataset_out = nan_flags(dataset_out,threshold=1) + + history_msg = f'PSF subtracted via pyKLIP {mode}.' + + dataset_out.update_after_processing_step(history_msg) + + return dataset_out def northup(input_dataset,correct_wcs=False): """ @@ -212,8 +343,9 @@ def northup(input_dataset,correct_wcs=False): # define the center for rotation try: - xcen, ycen = im_hd['PSFCENTX'], im_hd['PSFCENTY'] # TBU, after concluding the header keyword + xcen, ycen = im_hd['STARLOCX'], im_hd['STARLOCY'] except KeyError: + warnings.warn('"STARLOCX/Y" missing from ext_hdr. Rotating about center of array.') xcen, ycen = xlen/2, ylen/2 # look for WCS solutions diff --git a/corgidrp/mocks.py b/corgidrp/mocks.py index 95ebe0be..d7338bd9 100644 --- a/corgidrp/mocks.py +++ b/corgidrp/mocks.py @@ -18,8 +18,10 @@ import corgidrp.detector as detector from corgidrp.detector import imaging_area_geom, unpack_geom from corgidrp.pump_trap_calibration import (P1, P1_P1, P1_P2, P2, P2_P2, P3, P2_P3, P3_P3, tau_temp) +from pyklip.instruments.utils.wcsgen import generate_wcs from corgidrp.data import DetectorParams + from emccd_detect.emccd_detect import EMCCDDetect from emccd_detect.util.read_metadata_wrapper import MetadataWrapper @@ -278,7 +280,6 @@ def create_simflat_dataset(filedir=None, numfiles=10): dataset = data.Dataset(frames) return dataset - def create_raster(mask,data,dither_sizex=None,dither_sizey=None,row_cent = None,col_cent = None,n_dith=None,mask_size=420,snr=250,planet=None, band=None, radius=None, snr_constant=None): """Performs raster scan of Neptune or Uranus images @@ -427,7 +428,7 @@ def create_onsky_rasterscans(dataset,filedir=None,planet=None,band=None, im_size frame = data.Image(sim_data, pri_hdr=prihdr, ext_hdr=exthdr) pl=planet band=band - frame.pri_hdr.append(('TARGET', pl), end=True) + frame.pri_hdr.set('TARGET', pl) frame.pri_hdr.append(('FILTER', band), end=True) if filedir is not None: frame.save(filedir=filedir, filename=filepattern.format(i)) @@ -651,11 +652,27 @@ def create_default_headers(arrtype="SCI", vistype="TDEMO"): NAXIS2 = 2200 # fill in prihdr + prihdr['AUXFILE'] = 'mock_auxfile.fits' prihdr['OBSID'] = 0 prihdr['BUILD'] = 0 # prihdr['OBSTYPE'] = arrtype prihdr['VISTYPE'] = vistype prihdr['MOCK'] = True + prihdr['TELESCOP'] = 'ROMAN' + prihdr['INSTRUME'] = 'CGI' + prihdr['OBSNAME'] = 'MOCK' + prihdr['TARGET'] = 'MOCK' + prihdr['OBSNUM'] = '000' + prihdr['CAMPAIGN'] = '000' + prihdr['PROGNUM'] = '00000' + prihdr['SEGMENT'] = '000' + prihdr['VISNUM'] = '000' + prihdr['EXECNUM'] = '00' + prihdr['VISITID'] = prihdr['PROGNUM'] + prihdr['EXECNUM'] + prihdr['CAMPAIGN'] + prihdr['SEGMENT'] + prihdr['OBSNUM'] + prihdr['VISNUM'] + prihdr['PSFREF'] = False + prihdr['SIMPLE'] = True + prihdr['NAXIS'] = 0 + # fill in exthdr exthdr['NAXIS'] = 2 @@ -698,12 +715,22 @@ def create_default_headers(arrtype="SCI", vistype="TDEMO"): exthdr['CFAM_V'] = 1.0 exthdr['DPAM_H'] = 1.0 exthdr['DPAM_V'] = 1.0 + exthdr['CFAMNAME'] = '1F' # Color filter for band 1 + exthdr['DPAMNAME'] = 'IMAGING' + exthdr['FPAMNAME'] = 'HLC12_C2R1' # Focal plane mask for NFOV + exthdr['FSAMNAME'] = 'R1C1' # Circular field stop for NFOV + exthdr['LSAMNAME'] = 'NFOV' # Lyot stop for NFOV observations + exthdr['SPAMNAME'] = 'OPEN' # Used for NFOV observations + + + exthdr['DATETIME'] = '2024-01-01T11:00:00.000Z' exthdr['HIERARCH DATA_LEVEL'] = "L1" exthdr['MISSING'] = False exthdr['BUNIT'] = "" return prihdr, exthdr + def create_badpixelmap_files(filedir=None, col_bp=None, row_bp=None): """ Create simulated bad pixel map data. Code value is 4. @@ -2007,6 +2034,28 @@ def create_photon_countable_frames(Nbrights=30, Ndarks=40, EMgain=5000, kgain=7, return ill_dataset, dark_dataset, ill_mean, dark_mean +def gaussian_array(array_shape=[50,50],sigma=2.5,amp=100.,xoffset=0.,yoffset=0.): + """Generate a 2D square array with a centered gaussian surface (for mock PSF data). + + Args: + array_shape (int, optional): Shape of desired array in pixels. Defaults to [50,50]. + sigma (float, optional): Standard deviation of the gaussian curve, in pixels. Defaults to 5. + amp (float,optional): Amplitude of gaussian curve. Defaults to 1. + xoffset (float,optional): x offset of gaussian from array center. Defaults to 0. + yoffset (float,optional): y offset of gaussian from array center. Defaults to 0. + + Returns: + np.array: 2D array of a gaussian surface. + """ + x, y = np.meshgrid(np.linspace(-array_shape[0]/2+0.5, array_shape[0]/2-0.5, array_shape[0]), + np.linspace(-array_shape[1]/2+0.5, array_shape[1]/2-0.5, array_shape[1])) + dst = np.sqrt((x-xoffset)**2+(y-yoffset)**2) + + # Calculate Gaussian + gauss = np.exp(-((dst)**2 / (2.0 * sigma**2))) * amp / (2.0 * np.pi * sigma**2) + + return gauss + def create_flux_image(star_flux, fwhm, cal_factor, filedir=None, color_cor = 1., platescale=21.8, add_gauss_noise=True, noise_scale=1., background = 0., file_save=False): """ Create simulated data for absolute flux calibration. This is a point source in the image center with a 2D-Gaussian PSF @@ -2068,8 +2117,7 @@ def create_flux_image(star_flux, fwhm, cal_factor, filedir=None, color_cor = 1., # inject gaussian psf star stampsize = int(np.ceil(3 * fwhm)) sigma = fwhm/ (2.*np.sqrt(2*np.log(2))) - amplitude = flux/(2. * np.pi * sigma**2) - + # coordinate system y, x = np.indices([stampsize, stampsize]) y -= stampsize // 2 @@ -2086,7 +2134,7 @@ def create_flux_image(star_flux, fwhm, cal_factor, filedir=None, color_cor = 1., ymin = y[0][0] ymax = y[-1][-1] - psf = amplitude * np.exp(-((x - xpos)**2. + (y - ypos)**2.) / (2. * sigma**2)) + psf = gaussian_array((stampsize,stampsize),sigma,flux) # inject the star into the image sim_data[ymin:ymax + 1, xmin:xmax + 1] += psf @@ -2113,3 +2161,205 @@ def create_flux_image(star_flux, fwhm, cal_factor, filedir=None, color_cor = 1., frame.save(filedir=filedir, filename=filename) return frame + +default_wcs_string = """WCSAXES = 2 / Number of coordinate axes +CRPIX1 = 0.0 / Pixel coordinate of reference point +CRPIX2 = 0.0 / Pixel coordinate of reference point +CDELT1 = 1.0 / Coordinate increment at reference point +CDELT2 = 1.0 / Coordinate increment at reference point +CRVAL1 = 0.0 / Coordinate value at reference point +CRVAL2 = 0.0 / Coordinate value at reference point +LATPOLE = 90.0 / [deg] Native latitude of celestial pole +MJDREF = 0.0 / [d] MJD of fiducial time +""" + + +def create_psfsub_dataset(n_sci,n_ref,roll_angles,darkhole_scifiles=None,darkhole_reffiles=None, + wcs_header = None, + data_shape = [100,100], + centerxy = None, + outdir = None, + st_amp = 100., + noise_amp = 1., + ref_psf_spread=1. , + pl_contrast=1e-3 + ): + """Generate a mock science and reference dataset ready for the PSF subtraction step. + TODO: reference a central pixscale number, rather than hard code. + + Args: + n_sci (int): number of science frames, must be >= 1. + n_ref (int): nummber of reference frames, must be >= 0. + roll_angles (list-like): list of the roll angles of each science and reference + frame, with the science frames listed first. + darkhole_scifiles (list of str, optional): Filepaths to the darkhole science frames. + If not provided, a noisy 2D gaussian will be used instead. Defaults to None. + darkhole_reffiles (list of str, optional): Filepaths to the darkhole reference frames. + If not provided, a noisy 2D gaussian will be used instead. Defaults to None. + wcs_header (astropy.fits.Header, optional): Fits header object containing WCS + information. If not provided, a mock header will be created. Defaults to None. + data_shape (list of int): desired shape of data array. Must have length 2. Defaults to + [100,100]. + centerxy (list of float): Desired PSF center in xy order. Must have length 2. Defaults + to image center. + outdir (str, optional): Desired output directory. If not provided, data will not be + saved. Defaults to None. + st_amp (float): Amplitude of stellar psf added to fake data. Defaults to 100. + noise_amp (float): Amplitude of gaussian noise added to fake data. Defaults to 1. + ref_psf_spread (float): Fractional increase in gaussian PSF width between science and + reference PSFs. Defaults to 1. + pl_contrast (float): Flux ratio between planet and starlight incident on the detector. + Defaults to 1e-3. + + + Returns: + tuple: corgiDRP science Dataset object and reference Dataset object. + """ + + assert len(data_shape) == 2 + + if roll_angles is None: + roll_angles = [0.] * (n_sci+n_ref) + + # mask_center = np.array(data_shape)/2 + # star_pos = mask_center + pixscale = 0.0218 # arcsec + + # Build each science/reference frame + sci_frames = [] + ref_frames = [] + for i in range(n_sci+n_ref): + + # Create default headers + prihdr, exthdr = create_default_headers() + + # Read in darkhole data, if provided + if i=n_sci and not darkhole_reffiles is None: + fpath = darkhole_reffiles[i-n_sci] + _,fname = os.path.split(fpath) + darkhole = fits.getdata(fpath) + fill_value = np.nanmin(darkhole) + img_data = np.full(data_shape,fill_value) + + # Overwrite center of array with the darkhole data + cr_psf_pix = np.array(darkhole.shape) / 2 - 0.5 + if centerxy is None: + full_arr_center = np.array(img_data.shape) // 2 + else: + full_arr_center = (centerxy[1],centerxy[0]) + start_psf_ind = full_arr_center - np.array(darkhole.shape) // 2 + img_data[start_psf_ind[0]:start_psf_ind[0]+darkhole.shape[0],start_psf_ind[1]:start_psf_ind[1]+darkhole.shape[1]] = darkhole + psfcenty, psfcentx = cr_psf_pix + start_psf_ind + + # Otherwise generate a 2D gaussian for a fake PSF + else: + sci_sigma = 2.5 + ref_sigma = sci_sigma * ref_psf_spread + pl_amp = st_amp * pl_contrast + + label = 'ref' if i>= n_sci else 'sci' + sigma = ref_sigma if i>= n_sci else sci_sigma + fname = f'MOCK_{label}_roll{roll_angles[i]}.fits' + arr_center = np.array(data_shape) / 2 - 0.5 + if centerxy is None: + psfcenty,psfcentx = arr_center + else: + psfcentx,psfcenty = centerxy + + psf_off_xy = (psfcentx-arr_center[1],psfcenty-arr_center[0]) + img_data = gaussian_array(array_shape=data_shape, + xoffset=psf_off_xy[0], + yoffset=psf_off_xy[1], + sigma=sigma, + amp=st_amp) + + # Add some noise + rng = np.random.default_rng(seed=123+2*i) + noise = rng.normal(0,noise_amp,img_data.shape) + img_data += noise + + # Add fake planet to sci files + if i 0: + ref_dataset = data.Dataset(ref_frames) + else: + ref_dataset = None + + # Save datasets if outdir was provided + if not outdir is None: + if not os.path.exists(outdir): + os.makedirs(outdir) + + sci_dataset.save(filedir=outdir, filenames=['mock_psfsub_L2b_sci_input_dataset.fits']) + if len(ref_frames) > 0: + ref_dataset.save(filedir=outdir, filenames=['mock_psfsub_L2b_ref_input_dataset.fits']) + + return sci_dataset,ref_dataset diff --git a/requirements.txt b/requirements.txt index 71d4d114..9b69cd8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ photutils >= 2.0.0 statsmodels pyklip emccd-detect +tqdm diff --git a/tests/test_crop.py b/tests/test_crop.py index 67fd8ce2..867d8856 100644 --- a/tests/test_crop.py +++ b/tests/test_crop.py @@ -218,7 +218,7 @@ def test_non_nfov_input(): test_dataset = make_test_dataset(shape=[100,100],centxy=[50.5,50.5]) for frame in test_dataset: - frame.pri_hdr['LSAMNAME'] = 'WFOV' + frame.ext_hdr['LSAMNAME'] = 'WFOV' try: _ = crop(test_dataset,sizexy=20,centerxy=None) diff --git a/tests/test_err_dq.py b/tests/test_err_dq.py index 7d2dd4a4..966d5f59 100644 --- a/tests/test_err_dq.py +++ b/tests/test_err_dq.py @@ -26,6 +26,15 @@ dqhd = fits.Header() dqhd["CASE"] = "test" +data_3d = np.ones([2,1024,1024]) * 2 +err_3d = np.zeros([2,1024,1024]) +err1_3d = np.ones([2,1024,1024]) +err2_3d = err1_3d.copy() +err3_3d = np.ones([2,1,1024,1024]) * 0.5 +dq_3d = np.zeros([2,1024,1024], dtype = int) +dq1_3d = dq_3d.copy() +dq1_3d[0,0,0] = 1 + old_err_tracking = corgidrp.track_individual_errors # use default parameters detector_params = DetectorParams({}) @@ -118,6 +127,22 @@ def test_add_error_term(): assert image_test.err_hdr["Layer_2"] == "error_noid" assert image_test.err_hdr["Layer_3"] == "error_nuts" + image_3d = Image(data_3d,prhd,exthd,err_3d,dq_3d,errhd,dqhd) + image_3d.add_error_term(err1_3d, "error_noid") + assert image_3d.err[0,0,0,0] == err_3d[0,0,0] + image_3d.add_error_term(err2_3d, "error_nuts") + assert image_3d.err.shape == (3,2,1024,1024) + assert image_3d.err[0,0,0,0] == np.sqrt(err1_3d[0,0,0]**2 + err2_3d[0,0,0]**2) + image_3d.save(filename="test_image3d.fits") + + image_test_3d = Image('test_image3d.fits') + assert np.array_equal(image_test_3d.dq, dq_3d) + assert np.array_equal(image_test_3d.err, image_3d.err) + assert image_test_3d.err.shape == (3,2,1024,1024) + assert image_test_3d.err_hdr["Layer_1"] == "combined_error" + assert image_test_3d.err_hdr["Layer_2"] == "error_noid" + assert image_test_3d.err_hdr["Layer_3"] == "error_nuts" + def test_err_dq_dataset(): """ test the behavior of the err and data arrays in the dataset @@ -150,7 +175,6 @@ def test_get_masked_data(): assert masked_data.mean()==2 assert masked_data.sum()==image2.data.sum()-2 - def test_err_adderr_notrack(): """ test the initialization of error and adding errors when we are not tracking @@ -172,7 +196,6 @@ def test_err_adderr_notrack(): assert image1.err.shape == (1,1024,1024) assert image1.err[0,0,0] == np.sqrt(err1[0,0]**2 + err2[0,0]**2) - def test_read_many_errors_notrack(): """ Check that we can successfully discard errors when reading in a frame with multiple errors @@ -188,7 +211,6 @@ def test_read_many_errors_notrack(): with pytest.raises(KeyError): assert image_test.err_hdr["Layer_3"] == "error_nuts" - def test_err_array_sizes(): ''' Check that we're robust to 2D error arrays @@ -222,8 +244,6 @@ def test_err_array_sizes(): dark_frame.save(filedir=calibdir, filename=dark_filename) testcaldb.scan_dir_for_new_entries(calibdir) - - def teardown_module(): """ Runs automatically at the end. ONLY IN PYTEST @@ -235,7 +255,6 @@ def teardown_module(): corgidrp.track_individual_errors = old_err_tracking - # for debugging. does not run with pytest!! if __name__ == '__main__': test_err_array_sizes() @@ -244,6 +263,10 @@ def teardown_module(): test_add_error_term() test_err_dq_dataset() test_get_masked_data() + test_err_adderr_notrack() + test_read_many_errors_notrack() + test_err_array_sizes() for i in range(3): - os.remove('test_image{0}.fits'.format(i)) \ No newline at end of file + os.remove('test_image{0}.fits'.format(i)) + os.remove('test_image3d.fits') \ No newline at end of file diff --git a/tests/test_psfsub.py b/tests/test_psfsub.py new file mode 100644 index 00000000..b484aca1 --- /dev/null +++ b/tests/test_psfsub.py @@ -0,0 +1,676 @@ +from corgidrp.mocks import create_psfsub_dataset,create_default_headers +from corgidrp.l3_to_l4 import do_psf_subtraction +from corgidrp.data import PyKLIPDataset, Image, Dataset +from corgidrp.detector import nan_flags, flag_nans +from scipy.ndimage import shift, rotate +import pytest +import numpy as np + +## Helper functions/quantities + +def create_circular_mask(h, w, center=None, r=None): + """Creates a circular mask + + Args: + h (int): array height + w (int): array width + center (list of float, optional): Center of mask. Defaults to the + center of the array. + r (float, optional): radius of mask. Defaults to the minimum distance + from the center to the edge of the array. + + Returns: + np.array: boolean array with True inside the circle, False outside. + """ + + if center is None: # use the middle of the image + center = (w/2, h/2) + if r is None: # use the smallest distance between the center and image walls + r = min(center[0], center[1], w-center[0], h-center[1]) + + Y, X = np.ogrid[:h, :w] + dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2) + + mask = dist_from_center <= r + return mask + +iwa_lod = 3. +owa_lod = 9.7 +d = 2.36 #m +lam = 573.8e-9 #m +pixscale_arcsec = 0.0218 + +iwa_pix = iwa_lod * lam / d * 206265 / pixscale_arcsec +owa_pix = owa_lod * lam / d * 206265 / pixscale_arcsec + +st_amp = 100. +noise_amp=1e-11 +pl_contrast=1e-4 + +## pyKLIP data class tests + +def test_pyklipdata_ADI(): + """Tests that pyklip dataset centers frame, assigns rolls, and initializes PSF library properly for ADI data. + """ + + rolls = [0,90] + # Init with center shifted by 1 pixel in x, 2 pixels in y + mock_sci,mock_ref = create_psfsub_dataset(2,0,rolls, + centerxy=(50.5,51.5)) + + pyklip_dataset = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + + # Check image is centered properly + for i,image in enumerate(pyklip_dataset._input): + + assert mock_sci.all_data[i,2:,1:] == pytest.approx(image[:-2,:-1]), f"Frame {i} centered improperly." + + # Check roll assignments and filenames match up for sci dataset + for r,roll in enumerate(pyklip_dataset._PAs): + assert roll == rolls[r] + assert pyklip_dataset._filenames[r] == f'MOCK_sci_roll{roll}.fits_INT1', f"Incorrect roll assignment for frame {r}." + + # Check ref library is None + assert pyklip_dataset.psflib is None, "pyklip_dataset.psflib is not None, even though no reference dataset was provided." + +def test_pyklipdata_RDI(): + """Tests that pyklip dataset centers frame, assigns rolls, and initializes PSF library properly for RDI data. + """ + rolls = [45,180] + n_sci = 1 + n_ref = 1 + # Init with center shifted + mock_sci,mock_ref = create_psfsub_dataset(n_sci,n_ref,rolls,centerxy=(50.5,51.5)) + + pyklip_dataset = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + + # Check image is centered properly + for i,image in enumerate(pyklip_dataset._input): + + assert mock_sci.all_data[i,2:,1:] == pytest.approx(image[:-2,:-1]), f"Frame {i} centered improperly." + + # Check roll assignments and filenames match up for sci dataset + for r,roll in enumerate(pyklip_dataset._PAs): + assert roll == rolls[r] + assert pyklip_dataset._filenames[r] == f'MOCK_sci_roll{roll}.fits_INT1', f"Incorrect roll assignment for frame {r}." + + # Check ref library shape + assert pyklip_dataset._psflib.master_library.shape[0] == n_sci+n_ref + +def test_pyklipdata_ADIRDI(): + """Tests that pyklip dataset centers frame, assigns rolls, and initializes PSF library properly for ADI+RDI data. + """ + rolls = [45,-45,180] + n_sci = 2 + n_ref = 1 + # Init with center shifted by 1 pixel + mock_sci,mock_ref = create_psfsub_dataset(n_sci,n_ref,rolls, + centerxy=(50.5,51.5)) + + pyklip_dataset = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + + # Check image is recentered properly + for i,image in enumerate(pyklip_dataset._input): + + assert mock_sci.all_data[i,2:,1:] == pytest.approx(image[:-2,:-1]), f"Frame {i} centered improperly." + + # Check roll assignments and filenames match up for sci dataset + for r,roll in enumerate(pyklip_dataset._PAs): + assert roll == rolls[r] + assert pyklip_dataset._filenames[r] == f'MOCK_sci_roll{roll}.fits_INT1', f"Incorrect roll assignment for frame {r}." + + # Check ref library shape + assert pyklip_dataset._psflib.master_library.shape[0] == n_sci+n_ref + +def test_pyklipdata_badtelescope(): + """Tests that pyklip data class initialization fails if data does not come from Roman. + """ + mock_sci,mock_ref = create_psfsub_dataset(1,1,[0,0]) + mock_sci[0].pri_hdr['TELESCOP'] = "HUBBLE" + + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + +def test_pyklipdata_badinstrument(): + """Tests that pyklip data class initialization fails if data does not come from Coronagraph Instrument. + """ + mock_sci,mock_ref = create_psfsub_dataset(1,1,[0,0]) + mock_sci[0].pri_hdr['INSTRUME'] = "WFI" + + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + +def test_pyklipdata_badcfamname(): + """Tests that pyklip data class raises an error if the CFAM position is not a valid position name. + """ + mock_sci,mock_ref = create_psfsub_dataset(1,1,[0,0]) + mock_sci[0].ext_hdr['CFAMNAME'] = "BAD" + + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + +def test_pyklipdata_notdataset(): + """Tests that pyklip data class raises an error if the iput is not a corgidrp dataset object. + """ + mock_sci,mock_ref = create_psfsub_dataset(1,0,[0]) + mock_ref = [] + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + + mock_sci = [] + mock_ref = None + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + +def test_pyklipdata_badimgshapes(): + """Tests that pyklip data class enforces all data frames to have the same shape. + """ + mock_sci,mock_ref = create_psfsub_dataset(2,0,[0,0]) + + mock_sci[0].data = np.zeros((5,5)) + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + +def test_pyklipdata_multiplepixscales(): + """Tests that pyklip data class enforces that each frame has the same pixel scale. + """ + mock_sci,mock_ref = create_psfsub_dataset(2,0,[0,0]) + + mock_sci[0].ext_hdr["PLTSCALE"] = 10 + with pytest.raises(UserWarning): + _ = PyKLIPDataset(mock_sci,psflib_dataset=mock_ref) + +# DQ flagging tests + +def make_test_data(frame_shape,n_frames=1,): + """Makes a test corgidrp Dataset of all zeros with the desired + frame shape and number of frames. + + Args: + frame_shape (listlike of int): 2D or 3D image shape desired. + n_frames (int, optional): Number of frames. Defaults to 1. + + Returns: + corgidrp.Dataset: mock dataset of all zeros. + """ + + frames = [] + for i in range(n_frames): + prihdr, exthdr = create_default_headers() + im_data = np.zeros(frame_shape).astype(np.float64) + frame = Image(im_data, pri_hdr=prihdr, ext_hdr=exthdr) + + frames.append(frame) + + dataset = Dataset(frames) + return dataset + +def test_nanflags_2D(): + """Test detector.nan_flags() on 2D data. + """ + + # 2D: + mock_dataset = make_test_data([10,10],n_frames=2,) + mock_dataset.all_dq[0,4,6] = 1 + mock_dataset.all_dq[1,5,7] = 1 + + expected_data = np.zeros([2,10,10]) + expected_data[0,4,6] = np.nan + expected_data[1,5,7] = np.nan + + expected_err = np.zeros([2,1,10,10]) + expected_err[0,0,4,6] = np.nan + expected_err[1,0,5,7] = np.nan + + nanned_dataset = nan_flags(mock_dataset,threshold=1) + + if not np.array_equal(nanned_dataset.all_data, expected_data,equal_nan=True): + # import matplotlib.pyplot as plt + # fig,axes = plt.subplots(1,2) + # axes[0].imshow(nanned_dataset.all_data[0,:,:]) + # axes[1].imshow(expected_data[0,:,:]) + # plt.show() + # plt.close() + raise Exception('2D nan_flags test produced unexpected result') + + if not np.array_equal(nanned_dataset.all_err,expected_err,equal_nan=True): + raise Exception('2D nan_flags test produced unexpected result for ERR array') + +def test_nanflags_3D(): + """Test detector.nan_flags() on 3D data. + """ + + # 3D: + mock_dataset = make_test_data([3,10,10],n_frames=2,) + mock_dataset.all_dq[0,0,4,6] = 1 + mock_dataset.all_dq[1,:,5,7] = 1 + + expected_data = np.zeros([2,3,10,10]) + expected_data[0,0,4,6] = np.nan + expected_data[1,:,5,7] = np.nan + + expected_err = np.zeros([2,1,3,10,10]) + expected_err[0,0,0,4,6] = np.nan + expected_err[1,0,:,5,7] = np.nan + + nanned_dataset = nan_flags(mock_dataset,threshold=1) + + if not np.array_equal(nanned_dataset.all_data, expected_data,equal_nan=True): + raise Exception('2D nan_flags test produced unexpected result') + + if not np.array_equal(nanned_dataset.all_err, expected_err,equal_nan=True): + raise Exception('3D nan_flags test produced unexpected result for ERR array') + +def test_nanflags_mixed_dqvals(): + """Test detector.nan_flags() on 3D data with some DQ values below the threshold. + """ + + # 3D: + mock_dataset = make_test_data([3,10,10],n_frames=2,) + mock_dataset.all_dq[0,0,4,6] = 1 + mock_dataset.all_dq[1,:,5,7] = 2 + + expected_data = np.zeros([2,3,10,10]) + expected_data[1,:,5,7] = np.nan + + expected_err = np.zeros([2,1,3,10,10]) + expected_err[1,0,:,5,7] = np.nan + + nanned_dataset = nan_flags(mock_dataset,threshold=2) + + if not np.array_equal(nanned_dataset.all_data, expected_data,equal_nan=True): + raise Exception('nan_flags with mixed dq values produced unexpected result') + + if not np.array_equal(nanned_dataset.all_err, expected_err,equal_nan=True): + raise Exception('3D nan_flags with mixed dq values produced unexpected result for ERR array') + +def test_flagnans_2D(): + """Test detector.flag_nans() on 2D data. + """ + + # 2D: + mock_dataset = make_test_data([10,10],n_frames=2,) + mock_dataset.all_data[0,4,6] = np.nan + mock_dataset.all_data[1,5,7] = np.nan + + expected_dq = np.zeros([2,10,10]) + expected_dq[0,4,6] = 1 + expected_dq[1,5,7] = 1 + + flagged_dataset = flag_nans(mock_dataset) + + if not np.array_equal(flagged_dataset.all_dq, expected_dq,equal_nan=True): + raise Exception('2D nan_flags test produced unexpected result for DQ array') + +def test_flagnans_3D(): + """Test detector.flag_nans() on 3D data. + """ + + # 3D: + mock_dataset = make_test_data([3,10,10],n_frames=2,) + mock_dataset.all_data[0,0,4,6] = np.nan + mock_dataset.all_data[1,:,5,7] = np.nan + + expected_dq = np.zeros([2,3,10,10]) + expected_dq[0,0,4,6] = 1 + expected_dq[1,:,5,7] = 1 + + flagged_dataset = flag_nans(mock_dataset) + + if not np.array_equal(flagged_dataset.all_dq, expected_dq,equal_nan=True): + raise Exception('3D flag_nans test produced unexpected result for DQ array') + +def test_flagnans_flagval2(): + """Test detector.flag_nans() on 3D data with a non-default DQ value. + """ + + # 3D: + mock_dataset = make_test_data([3,10,10],n_frames=2,) + mock_dataset.all_data[0,0,4,6] = np.nan + mock_dataset.all_data[1,:,5,7] = np.nan + + expected_dq = np.zeros([2,3,10,10]) + expected_dq[0,0,4,6] = 2 + expected_dq[1,:,5,7] = 2 + + flagged_dataset = flag_nans(mock_dataset,flag_val=2) + + if not np.array_equal(flagged_dataset.all_dq, expected_dq,equal_nan=True): + raise Exception('3D nan_flags test produced unexpected result for DQ array') + +## PSF subtraction step tests + +def test_psf_sub_split_dataset(): + """Tests that psf subtraction step can correctly split an input dataset into + science and reference dataset, if they are not passed in separately. + """ + + # Sci & Ref + numbasis = [1,4,8] + rolls = [270+13,270-13,0,0] + mock_sci,mock_ref = create_psfsub_dataset(2,2,rolls, + st_amp=st_amp, + noise_amp=noise_amp, + pl_contrast=pl_contrast) + + # combine mock_sci and mock_ref into 1 dataset + frames = [*mock_sci,*mock_ref] + mock_sci_and_ref = Dataset(frames) + + # Pass combined dataset to do_psf_subtraction + result = do_psf_subtraction(mock_sci_and_ref, + numbasis=numbasis, + fileprefix='test_single_dataset', + do_crop=False) + + # Should choose ADI+RDI + for frame in result: + if not frame.pri_hdr['KLIP_ALG'] == 'ADI+RDI': + raise Exception(f"Chose {frame.pri_hdr['KLIP_ALG']} instead of 'ADI+RDI' mode when provided 2 science images and 2 references.") + + # Try passing only science frames + result = do_psf_subtraction(mock_sci, + numbasis=numbasis, + fileprefix='test_sci_only_dataset', + do_crop=False) + + # Should choose ADI + for frame in result: + if not frame.pri_hdr['KLIP_ALG'] == 'ADI': + raise Exception(f"Chose {frame.pri_hdr['KLIP_ALG']} instead of 'ADI' mode when provided 2 science images and no references.") + + # pass only reference frames (should fail) + with pytest.raises(UserWarning): + _ = do_psf_subtraction(mock_ref, + numbasis=numbasis, + fileprefix='test_ref_only_dataset', + do_crop=False) + +def test_psf_sub_ADI_nocrop(): + """Tests that psf subtraction step correctly identifies an ADI dataset (multiple rolls, no references), + that overall counts decrease, that the KLIP result matches the analytical expectation, and that the + output data shape is correct. + """ + + numbasis = [1] + rolls = [270+13,270-13] + mock_sci,mock_ref = create_psfsub_dataset(2,0,rolls, + st_amp=st_amp, + noise_amp=noise_amp, + pl_contrast=pl_contrast) + + result = do_psf_subtraction(mock_sci,mock_ref, + numbasis=numbasis, + fileprefix='test_ADI', + do_crop=False) + + analytical_result = shift((rotate(mock_sci[0].data - mock_sci[1].data,-rolls[0],reshape=False,cval=0) + rotate(mock_sci[1].data - mock_sci[0].data,-rolls[1],reshape=False,cval=0)) / 2, + [0.5,0.5], + cval=np.nan) + + for i,frame in enumerate(result): + + # import matplotlib.pyplot as plt + + # fig,axes = plt.subplots(1,3,sharey=True,layout='constrained',figsize=(12,3)) + # im0 = axes[0].imshow(frame.data,origin='lower') + # plt.colorbar(im0,ax=axes[0],shrink=0.8) + # axes[0].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[0].set_title(f'PSF Sub Result ({numbasis[i]} KL Modes)') + + # im1 = axes[1].imshow(analytical_result,origin='lower') + # plt.colorbar(im1,ax=axes[1],shrink=0.8) + # axes[1].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[1].set_title('Analytical result') + + # im2 = axes[2].imshow(frame.data - analytical_result,origin='lower') + # plt.colorbar(im2,ax=axes[2],shrink=0.8) + # axes[2].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[2].set_title('Difference') + + # fig.suptitle('ADI') + + # plt.show() + # plt.close() + + # Overall counts should decrease + if not np.nansum(mock_sci[0].data) > np.nansum(frame.data): + raise Exception(f"ADI subtraction resulted in increased counts for frame {i}.") + + # Result should match analytical result + if np.nanmax(np.abs(frame.data - analytical_result)) > 1e-5: + raise Exception(f"Absolute difference between ADI result and analytical result is greater then 1e-5.") + + if not frame.pri_hdr['KLIP_ALG'] == 'ADI': + raise Exception(f"Chose {frame.pri_hdr['KLIP_ALG']} instead of 'ADI' mode when provided 2 science images and no references.") + + # Check expected data shape + expected_data_shape = (1,len(numbasis),*mock_sci[0].data.shape) + if not result.all_data.shape == expected_data_shape: + raise Exception(f"Result data shape was {result.all_data.shape} instead of expected {expected_data_shape} after ADI subtraction.") + +def test_psf_sub_RDI_nocrop(): + """Tests that psf subtraction step correctly identifies an RDI dataset (single roll, 1 or more references), + that overall counts decrease, that the KLIP result matches the analytical expectation, and that the + output data shape is correct. + """ + numbasis = [1] + rolls = [13,0] + + mock_sci,mock_ref = create_psfsub_dataset(1,1,rolls,ref_psf_spread=1., + centerxy=(49.5,49.5), + pl_contrast=pl_contrast, + noise_amp=noise_amp, + st_amp=st_amp + ) + + result = do_psf_subtraction(mock_sci,mock_ref, + numbasis=numbasis, + fileprefix='test_RDI', + do_crop=False + ) + analytical_result = rotate(mock_sci[0].data - mock_ref[0].data,-rolls[0],reshape=False,cval=np.nan) + + for i,frame in enumerate(result): + + mask = create_circular_mask(*frame.data.shape[-2:],r=iwa_pix,center=(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY'])) + masked_frame = np.where(mask,np.nan,frame.data) + + # import matplotlib.pyplot as plt + + # fig,axes = plt.subplots(1,3,sharey=True,layout='constrained',figsize=(12,3)) + # im0 = axes[0].imshow(mock_sci[0].data,origin='lower') + # plt.colorbar(im0,ax=axes[0],shrink=0.8) + # axes[0].scatter(mock_sci[0].ext_hdr['STARLOCX'],mock_sci[0].ext_hdr['STARLOCY']) + # axes[0].set_title(f'Sci Input') + + # im1 = axes[1].imshow(mock_ref[0].data,origin='lower') + # plt.colorbar(im1,ax=axes[1],shrink=0.8) + # axes[1].scatter(mock_ref[0].ext_hdr['STARLOCX'],mock_ref[0].ext_hdr['STARLOCY']) + # axes[1].set_title('Ref Input') + + # im2 = axes[2].imshow(mock_sci[0].data - mock_ref[0].data,origin='lower') + # plt.colorbar(im2,ax=axes[2],shrink=0.8) + # axes[2].scatter(mock_sci[0].ext_hdr['STARLOCX'],mock_sci[0].ext_hdr['STARLOCY']) + # axes[2].set_title('Difference') + + # fig.suptitle('Inputs') + + # fig,axes = plt.subplots(1,3,sharey=True,layout='constrained',figsize=(12,3)) + # im0 = axes[0].imshow(frame.data - np.nanmedian(frame.data),origin='lower') + # plt.colorbar(im0,ax=axes[0],shrink=0.8) + # axes[0].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[0].set_title(f'PSF Sub Result ({numbasis[i]} KL Modes, Median Subtracted)') + + # im1 = axes[1].imshow(analytical_result,origin='lower') + # plt.colorbar(im1,ax=axes[1],shrink=0.8) + # axes[1].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[1].set_title('Analytical result') + + # norm = LogNorm(vmin=1e-8, vmax=1, clip=False) + # im2 = axes[2].imshow(frame.data - np.nanmedian(frame.data) - analytical_result, + # origin='lower',norm=None) + # plt.colorbar(im2,ax=axes[2],shrink=0.8) + # axes[2].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[2].set_title('Difference') + + # fig.suptitle('RDI Result') + + # plt.show() + # plt.close() + + # Overall counts should decrease + if not np.nansum(mock_sci[0].data) > np.nansum(frame.data): + raise Exception(f"RDI subtraction resulted in increased counts for frame {i}.") + + # The step should choose mode RDI based on having 1 roll and 1 reference. + if not frame.pri_hdr['KLIP_ALG'] == 'RDI': + raise Exception(f"Chose {frame.pri_hdr['KLIP_ALG']} instead of 'RDI' mode when provided 1 science image and 1 reference.") + + # Frame should match analytical result outside of the IWA (after correcting for the median offset) + if not np.nanmax(np.abs((masked_frame - np.nanmedian(frame.data)) - analytical_result)) < 1e-5: + raise Exception("RDI subtraction did not produce expected analytical result.") + + # Check expected data shape + expected_data_shape = (1,len(numbasis),*mock_sci[0].data.shape) + if not result.all_data.shape == expected_data_shape: + raise Exception(f"Result data shape was {result.all_data.shape} instead of expected {expected_data_shape} after RDI subtraction.") + +def test_psf_sub_ADIRDI_nocrop(): + """Tests that psf subtraction step correctly identifies an ADI+RDI dataset (multiple rolls, 1 or more references), + that overall counts decrease, that the KLIP result matches the analytical expectation for 1 KL mode, and that the + output data shape is correct. + """ + + numbasis = [1,2,3,4] + rolls = [13,-13,0] + mock_sci,mock_ref = create_psfsub_dataset(2,1,rolls, + st_amp=st_amp, + noise_amp=noise_amp, + pl_contrast=pl_contrast) + + + analytical_result1 = (rotate(mock_sci[0].data - (mock_sci[1].data/2+mock_ref[0].data/2),-rolls[0],reshape=False,cval=0) + rotate(mock_sci[1].data - (mock_sci[0].data/2+mock_ref[0].data/2),-rolls[1],reshape=False,cval=0)) / 2 + analytical_result2 = (rotate(mock_sci[0].data - mock_sci[1].data,-rolls[0],reshape=False,cval=0) + rotate(mock_sci[1].data - mock_sci[0].data,-rolls[1],reshape=False,cval=0)) / 2 + analytical_results = [analytical_result1,analytical_result2] + + result = do_psf_subtraction(mock_sci,mock_ref, + numbasis=numbasis, + fileprefix='test_ADI+RDI', + do_crop=False) + + for i,frame in enumerate(result): + + + mask = create_circular_mask(*frame.data.shape[-2:],r=iwa_pix,center=(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY'])) + masked_frame = np.where(mask,np.nan,frame.data) + + # import matplotlib.pyplot as plt + + # fig,axes = plt.subplots(1,3,sharey=True,layout='constrained',figsize=(12,3)) + # im0 = axes[0].imshow(frame.data - np.nanmedian(frame.data),origin='lower') + # plt.colorbar(im0,ax=axes[0],shrink=0.8) + # axes[0].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[0].set_title(f'PSF Sub Result ({numbasis[i]} KL Modes, Median Subtracted)') + + # im1 = axes[1].imshow(analytical_results[0],origin='lower') + # plt.colorbar(im1,ax=axes[1],shrink=0.8) + # axes[1].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[1].set_title('Analytical result') + + # im2 = axes[2].imshow(masked_frame - np.nanmedian(frame.data) - analytical_results[0],origin='lower') + # plt.colorbar(im2,ax=axes[2],shrink=0.8) + # axes[2].scatter(frame.ext_hdr['STARLOCX'],frame.ext_hdr['STARLOCY']) + # axes[2].set_title('Difference') + + # fig.suptitle('ADI+RDI') + + # plt.show() + # plt.close() + + # Overall counts should decrease + if not np.nansum(mock_sci[0].data) > np.nansum(frame.data): + raise Exception(f"ADI+RDI subtraction resulted in increased counts for frame {i}.") + + # Corgidrp should know to choose ADI+RDI mode + if not frame.pri_hdr['KLIP_ALG'] == 'ADI+RDI': + raise Exception(f"Chose {frame.pri_hdr['KLIP_ALG']} instead of 'ADI+RDI' mode when provided 2 science images and 1 reference.") + + # Frame should match analytical result outside of the IWA (after correcting for the median offset) for KL mode 1 + if i==0: + if not np.nanmax(np.abs((masked_frame - np.nanmedian(frame.data)) - analytical_results[i])) < 1e-5: + raise Exception("ADI+RDI subtraction did not produce expected analytical result.") + + # Check expected data shape + expected_data_shape = (1,len(numbasis),*mock_sci[0].data.shape) + if not result.all_data.shape == expected_data_shape: + raise Exception(f"Result data shape was {result.all_data.shape} instead of expected {expected_data_shape} after ADI+RDI subtraction.") + +def test_psf_sub_withcrop(): + """Tests that psf subtraction step results in the correct data shape when + cropping by default, and that overall counts decrease. + """ + + numbasis = [1,2] + rolls = [270+13,270-13] + mock_sci,mock_ref = create_psfsub_dataset(2,0,rolls,pl_contrast=1e-3) + + result = do_psf_subtraction(mock_sci,mock_ref, + numbasis=numbasis, + fileprefix='test_withcrop') + + for i,frame in enumerate(result): + + # Overall counts should decrease + if not np.nansum(mock_sci[0].data) > np.nansum(frame.data): + raise Exception(f"PSF subtraction resulted in increased counts for frame {i}.") + + # Check expected data shape + expected_data_shape = (1,len(numbasis),60,60) + if not result.all_data.shape == expected_data_shape: + raise Exception(f"Result data shape was {result.all_data.shape} instead of expected {expected_data_shape} after ADI subtraction.") + +def test_psf_sub_badmode(): + """Tests that psf subtraction step fails correctly if an unconfigured mode is supplied (e.g. SDI). + """ + + numbasis = [1,2,3,4] + rolls = [13,-13,0] + mock_sci,mock_ref = create_psfsub_dataset(2,1,rolls, + st_amp=st_amp, + noise_amp=noise_amp, + pl_contrast=pl_contrast) + + + with pytest.raises(Exception): + _ = do_psf_subtraction(mock_sci,mock_ref, + numbasis=numbasis, + mode='SDI', + fileprefix='test_SDI', + do_crop=False) + +if __name__ == '__main__': + test_pyklipdata_ADI() + test_pyklipdata_RDI() + test_pyklipdata_ADIRDI() + test_pyklipdata_badtelescope() + test_pyklipdata_badinstrument() + test_pyklipdata_badcfamname() + test_pyklipdata_notdataset() + test_pyklipdata_badimgshapes() + test_pyklipdata_multiplepixscales() + + test_nanflags_2D() + test_nanflags_3D() + test_nanflags_mixed_dqvals() + test_flagnans_2D() + test_flagnans_3D() + test_flagnans_flagval2() + + test_psf_sub_split_dataset() + + test_psf_sub_ADI_nocrop() + test_psf_sub_RDI_nocrop() + test_psf_sub_ADIRDI_nocrop() + test_psf_sub_withcrop() + test_psf_sub_badmode()