Skip to content

Commit

Permalink
Refactor function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
garryod committed Jul 12, 2022
1 parent 71636a9 commit 9ec169a
Show file tree
Hide file tree
Showing 17 changed files with 153 additions and 207 deletions.
31 changes: 14 additions & 17 deletions src/adcorr/angular_efficiency.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
from typing import Any, Tuple, TypeVar, cast
from typing import Tuple, cast

from numpy import cos, dtype, exp, ndarray
from numpy import cos, exp

from .utils.geometry import scattering_angles

FrameDType = TypeVar("FrameDType", bound=dtype)
FramesShape = TypeVar("FramesShape", bound=Any)
from .utils.typing import Frames


def correct_angular_efficiency(
frames: ndarray[FramesShape, FrameDType],
frames: Frames,
beam_center: Tuple[float, float],
pixel_sizes: Tuple[float, float],
distance: float,
absorption_coefficient: float,
thickness: float,
) -> ndarray[FramesShape, FrameDType]:
) -> Frames:
"""Corrects for loss due to the angular efficiency of the detector head.
Corrects for loss due to the angular efficiency of the detector head, as described
in section 3.xiii and appendix C of 'The modular small-angle X-ray scattering data
correction sequence' [https://doi.org/10.1107/S1600576717015096].
Args:
frames (ndarray[FramesShape, FrameDType]): A stack of frames to be
corrected.
beam_center (Tuple[float, float]): The center position of the beam in pixels.
pixel_sizes (Tuple[float, float]): The real space size of a detector pixel.
distance (float): The distance between the detector and the sample head.
absorption_coefficient (float): The coefficient of absorption for a given
material at a given photon energy.
thickness (float): The thickness of the detector head material.
frames: A stack of frames to be corrected.
beam_center: The center position of the beam in pixels.
pixel_sizes: The real space size of a detector pixel.
distance: The distance between the detector and the sample head.
absorption_coefficient: The coefficient of absorption for a given material at a
given photon energy.
thickness: The thickness of the detector head material.
Returns:
ndarray[FramesShape, FrameDType]: The corrected stack of frames.
The corrected stack of frames.
"""
if absorption_coefficient <= 0.0:
raise ValueError("absorption coefficient must positive.")
raise ValueError("Absorption Coefficient must positive.")
if thickness <= 0.0:
raise ValueError("Thickness must be positive.")

Expand Down
23 changes: 5 additions & 18 deletions src/adcorr/background_subtraction.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,18 @@
from typing import Tuple, TypeVar
from .utils.typing import Frame, Frames

from numpy import dtype, ndarray

FrameDType = TypeVar("FrameDType", bound=dtype)
NumFrames = TypeVar("NumFrames", bound=int)
FrameWidth = TypeVar("FrameWidth", bound=int)
FrameHeight = TypeVar("FrameHeight", bound=int)


def subtract_background(
foreground_frames: ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType],
background_frame: ndarray[Tuple[FrameWidth, FrameHeight], FrameDType],
) -> ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]:
def subtract_background(foreground_frames: Frames, background_frame: Frame) -> Frames:
"""Subtract a background frame from a sequence of foreground frames.
Subtract a background frame from a sequence of foreground frames, as detailed in
section 3.4.6 of 'Everything SAXS: small-angle scattering pattern collection and
correction' [https://doi.org/10.1088/0953-8984/25/38/383201].
Args:
foreground_frames (ndarray[Tuple[NumFrames, FrameWidth, FrameHeight],
FrameDType]): A sequence of foreground frames to be corrected.
background_frame (ndarray[Tuple[FrameWidth, FrameHeight], FrameDType]): The
background which is to be corrected for.
foreground_frames: A sequence of foreground frames to be corrected.
background_frame: The background which is to be corrected for.
Returns:
ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]: A sequence of
corrected frames.
A sequence of corrected frames.
"""
return foreground_frames - background_frame
28 changes: 11 additions & 17 deletions src/adcorr/dark_current.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from typing import Literal, Tuple, TypeVar, Union
from typing import Literal, Tuple, Union

