Skip to content

Commit

Permalink
Enable time-of-flight indexing and Laue/ToF refinement (dials#2662)
Browse files Browse the repository at this point in the history
* Enable time-of-flight indexing and Laue/time-of-flight refinement.
  • Loading branch information
toastisme authored Jun 24, 2024
1 parent db15ca2 commit d04235d
Show file tree
Hide file tree
Showing 24 changed files with 2,513 additions and 256 deletions.
1 change: 1 addition & 0 deletions newsfragments/2662.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add classes to support time-of-flight and Laue indexing and refinement.
36 changes: 32 additions & 4 deletions src/dials/algorithms/indexing/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import iotbx.phil
import libtbx
from cctbx import sgtbx
from dxtbx.model import ExperimentList, ImageSequence
from dxtbx.model import ExperimentList, ImageSequence, tof_helpers

import dials.util
from dials.algorithms.indexing import (
Expand Down Expand Up @@ -873,7 +873,7 @@ def _xyzcal_mm_to_px(self, experiments, reflections):
refined_reflections = reflections.select(imgset_sel)
panel_numbers = flex.size_t(refined_reflections["panel"])
xyzcal_mm = refined_reflections["xyzcal.mm"]
x_mm, y_mm, z_rad = xyzcal_mm.parts()
x_mm, y_mm, z = xyzcal_mm.parts()
xy_cal_mm = flex.vec2_double(x_mm, y_mm)
xy_cal_px = flex.vec2_double(len(xy_cal_mm))
for i_panel in range(len(expt.detector)):
Expand All @@ -884,10 +884,18 @@ def _xyzcal_mm_to_px(self, experiments, reflections):
)
x_px, y_px = xy_cal_px.parts()
if expt.scan is not None:
z_px = expt.scan.get_array_index_from_angle(z_rad, deg=False)
if expt.scan.has_property("time_of_flight"):
tof = expt.scan.get_property("time_of_flight")
frames = list(range(len(tof)))
tof_to_frame = tof_helpers.tof_to_frame_interpolator(tof, frames)
z.set_selected(z < min(tof), min(tof))
z.set_selected(z > max(tof), max(tof))
z_px = flex.double(tof_to_frame(z))
else:
z_px = expt.scan.get_array_index_from_angle(z, deg=False)
else:
# must be a still image, z centroid not meaningful
z_px = z_rad
z_px = z
xyzcal_px = flex.vec3_double(x_px, y_px, z_px)
reflections["xyzcal.px"].set_selected(imgset_sel, xyzcal_px)

Expand Down Expand Up @@ -941,6 +949,25 @@ def find_max_cell(self):
self.params.max_cell = params.multiplier * max(uc_params[:3])
logger.info("Using max_cell: %.1f Angstrom", self.params.max_cell)
else:

convert_reflections_z_to_deg = True
all_tof_experiments = False
for expt in self.experiments:
if expt.scan is not None and expt.scan.has_property(
"time_of_flight"
):
all_tof_experiments = True
elif all_tof_experiments:
raise ValueError(
"Cannot find max cell for ToF and non-ToF experiments at the same time"
)

if all_tof_experiments:
if params.step_size < 100:
logger.info("Setting default ToF step size to 500 usec")
params.step_size = 500
convert_reflections_z_to_deg = False

self.params.max_cell = find_max_cell(
self.reflections,
max_cell_multiplier=params.multiplier,
Expand All @@ -952,6 +979,7 @@ def find_max_cell(self):
filter_ice=params.filter_ice,
filter_overlaps=params.filter_overlaps,
overlaps_border=params.overlaps_border,
convert_reflections_z_to_deg=convert_reflections_z_to_deg,
).max_cell
logger.info("Found max_cell: %.1f Angstrom", self.params.max_cell)

Expand Down
2 changes: 1 addition & 1 deletion src/dials/algorithms/indexing/lattice_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def choose_best_orientation_matrix(self, candidate_orientation_matrices):
experiments = ExperimentList()
for i_expt, expt in enumerate(self.experiments):
# XXX Not sure if we still need this loop over self.experiments
if expt.scan is not None:
if expt.scan is not None and expt.scan.has_property("oscillation"):
start, end = expt.scan.get_oscillation_range()
if (end - start) > 360:
# only use reflections from the first 360 degrees of the scan
Expand Down
2 changes: 2 additions & 0 deletions src/dials/algorithms/indexing/max_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def find_max_cell(
filter_ice=True,
filter_overlaps=True,
overlaps_border=0,
convert_reflections_z_to_deg=True,
):
logger.debug("Finding suitable max_cell based on %i reflections", len(reflections))
# Exclude potential ice-ring spots from nearest neighbour analysis if needed
Expand Down Expand Up @@ -63,6 +64,7 @@ def find_max_cell(
percentile=nearest_neighbor_percentile,
histogram_binning=histogram_binning,
nn_per_bin=nn_per_bin,
convert_reflections_z_to_deg=convert_reflections_z_to_deg,
)
except AssertionError as e:
raise DialsIndexError("Failure in nearest neighbour analysis:\n" + str(e))
Expand Down
4 changes: 2 additions & 2 deletions src/dials/algorithms/indexing/model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def score_by_volume(self, reverse=False):
def score_by_rmsd_xy(self, reverse=False):
# smaller rmsds = better
rmsd_x, rmsd_y, rmsd_z = flex.vec3_double(
s.rmsds for s in self.all_solutions
s.rmsds[:3] for s in self.all_solutions
).parts()
rmsd_xy = flex.sqrt(flex.pow2(rmsd_x) + flex.pow2(rmsd_y))
score = flex.log(rmsd_xy) / math.log(2)
Expand Down Expand Up @@ -275,7 +275,7 @@ def __str__(self):
perm = flex.sort_permutation(combined_scores)

rmsd_x, rmsd_y, rmsd_z = flex.vec3_double(
s.rmsds for s in self.all_solutions
s.rmsds[:3] for s in self.all_solutions
).parts()
rmsd_xy = flex.sqrt(flex.pow2(rmsd_x) + flex.pow2(rmsd_y))

Expand Down
18 changes: 11 additions & 7 deletions src/dials/algorithms/indexing/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
percentile=None,
histogram_binning="linear",
nn_per_bin=5,
convert_reflections_z_to_deg=True,
):
self.tolerance = tolerance # Margin of error for max unit cell estimate
from scitbx.array_family import flex
Expand All @@ -28,7 +29,10 @@ def __init__(
else:
entering_flags = flex.bool(reflections.size(), True)
rs_vectors = reflections["rlp"]
phi_deg = reflections["xyzobs.mm.value"].parts()[2] * (180 / math.pi)

z = reflections["xyzobs.mm.value"].parts()[2]
if convert_reflections_z_to_deg:
z = z * (180 / math.pi)

d_spacings = flex.double()
# nearest neighbor analysis
Expand All @@ -38,16 +42,16 @@ def __init__(
sel_imageset = reflections["imageset_id"] == imageset_id
if sel_imageset.count(True) == 0:
continue
phi_min = flex.min(phi_deg.select(sel_imageset))
phi_max = flex.max(phi_deg.select(sel_imageset))
d_phi = phi_max - phi_min
n_steps = max(int(math.ceil(d_phi / step_size)), 1)
z_min = flex.min(z.select(sel_imageset))
z_max = flex.max(z.select(sel_imageset))
d_z = z_max - z_min
n_steps = max(int(math.ceil(d_z / step_size)), 1)

for n in range(n_steps):
sel_step = (
sel_imageset
& (phi_deg >= (phi_min + n * step_size))
& (phi_deg < (phi_min + (n + 1) * step_size))
& (z >= (z_min + n * step_size))
& (z < (z_min + (n + 1) * step_size))
)

for entering in (True, False):
Expand Down
142 changes: 142 additions & 0 deletions src/dials/algorithms/refinement/parameterisation/beam_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,145 @@ def get_state(self):

# only a single beam exists, so no multi_state_elt argument is allowed
return matrix.col(self._model.get_s0())


class LaueBeamMixin:
"""Mix-in class defining some functionality unique to Laue beam parameterisations
that can be shared by static and scan-varying versions"""

@staticmethod
def _build_p_list(unit_s0, goniometer, parameter_type=Parameter):
"""Build the list of parameters, using the parameter_type callback to
select between versions of the Parameter class"""

# Set up the parameters
if goniometer:
spindle = matrix.col(goniometer.get_rotation_axis())
unit_s0_plane_dir2 = unit_s0.cross(spindle).normalize()
unit_s0_plane_dir1 = unit_s0_plane_dir2.cross(unit_s0).normalize()
else:
unit_s0_plane_dir1 = unit_s0.ortho().normalize()
unit_s0_plane_dir2 = unit_s0.cross(unit_s0_plane_dir1).normalize()

# rotation around unit_s0_plane_dir1
mu1 = parameter_type(0.0, unit_s0_plane_dir1, "angle (mrad)", "Mu1")
# rotation around unit_s0_plane_dir2
mu2 = parameter_type(0.0, unit_s0_plane_dir2, "angle (mrad)", "Mu2")

# build the parameter list in a specific, maintained order
p_list = [mu1, mu2]

return p_list

@staticmethod
def _compose_core(is0, ipn, mu1, mu2, mu1_axis, mu2_axis):

# convert angles to radians
mu1rad, mu2rad = mu1 / 1000.0, mu2 / 1000.0

# compose rotation matrices and their first order derivatives
Mu1 = (mu1_axis).axis_and_angle_as_r3_rotation_matrix(mu1rad, deg=False)
dMu1_dmu1 = dR_from_axis_and_angle(mu1_axis, mu1rad, deg=False)

Mu2 = (mu2_axis).axis_and_angle_as_r3_rotation_matrix(mu2rad, deg=False)
dMu2_dmu2 = dR_from_axis_and_angle(mu2_axis, mu2rad, deg=False)

# compose new state
Mu21 = Mu2 * Mu1
unit_s0 = (Mu21 * is0).normalize()
pn_new_dir = (Mu21 * ipn).normalize()

# calculate derivatives of the beam direction wrt angles:
# 1) derivative wrt mu1
dMu21_dmu1 = Mu2 * dMu1_dmu1
dunit_s0_new_dir_dmu1 = dMu21_dmu1 * is0

# 2) derivative wrt mu2
dMu21_dmu2 = dMu2_dmu2 * Mu1
dunit_s0_new_dir_dmu2 = dMu21_dmu2 * is0

# calculate derivatives of the attached beam vector, converting
# parameters back to mrad
dunit_s0_dval = [
dunit_s0_new_dir_dmu1 / 1000.0,
dunit_s0_new_dir_dmu2 / 1000.0,
unit_s0,
]

return (unit_s0, pn_new_dir), dunit_s0_dval


class LaueBeamParameterisation(ModelParameterisation, LaueBeamMixin):
"""A parameterisation of a Laue Beam model, where wavelength is ignored.
The Beam direction is parameterised using angles expressed in
mrad. A goniometer can be provided (if
present in the experiment) to ensure a consistent definition of the beam
rotation angles with respect to the spindle-beam plane."""

def __init__(self, beam, goniometer=None, experiment_ids=None):
"""Initialise the BeamParameterisation object
Args:
beam: A dxtbx PolychromaticBeam object to be parameterised.
goniometer: An optional dxtbx Goniometer object. Defaults to None.
experiment_ids (list): The experiment IDs affected by this
parameterisation. Defaults to None, which is replaced by [0].
"""
# The state of the beam model consists of the unit s0 vector that it is
# modelling. The initial state is the direction of this vector at the point
# of initialisation, plus the direction of the orthogonal polarization
# normal vector. Future states are composed by rotations around axes
# perpendicular to that direction.

# Set up the initial state
if experiment_ids is None:
experiment_ids = [0]
unit_s0 = matrix.col(beam.get_unit_s0())
istate = {
"unit_s0": matrix.col(unit_s0),
"polarization_normal": matrix.col(beam.get_polarization_normal()),
}

# build the parameter list
p_list = self._build_p_list(unit_s0, goniometer)

# set up the base class
ModelParameterisation.__init__(
self, beam, istate, p_list, experiment_ids=experiment_ids
)

# call compose to calculate all the derivatives
self.compose()

return

def compose(self):

# extract direction from the initial state
ius0 = self._initial_state["unit_s0"]
ipn = self._initial_state["polarization_normal"]

# extract parameters from the internal list
mu1, mu2 = self._param

# calculate new s0 and derivatives
(unit_s0, pn), self._dstate_dp = self._compose_core(
ius0,
ipn,
mu1.value,
mu2.value,
mu1_axis=mu1.axis,
mu2_axis=mu2.axis,
)

# now update the model with its new s0 and polarization vector
self._model.set_unit_s0(unit_s0)
self._model.set_polarization_normal(pn)

return

def get_state(self):

# only a single beam exists, so no multi_state_elt argument is allowed
return matrix.col(self._model.get_unit_s0())
28 changes: 23 additions & 5 deletions src/dials/algorithms/refinement/parameterisation/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re

import libtbx
from dxtbx.model import PolychromaticBeam
from libtbx.phil import parse

from dials.algorithms.refinement import DialsRefineConfigError
Expand All @@ -15,11 +16,12 @@
phil_str as sv_phil_str,
)
from dials.algorithms.refinement.refinement_helpers import string_sel
from dials.algorithms.refinement.reflection_manager import LaueReflectionManager
from dials.algorithms.refinement.restraints.restraints_parameterisation import (
uc_phil_str as uc_restraints_phil_str,
)

from .beam_parameters import BeamParameterisation
from .beam_parameters import BeamParameterisation, LaueBeamParameterisation
from .crystal_parameters import (
CrystalOrientationParameterisation,
CrystalUnitCellParameterisation,
Expand All @@ -31,6 +33,7 @@
)
from .goniometer_parameters import GoniometerParameterisation
from .prediction_parameters import (
LauePredictionParameterisation,
XYPhiPredictionParameterisation,
XYPhiPredictionParameterisationSparse,
)
Expand Down Expand Up @@ -433,8 +436,15 @@ def _parameterise_beams(options, experiments, analysis):
experiment_ids=exp_ids,
)
else:
# Parameterise scan static beam, passing the goniometer
beam_param = BeamParameterisation(beam, goniometer, experiment_ids=exp_ids)
if isinstance(beam, PolychromaticBeam):
beam_param = LaueBeamParameterisation(
beam, goniometer, experiment_ids=exp_ids
)
else:
# Parameterise scan static beam, passing the goniometer
beam_param = BeamParameterisation(
beam, goniometer, experiment_ids=exp_ids
)

