Skip to content

Commit

Permalink
Initial commit - copy of Matthew Willson preliminary package
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilles Orban de Xivry committed Apr 14, 2022
1 parent 6785f6b commit 7680286
Show file tree
Hide file tree
Showing 11 changed files with 537 additions and 1 deletion.
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,31 @@
# PSI
PSI package for focal-plane wavefront sensing, developed primarily for ELT/METIS
Python implementation of Phase Sorting Interferometry (PSI) for estimating non-common path errors in High Contrast Imaging (HCI) instruments.

This PSI package is developed primarily for ELT/METIS, but also for VLT/ERIS.


## Legacy contribution
This work is based on the initial work of Emiel Por, followed-up by the work of Matthew Willson†:
- [fepsi](https://github.com/mkenworthy/fepsi) : written by Emiel Por and modified by Matthew Kenworthy (Leiden University)
- [psi](https://github.com/mwillson-astro/PSI/tree/master) : preliminary package developed by Matthew Willson† for METIS (Liège University)

The initial commit of this repository is a copy of [psi](https://github.com/mwillson-astro/PSI/tree/master).

## References
- ["Focal Plane Wavefront Sensing Using Residual Adaptive Optics Speckles" by Codona and Kenworthy (2013)](https://iopscience.iop.org/article/10.1088/0004-637X/767/2/100), ApJ, 767, 100.

## Dependencies

Code requires the following Python packages:
* `astropy`
* `matplotlib`
* `hcipy`

The ffmpeg tools are required for generating movies.

[HCIPy](https://github.com/ehpor/hcipy) is a Python software package written and developed by Emiel Por for performing end-to-end simulations of high contrast imaging instruments for astronomy.

It can be installed from PyPI with:
```
pip install hcipy
```
19 changes: 19 additions & 0 deletions __init__ .py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Import all submodules.
from . import aperture
from . import psi

# Import all core submodules in default namespace.
from .aperture import *
from .psi import *

# Export default namespaces.
__all__ = []
__all__.extend(aperture.__all__)
__all__.extend(psi.__all__)

from pkg_resources import get_distribution, DistributionNotFound
try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
# package is not installed
pass
3 changes: 3 additions & 0 deletions aperture/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = ['make_COMPASS_aperture', 'resize_img', 'pad_img', 'crop_img']

from .realistic import *
127 changes: 127 additions & 0 deletions aperture/realistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import numpy as np
# from ..field import CartesianGrid, UnstructuredCoords, make_hexagonal_grid, Field
# from .generic import *
from hcipy.field import CartesianGrid, UnstructuredCoords, make_hexagonal_grid, Field # These two lines maybe needed in the future for making custom apertures so I have left them in
from hcipy.aperture.generic import *

def make_vlt_aperture():
pass

def make_subaru_aperture():
pass

def make_lbt_aperture():
pass

def make_elt_aperture():
pass

def make_COMPASS_aperture(npupil=256, input_folder='/Users/matt/Documents/METIS/TestArea/fepsi/COMPASSPhaseScreens/Test/', nimg=720):
'''Create an aperture from a COMPASS product.
Parameters
----------
npupil : scalar
Number of pixels across each dimension of the array.
input_folder : string
Location of the aperture file from COMPASS.
nimg : scalar
The size the aperture file needs to be cut down to as it comes with some padding. 720 should be the default unless something
changes in the COMPASS products.
Returns
-------
Field generator
The resized COMPASS aperture.
'''

mask = fits.getdata(os.path.join(input_folder, 'mask_256.fits'))
if mask.shape[0] < nimg:
mask = crop_img(mask, nimg, verbose=False)
mask_pupil = resize_img(mask, npupil)
#mask_pupil[mask_pupil<0.8] = 0
#mask_pupil = mask_pupil.transpose() # Testing for wind direction dependencies. Should be commented out.
aperture = np.ravel(mask_pupil)

nimg = 720
npupil = 256

def func(grid):
return Field(aperture, grid)
return func

def resize_img(img, new_size, preserve_range=True, mode='reflect',
anti_aliasing=True):
''' Resize an image. Handles even and odd sizes.
'''
requirement = "new_size must be an int or a tuple/list of size 2."
assert type(new_size) in [int, tuple, list], requirement
if type(new_size) is int:
new_size = (new_size, new_size)
else:
assert len(new_size) == 2, requirement
assert img.ndim in [2, 3], 'image must be a frame (2D) or a cube (3D)'
if img.ndim == 3:
new_size = (len(img), *new_size)
if new_size != img.shape:
with warnings.catch_warnings():
warnings.simplefilter("ignore") # when anti_aliasing=False, and NANs
img = np.float32(resize(np.float32(img), new_size, \
preserve_range=preserve_range, mode=mode, anti_aliasing=anti_aliasing))
return img

def pad_img(img, padded_size, pad_value=0):
''' Pad an img with a value (default is zero). Handles even and odd sizes.
'''
requirement = "padded_size must be an int or a tuple/list of size 2."
assert type(padded_size) in [int, tuple, list], requirement
if type(padded_size) is int:
(x1, y1) = (padded_size, padded_size)
else:
assert len(padded_size) == 2, requirement
(x1, y1) = padded_size
(x2, y2) = img.shape
# determine padding region
assert not (x1<x2 or y1<y2), "padding region can't be smaller than image size."
dx = int((x1 - x2)/2)
dy = int((y1 - y2)/2)
padx = (dx, dx) if (x1-x2)%2==0 else (dx+1, dx)
pady = (dy, dy) if (y1-y2)%2==0 else (dy+1, dy)
# pad image
img = np.pad(img, [padx, pady], mode='constant', constant_values=pad_value)
return img

def crop_img(img, new_size, margin=0, verbose=True):
''' Crop an img to a new size. Handles even and odd sizes.
Can add an optional margin of length 1, 2 (x,y) or 4 (x1,x2,y1,y2).
'''
requirement = "new_size must be an int or a tuple/list of size 2."
assert type(new_size) in [int, tuple, list], requirement
if type(new_size) is int:
(x1, y1) = (new_size, new_size)
else:
assert len(new_size) == 2, requirement
(x1, y1) = new_size
(x2, y2) = img.shape
if not np.any(np.array([x1,y1]) < np.array([x2,y2])):
if verbose == True:
print('crop size is larger than img size')
else:
# determine cropping region
dx = int((x2 - x1)/2)
dy = int((y2 - y1)/2)
cropx = (dx, dx) if (x2-x1)%2==0 else (dx+1, dx)
cropy = (dy, dy) if (y2-y1)%2==0 else (dy+1, dy)
# check for margins
requirement2 = "margin must be an int or a tuple/list of size 2 or 4."
assert type(margin) in [int, tuple, list], requirement2
if type(margin) is int:
(mx1, mx2, my1, my2) = (margin, margin, margin, margin)
elif len(margin) == 2:
(mx1, mx2, my1, my2) = (margin[0], margin[0], margin[1], margin[1])
else:
assert len(margin) == 4, requirement2
(mx1, mx2, my1, my2) = margin
# crop image
img = img[cropx[0]-mx1:-cropx[1]+mx2, cropy[0]-my1:-cropy[1]+my2]
return img
Binary file added psi/.DS_Store
Binary file not shown.
11 changes: 11 additions & 0 deletions psi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
__all__ = ['crop_img', 'resize_img', 'process_screen', 'load_file', 'loadNCPA', 'load_and_process', 'get_contrast_curve', 'remove_piston', 'gauss_2Dalt', 'resetVariables', 'resetPSIVariables']
__all__ += ['makeFilters', 'makeMatrices', 'makeOpticalSystem', 'makeZerns']
__all__ += ['prop_image', 'prop_wf']
__all__ += ['PSI']
__all__ += ['processCorrection','processCorrection_Original']

from .psi_utils import *
from .makeGrids import *
from .processCorrection import *
from .psi_propagations import *
from .psi import *
49 changes: 49 additions & 0 deletions psi/makeGrids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
from scipy.signal import convolve2d
from .psi_utils import gauss_2Dalt
from hcipy.aperture import circular_aperture, make_obstructed_circular_aperture
from hcipy.coronagraphy import VortexCoronagraph
from hcipy.math_util import inverse_tikhonov
from hcipy.mode_basis import ModeBasis, make_zernike_basis
from hcipy.optics import Apodizer, make_gaussian_influence_functions, OpticalSystem

def makeFilters(grid, type="back_prop", sigma=0.05, cent_obs=0.27, outer_vin=1.0):
"""Packages required:
Circular and make_obstructed_circular_aperture
gauss_2Dalt
convolve2d
"""
if type == "back_prop":
filter_ = circular_aperture(15)(grid) # 15 is an arbitrary number set by Emiel.
elif type == "ncpa":
filter_ = make_obstructed_circular_aperture(1*0.7, 1.8*cent_obs)(grid)
elif type == "reset":
filter_ = make_obstructed_circular_aperture(1*outer_vin, 1.8*cent_obs)(grid)
else:
print("Type of filter not given!!")
quit()
sigma_ = int(sigma*np.sqrt(filter_.shape))
if sigma_ != 0:
kernal_ = gauss_2Dalt(size_x=int(np.sqrt(filter_.shape)), sigma_x=sigma_)
filter_.shape = (int(np.sqrt(filter_.shape)),int(np.sqrt(filter_.shape)))
filter_ = convolve2d(filter_, kernal_, mode='same')
filter_ = filter_.ravel()
return filter_

def makeMatrices(grid_, ao_acts_, aperture_, reconstruction_normalisation_):
ao_modes_ = make_gaussian_influence_functions(grid_, ao_acts_, 1.2 / ao_acts_) # Create an object containing all the available DM pistons, 1.0 to
ao_modes_ = ModeBasis([mode * aperture_ for mode in ao_modes_])
transformation_matrix_ = ao_modes_.transformation_matrix
reconstruction_matrix_ = inverse_tikhonov(transformation_matrix_, reconstruction_normalisation_)
return transformation_matrix_, reconstruction_matrix_

def makeOpticalSystem(grid_):
lyot_stop_ = make_obstructed_circular_aperture(0.98, 0.3)(grid_)
coro_ = OpticalSystem([VortexCoronagraph(grid_, 2), Apodizer(lyot_stop_)])
return coro_

def makeZerns(n_zerns, grid, start):
zernike_modes_ = make_zernike_basis(n_zerns, 1, grid, start)
reconstructor_zernike_ = inverse_tikhonov(zernike_modes_.transformation_matrix, 1e-3)
modal_means_ = np.ones(n_zerns)
return zernike_modes_, reconstructor_zernike_, modal_means_
98 changes: 98 additions & 0 deletions psi/processCorrection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np

def processCorrection(ncpa_estimate_, correction_, correction_normalisation, aperture_, ncpa_filter_aperture_, recon_zern_recon, \
zern_recon, modal_means_init_, modal_means_, reset=False, resetFilter=None, ZernCorrection=True, ms=0):
m_corrections = recon_zern_recon.dot(ncpa_estimate_ * aperture_)

# Lets have a look at the zernikes which are produced. Starting from the 4th zernike (no piston or tip/tilt)
if np.size(ms) != 1:
m_mean = np.mean(m_corrections)
m_std = np.std(m_corrections)
ms += m_corrections#*modal_means_

ncpa_estimate_temp = np.copy(ncpa_estimate_)
new_correction_max = np.max(zern_recon.transformation_matrix.dot(m_corrections*modal_means_init_))

lo_freq_estimate = zern_recon.transformation_matrix.dot(m_corrections*modal_means_)
lo_freq_corr = np.max(lo_freq_estimate) / new_correction_max
lo_freq_estimate *= lo_freq_corr
hi_freq_estimate = (ncpa_estimate_ - lo_freq_estimate) * ncpa_filter_aperture_

# Correcting by power rather than max
# ncpa_estimate_temp = np.copy(ncpa_estimate_)
# new_correction_sum = np.sum(zern_recon.transformation_matrix.dot(m_corrections*modal_means_init_))
# lo_freq_corr = np.sum(lo_freq_estimate) / new_correction_sum
# lo_freq_estimate *= lo_freq_corr

if ZernCorrection==True:
new_mask = aperture_-ncpa_filter_aperture_ # Define a mask of just the problem areas. Let everything else be high frequency
new_estimate = (lo_freq_estimate*new_mask*0.5 + ncpa_estimate_*ncpa_filter_aperture_)
new_estimate *= np.max(new_estimate) / new_correction_max
correction_ += -correction_normalisation * new_estimate

# correction_ += -(correction_normalisation * new_mask * lo_freq_estimate * 0.5) -(correction_normalisation * hi_freq_estimate)
else:
# freq_estimate = zern_recon.transformation_matrix.dot(m_corrections*modal_means_)
# freq_estimate *= np.max(freq_estimate) / new_correction_max
# correction_ += -(correction_normalisation * (ncpa_estimate_-freq_estimate))
# freq_estimate = zern_recon.transformation_matrix.dot(m_corrections*modal_means_)
# freq_estimate *= lo_freq_corr
# hi_freq_estimate = (ncpa_estimate_ - freq_estimate) * (aperture_-ncpa_filter_aperture_)
# # freq_estimate *= np.max(freq_estimate) / new_correction_max
# correction_ += -(correction_normalisation * freq_estimate) +(correction_normalisation * hi_freq_estimate)/10.

m_corrections = recon_zern_recon.dot(ncpa_estimate_ * aperture_)
ncpa_estimate_temp = np.copy(ncpa_estimate_)
new_correction_max = np.max(zern_recon.transformation_matrix.dot(m_corrections*modal_means_init_))

lo_freq_estimate = zern_recon.transformation_matrix.dot(m_corrections*modal_means_)
lo_freq_estimate *= np.max(lo_freq_estimate) / new_correction_max
hi_freq_estimate = (ncpa_estimate_ - lo_freq_estimate)/10. * ncpa_filter_aperture_

correction_ += -(correction_normalisation * lo_freq_estimate)# -(correction_normalisation * hi_freq_estimate)

# correction_ += -(correction_normalisation * lo_freq_estimate) -(correction_normalisation * hi_freq_estimate)

if reset == True:
try:
correction_ *= resetFilter
except ValueError:
print("No reset filter given.")

# correction_ += -correction_normalisation*ncpa_estimate_
if np.size(ms) != 1:
return correction_, ms
else:
return correction_

def processCorrection_Original(ncpa_estimate_, correction_, correction_normalisation, aperture_, ncpa_filter_aperture_, recon_zern_recon, \
zern_recon, modal_means_init_, modal_means_, reset=False, resetFilter=None, ZernCorrection=True, ms=0):

m_corrections = recon_zern_recon.dot(ncpa_estimate_ * aperture_)
ncpa_estimate_temp = np.copy(ncpa_estimate_)
new_correction_max = np.max(zern_recon.transformation_matrix.dot(m_corrections*modal_means_init_))

lo_freq_estimate = zern_recon.transformation_matrix.dot(m_corrections*modal_means_)
lo_freq_estimate *= np.max(lo_freq_estimate) / new_correction_max
hi_freq_estimate = (ncpa_estimate_ - lo_freq_estimate)/10. * ncpa_filter_aperture_

correction_ += -(correction_normalisation * lo_freq_estimate) -(correction_normalisation * hi_freq_estimate)

# correction_ += -correction_normalisation*ncpa_estimate_
if np.size(ms) != 1:
return correction_, ms
else:
return correction_

#####
# Original

# m_corrections = recon_zern_recon.dot(ncpa_estimate_ * aperture_)
# ncpa_estimate_temp = np.copy(ncpa_estimate_)
# new_correction_max = np.max(zern_recon.transformation_matrix.dot(m_corrections*modal_means_init_))
#
# lo_freq_estimate = zern_recon.transformation_matrix.dot(m_corrections*modal_means_)
# lo_freq_estimate *= np.max(lo_freq_estimate) / new_correction_max
# hi_freq_estimate = (ncpa_estimate_ - lo_freq_estimate)/10. * ncpa_filter_aperture_
#
# correction_ += -(correction_normalisation * lo_freq_estimate) -(correction_normalisation * hi_freq_estimate)
19 changes: 19 additions & 0 deletions psi/psi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
from hcipy.optics import Wavefront

def PSI(ref_, img_, I_sum, phi_sum, phi_2, phi_I, filter_, grid, prop_, i_):#, ncpa_filter_, normalisation_):
#
phi_I += ref_ * img_
phi_2 += np.abs(ref_)**2
phi_sum += ref_
I_sum += img_

psi_estimate_ = (phi_I - phi_sum * I_sum / (i_+1)) / (phi_2 - np.abs(phi_sum / (i_+1))**2)

wf = Wavefront(psi_estimate_ * filter_)

wf.electric_field *= np.exp(-2j * grid.as_('polar').theta) # This line solves the phase wrapping
pup = prop_.backward(wf)
ncpa_estimate_ = pup.electric_field.imag

return ncpa_estimate_, I_sum, phi_sum, phi_2, phi_I, psi_estimate_
32 changes: 32 additions & 0 deletions psi/psi_propagations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from hcipy.optics import Wavefront
from hcipy.propagation import Propagator
from hcipy.statistics import large_poisson

def prop_image(wf_post_coro_, ncpa_, correction_, prop_, coro_, wv=0, noise_=True, n_pho=8.88e+04/1000, background_subtract=False):
#
wf_post_coro_.electric_field *= np.exp(1j * ncpa_) * np.exp(1j * correction_)# * np.exp(1j * wv) # Add NCPA and correction
if coro_ is None:
img_oneEx_ = prop_(wf_post_coro_).power
else:
img_oneEx_ = prop_(coro_(wf_post_coro_)).power
if noise_ == True:
image_ = large_poisson(img_oneEx_) + np.random.poisson(n_pho, img_oneEx_.shape)
else:
image_ = img_oneEx
if background_subtract == True:
image_ -= np.full(image_.shape, n_pho)
return image_

def prop_wf(wf_post_ao_, aperture_, prop_, coro_, recon_m, trans_m):
#
wfs_measurement_ = recon_m.dot(np.angle(wf_post_ao_.electric_field / wf_post_ao_.electric_field.mean()) * aperture_) # * aperture is different between METIS and ERIS
reconstructed_pupil = aperture_ * np.exp(1j * trans_m.dot(wfs_measurement_))
reconstructed_pupil /= np.exp(1j * np.angle(reconstructed_pupil.mean()))
reconstructed_pupil -= aperture_

if coro_ is None:
reconstructed_electric_field_ = prop_(Wavefront(reconstructed_pupil)).electric_field
else:
reconstructed_electric_field_ = prop_(coro_(Wavefront(reconstructed_pupil))).electric_field
return reconstructed_electric_field_
Loading

0 comments on commit 7680286

Please sign in to comment.