from numpy import atleast_1d, dtype, expand_dims, floating, ndarray

FrameDType = TypeVar("FrameDType", bound=dtype)
NumFrames = TypeVar("NumFrames", bound=int)
FrameWidth = TypeVar("FrameWidth", bound=int)
FrameHeight = TypeVar("FrameHeight", bound=int)
from .utils.typing import Frames, NumFrames


def correct_dark_current(
frames: ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType],
frames: Frames,
count_times: ndarray[Tuple[Union[NumFrames, Literal[1]]], dtype[floating]],
base_dark_current: float,
temporal_dark_current: float,
flux_dependant_dark_current: float,
) -> ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]:
) -> Frames:
"""Correct by subtracting base, temporal and flux-dependant dark currents.
Correct for incident dark current by subtracting a baselike, time dependant and a
Expand All @@ -23,18 +20,15 @@ def correct_dark_current(
[https://doi.org/10.1088/0953-8984/25/38/383201].
Args:
frames (ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]): A
stack of frames to be corrected.
count_times (ndarray[Tuple[Union[NumFrames, Literal[1]]], dtype[floating]]):
The period over which photons are counted for each frame.
base_dark_current (float): The dark current flux, irrespective of time.
temporal_dark_current (float): The dark current flux, as a factor of time.
flux_dependant_dark_current (float): The dark current flux, as a factor of
incident flux.
frames: A stack of frames to be corrected.
count_times: The period over which photons are counted for each frame.
base_dark_current: The dark current flux, irrespective of time.
temporal_dark_current: The dark current flux, as a factor of time.
flux_dependant_dark_current: The dark current flux, as a factor of incident
flux.
Returns:
ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]: The
corrected stack of frames.
The corrected stack of frames.
"""
if (count_times <= 0).any():
raise ValueError("Count times must be positive.")
Expand Down
29 changes: 11 additions & 18 deletions src/adcorr/deadtime.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from typing import Any, Literal, Tuple, TypeVar, Union, cast
from typing import Any, Literal, Tuple, Union, cast

from numpy import atleast_1d, complexfloating, dtype, expand_dims, floating, ndarray
from scipy.special import lambertw

FrameDType = TypeVar("FrameDType", bound=dtype)
NumFrames = TypeVar("NumFrames", bound=int)
FrameWidth = TypeVar("FrameWidth", bound=int)
FrameHeight = TypeVar("FrameHeight", bound=int)
from .utils.typing import Frames, NumFrames


def correct_deadtime(
frames: ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType],
frames: Frames,
count_times: ndarray[Tuple[Union[NumFrames, Literal[1]]], dtype[floating]],
minimum_pulse_separation: float,
minimum_arrival_separation: float,
) -> ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]:
) -> Frames:
"""Correct for detector deadtime by scaling counts to account for overlapping events.
Correct for detector deadtime by iteratively solving for the number of incident
Expand All @@ -24,20 +21,16 @@ def correct_deadtime(
[https://doi.org/10.1088/0953-8984/25/38/383201].
Args:
frames (ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]): A
stack of frames to be corrected.
count_times (ndarray[Tuple[Union[NumFrames, Literal[1]]], dtype[floating]]):
The period over which photons are counted for each frame.
minimum_pulse_separation (float): The minimum time difference required between
a prior pulse and the current pulse for the current pulse to be recorded
frames: A stack of frames to be corrected.
count_times: The period over which photons are counted for each frame.
minimum_pulse_separation: The minimum time difference required between a prior
pulse and the current pulse for the current pulse to be recorded correctly.
minimum_arrival_separation: The minimum time difference required between the
current pulse and a subsequent pulse for the current pulse to be recorded
correctly.
minimum_arrival_separation (float): The minimum time difference required
between the current pulse and a subsequent pulse for the current pulse to
be recorded correctly.
Returns:
ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]: The corrected
stack of frames.
The corrected stack of frames.
"""
if (count_times <= 0).any():
raise ValueError("Count times must be positive.")
Expand Down
22 changes: 8 additions & 14 deletions src/adcorr/displaced_volume.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from typing import Tuple, TypeVar, cast
from typing import cast