# Set the model identifier to name the parameterisation
beam_param.model_identifier = f"Beam{ibeam + 1}"
Expand All @@ -454,7 +464,9 @@ def _parameterise_beams(options, experiments, analysis):
fix_list.append("Mu1")
if "out_spindle_plane" in options.beam.fix:
fix_list.append("Mu2")
if "wavelength" in options.beam.fix:
if "wavelength" in options.beam.fix and not isinstance(
beam, PolychromaticBeam
):
fix_list.append("nu")

if fix_list:
Expand Down Expand Up @@ -818,10 +830,16 @@ def build_prediction_parameterisation(
analysis = _centroid_analysis(options, experiments, reflection_manager)

# Parameterise each unique model
beam_params = _parameterise_beams(options, experiments, analysis)
xl_ori_params, xl_uc_params = _parameterise_crystals(options, experiments, analysis)
det_params = _parameterise_detectors(options, experiments, analysis)
gon_params = _parameterise_goniometers(options, experiments, analysis)
beam_params = _parameterise_beams(options, experiments, analysis)

if isinstance(reflection_manager, LaueReflectionManager):
PredParam = LauePredictionParameterisation
return PredParam(
experiments, det_params, beam_params, xl_ori_params, xl_uc_params
)

# Build the prediction equation parameterisation
if do_stills: # doing stills
Expand Down
Loading

0 comments on commit d04235d

Please sign in to comment.