from numpy import dtype, ndarray
from .utils.typing import Frames

FrameDType = TypeVar("FrameDType", bound=dtype)
FramesShape = TypeVar("FramesShape", bound=Tuple[int, int, int])


def correct_displaced_volume(
frames: ndarray[FramesShape, FrameDType],
displaced_fraction: float,
) -> ndarray[FramesShape, FrameDType]:
def correct_displaced_volume(frames: Frames, displaced_fraction: float) -> Frames:
"""Correct for displaced volume of solvent by multiplying signal by retained fraction.
Correct for displaced volume of solvent by multiplying signal by the retained
fraction, as detailed in section 3.xviii and appendix B of `The modular small-angle
fraction, as detailed in section 3.xviii and appendix B of 'The modular small-angle
X-ray scattering data correction sequence'
[https://doi.org/10.1107/S1600576717015096].
Args:
frames (ndarray[FramesShape, FrameDType]): A stack of frames to be corrected.
displaced_fraction (float): The fraction of solvent displaced by the analyte.
frames: A stack of frames to be corrected.
displaced_fraction: The fraction of solvent displaced by the analyte.
Returns:
ndarray[FramesShape, FrameDType]: The corrected stack of frames.
The corrected stack of frames.
"""
if displaced_fraction < 0.0 or displaced_fraction > 1.0:
raise ValueError("Displaced Fraction must be in interval [0, 1].")

return cast(ndarray[FramesShape, FrameDType], frames * (1.0 - displaced_fraction))
return cast(Frames, frames * (1.0 - displaced_fraction))
21 changes: 6 additions & 15 deletions src/adcorr/flatfield.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
from typing import Tuple, TypeVar

from numpy import dtype, floating, ndarray

FrameDType = TypeVar("FrameDType", bound=dtype)
NumFrames = TypeVar("NumFrames", bound=int)
FrameWidth = TypeVar("FrameWidth", bound=int)
FrameHeight = TypeVar("FrameHeight", bound=int)
from .utils.typing import Frames, FrameShape


def correct_flatfield(
frames: ndarray[Tuple[NumFrames, FrameHeight, FrameWidth], FrameDType],
flatfield: ndarray[Tuple[FrameHeight, FrameWidth], dtype[floating]],
) -> ndarray[Tuple[NumFrames, FrameHeight, FrameWidth], FrameDType]:
frames: Frames, flatfield: ndarray[FrameShape, dtype[floating]]
) -> Frames:
"""Apply multiplicative flatfield correction, to correct for inter-pixel sensitivity.
Apply multiplicative flatfield correction, to correct for inter-pixel sensitivity,
as described in section 3.xii of 'The modular small-angle X-ray scattering data
correction sequence' [https://doi.org/10.1107/S1600576717015096].
Args:
frames (ndarray[Tuple[NumFrames, FrameHeight, FrameWidth], FrameDType]): A
stack of frames to be corrected.
flatfield (ndarray[Tuple[FrameHeight, FrameWidth], dtype[floating]]): The
multiplicative flatfield correction to be applied.
frames: A stack of frames to be corrected.
flatfield: The multiplicative flatfield correction to be applied.
Returns:
ndarray[Tuple[NumFrames, FrameHeight, FrameWidth], FrameDType]: The corrected
stack of frames.
The corrected stack of frames.
"""
return frames * flatfield
17 changes: 6 additions & 11 deletions src/adcorr/flux_and_transmission.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
from typing import Any, TypeVar
from numpy import expand_dims, sum

from numpy import dtype, expand_dims, ndarray, sum
from .utils.typing import Frames

FrameDType = TypeVar("FrameDType", bound=dtype)
FramesShape = TypeVar("FramesShape", bound=Any)


def normalize_transmitted_flux(
frames: ndarray[FramesShape, FrameDType]
) -> ndarray[FramesShape, FrameDType]:
def normalize_transmitted_flux(frames: Frames) -> Frames:
"""Normalize for incident flux and transmissibility by scaling photon counts.
Normalize for incident flux and transmissibility by scaling photon counts with
respect to the total observed flux, as detailed in section 4 of `The modular small-
respect to the total observed flux, as detailed in section 4 of 'The modular small-
angle X-ray scattering data correction sequence'
[https://doi.org/10.1107/S1600576717015096].
Args:
frames (ndarray[FramesShape, FrameDType]): A stack of frames to be normalized.
frames: A stack of frames to be normalized.
Returns:
ndarray[FramesShape, FrameDType]: The normalized stack of frames.
The normalized stack of frames.
"""
frame_flux = expand_dims(sum(frames, axis=(-1, -2)), (-1, -2))
return frames / frame_flux
15 changes: 4 additions & 11 deletions src/adcorr/frame_average.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from math import prod
from typing import Any, TypeVar

from numpy import dtype, ndarray
from .utils.typing import Frames

FrameDType = TypeVar("FrameDType", bound=dtype)
FrameShape = TypeVar("FrameShape", bound=Any)


def average_all_frames(
frames: ndarray[FrameShape, FrameDType]
) -> ndarray[FrameShape, FrameDType]:
def average_all_frames(frames: Frames) -> Frames:
"""Average all frames over the leading axis.
Args:
frames (ndarray[FrameShape, FrameDType]): A stack of frames to be averaged.
frames: A stack of frames to be averaged.
Returns:
ndarray[FrameShape, FrameDType]: A frame containing the average pixel values
of all frames in the stack.
A frame containing the average pixel values of all frames in the stack.
"""
return frames.reshape(
[frames.size // prod(frames.shape[-2:]), *frames.shape[-2:]]
Expand Down
20 changes: 7 additions & 13 deletions src/adcorr/frame_time.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
from typing import Literal, Tuple, TypeVar, Union
from typing import Literal, Tuple, Union

from numpy import atleast_1d, dtype, expand_dims, floating, ndarray

FrameDType = TypeVar("FrameDType", bound=dtype)
NumFrames = TypeVar("NumFrames", bound=int)
FrameWidth = TypeVar("FrameWidth", bound=int)
FrameHeight = TypeVar("FrameHeight", bound=int)
from .utils.typing import Frames, NumFrames


def normalize_frame_time(
frames: ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType],
frames: Frames,
count_times: ndarray[Tuple[Union[NumFrames, Literal[1]]], dtype[floating]],
) -> ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]:
) -> Frames:
"""Normalize for detector frame rate by scaling photon counts according to count time.
Normalize for detector frame rate by scaling photon counts according to count time,
as detailed in section 3.4.3 of 'Everything SAXS: small-angle scattering pattern
collection and correction' [https://doi.org/10.1088/0953-8984/25/38/383201].
Args:
frames (ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]): A
stack of frames to be normalized.
count_times (ndarray[Tuple[Union[TimesShape, Literal[1]]], dtype[floating]]):
The period over which photons are counted for each frame.
frames: A stack of frames to be normalized.
count_times: The period over which photons are counted for each frame.
Returns:
ndarray[Tuple[NumFrames, FrameWidth, FrameHeight], FrameDType]: The normalized
stack of frames.
The normalized stack of frames.
"""
if (count_times <= 0).any():
raise ValueError("Count times must be positive.")
Expand Down
Loading

0 comments on commit 9ec169a

Please sign in to comment.