Skip to content

Commit

Permalink
Merge pull request #54 from tbirdso/convergence-criteria
Browse files Browse the repository at this point in the history
ENH: Add convergence criteria and fix registration verbosity issue
  • Loading branch information
thewtex authored May 19, 2021
2 parents 70b31c1 + 49269bf commit daa1d72
Show file tree
Hide file tree
Showing 8 changed files with 596 additions and 1,079 deletions.
1,413 changes: 506 additions & 907 deletions examples/MeshToMeshRegistration.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions src/hasi/hasi/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def register_template_to_sample(template_mesh:itk.Mesh,
verbose=False):
from .pointsetentropyregistrar import PointSetEntropyRegistrar

registrar = PointSetEntropyRegistrar()
registrar = PointSetEntropyRegistrar(verbose=verbose)
metric = itk.EuclideanDistancePointSetToPointSetMetricv4[itk.PointSet[itk.F,3]].New()
transform = itk.Euler3DTransform[itk.D].New()

Expand All @@ -218,8 +218,7 @@ def register_template_to_sample(template_mesh:itk.Mesh,
metric=metric,
transform=transform,
learning_rate=learning_rate,
max_iterations=max_iterations,
verbose=verbose)
max_iterations=max_iterations)
return deformed_mesh


Expand Down
61 changes: 28 additions & 33 deletions src/hasi/hasi/diffeoregistrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@

class DiffeoRegistrar(MeshToMeshRegistrar):

# Common definitions
MAX_ITERATIONS = 200

# Type definitions for function annotations
Dimension = 3
PixelType = itk.F
ImageType = itk.Image[PixelType, Dimension]
Expand All @@ -41,48 +39,50 @@ class DiffeoRegistrar(MeshToMeshRegistrar):
FixedPointSetType = itk.PointSet[itk.F, Dimension]

MeshType = itk.Mesh[itk.F, Dimension]

# Definitions for diffeometric registration
TransformType = itk.DisplacementFieldTransform[itk.F, Dimension]

VectorPixelType = itk.Vector[itk.F, Dimension]
DisplacementFieldType = itk.Image[VectorPixelType, Dimension]
DiffeoRegistrationFilterType = \
itk.DiffeomorphicDemonsRegistrationFilter[ImageType,
ImageType,
DisplacementFieldType]

# Default function values
MAX_ITERATIONS = 200
STANDARD_DEVIATIONS = 1.0
TransformType = \
itk.DisplacementFieldTransform[itk.F, Dimension]

MAX_RMS_ERROR = 0.0

def __init__(self):
self.initialize()
def __init__(self, verbose:bool=False):
super(self.__class__,self).__init__(verbose=verbose)

# Expose method to reset persistent objects as desired
def initialize(self):
VectorPixelType = itk.Vector[itk.F, self.Dimension]
DisplacementFieldType = itk.Image[VectorPixelType, self.Dimension]
DiffeoRegistrationFilterType = \
itk.DiffeomorphicDemonsRegistrationFilter[self.ImageType,
self.ImageType,
DisplacementFieldType]
self.filter = DiffeoRegistrationFilterType.New(
self.filter = self.DiffeoRegistrationFilterType.New(
StandardDeviations=self.STANDARD_DEVIATIONS)

# Print iteration updates
if(self.verbose):
def print_iteration():
print(f'Iteration: {self.filter.GetElapsedIterations()}'
f' Metric: {self.filter.GetMetric()}'
f' RMS Change: {self.filter.GetRMSChange()}')
self.filter.AddObserver(itk.ProgressEvent(),print_iteration)

# Register meshes with diffeomorphic demons algorithm
def register(self,
template_mesh:MeshType,
target_mesh:MeshType,
filepath:str=None,
verbose=False,
max_iterations=MAX_ITERATIONS) -> (TransformType, MeshType):
max_iterations:int=MAX_ITERATIONS,
max_rms_error:float=MAX_RMS_ERROR,
verbose=False) -> (TransformType, MeshType):

(template_image, target_image) = self.mesh_to_image([template_mesh, target_mesh])

self.filter.SetFixedImage(template_image)
self.filter.SetMovingImage(target_image)

self.filter.SetNumberOfIterations(max_iterations)

def print_iteration():
metric = self.filter.GetMetric()
print(f'{metric}')

if(verbose):
self.filter.AddObserver(itk.ProgressEvent(),print_iteration)
self.filter.SetMaximumRMSError(max_rms_error)

# Run registration
self.filter.Update()
Expand All @@ -100,9 +100,4 @@ def print_iteration():
Transform=transform)
transform_filter.Update()

if filepath is not None:
itk.meshwrite(transform_filter.GetOutput(),filepath)
if(verbose):
print(f'Wrote resulting mesh to {filepath}')

return (transform, transform_filter.GetOutput())
57 changes: 20 additions & 37 deletions src/hasi/hasi/meansquaresregistrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
from .meshtomeshregistrar import MeshToMeshRegistrar

class MeanSquaresRegistrar(MeshToMeshRegistrar):
# Common definitions
MAX_ITERATIONS = 200

# Type definitions for function annotations
Dimension = 3
PixelType = itk.F
ImageType = itk.Image[PixelType, Dimension]
Expand All @@ -41,7 +39,8 @@ class MeanSquaresRegistrar(MeshToMeshRegistrar):

MeshType = itk.Mesh[itk.F, Dimension]

# Definitions for default registration
# Default function values
MAX_ITERATIONS = 200
V_SPLINE_ORDER = 3
GRID_NODES_IN_ONE_DIMENSION = 4

Expand All @@ -56,26 +55,32 @@ class MeanSquaresRegistrar(MeshToMeshRegistrar):

TransformType = itk.BSplineTransform[itk.D, Dimension, V_SPLINE_ORDER]


def __init__(self):
super(self.__class__,self).__init__()
def __init__(self, verbose:bool=False):
super(self.__class__,self).__init__(verbose=verbose)

def initialize(self):
# Optimizer is exposed so that calling scripts can reference with custom observers
self.optimizer = itk.LBFGSBOptimizerv4.New(CostFunctionConvergenceFactor=self.COST_FN_CONVERGENCE_FACTOR,
GradientConvergenceTolerance=self.GRADIENT_CONVERGENCE_TOLERANCE,
self.optimizer = itk.LBFGSBOptimizerv4.New(
MaximumNumberOfFunctionEvaluations=self.MAX_FUNCTION_EVALUATIONS,
MaximumNumberOfCorrections=self.MAX_CORRECTIONS)
# TODO initialize verbose output with observer

# Monitor optimization via observer
if(self.verbose):
def print_iteration():
print(f'Iteration: {self.optimizer.GetCurrentIteration()}'
f' Metric: {self.optimizer.GetCurrentMetricValue()}'
f' Infinity Norm: {self.optimizer.GetInfinityNormOfProjectedGradient()}')
self.optimizer.AddObserver(itk.IterationEvent(),
print_iteration)

# Register two 3D images with an LBFGSB optimizer
def register(self,
template_mesh:MeshType,
target_mesh:MeshType,
filepath:str=None,
verbose=False,
num_iterations=MAX_ITERATIONS) -> (TransformType, MeshType):
num_iterations:int=MAX_ITERATIONS,
convergence_factor:float=COST_FN_CONVERGENCE_FACTOR,
gradient_convergence_tolerance:float=GRADIENT_CONVERGENCE_TOLERANCE) \
-> (TransformType, MeshType):

(template_image, target_image) = self.mesh_to_image([template_mesh, target_mesh])

Expand Down Expand Up @@ -103,18 +108,8 @@ def register(self,
self.optimizer.SetUpperBound([0] * number_of_parameters)
self.optimizer.SetLowerBound([0] * number_of_parameters)

# Monitor optimization via observer
def print_iteration():
iteration = self.optimizer.GetCurrentIteration()
metric = self.optimizer.GetCurrentMetricValue()
infnorm = self.optimizer.GetInfinityNormOfProjectedGradient()
print(f'{iteration} {metric} {infnorm}')

# FIXME adds a duplicate observer if multiple calls without re-initialization
if(verbose):
self.optimizer.AddObserver(itk.IterationEvent(),
print_iteration)

self.optimizer.SetCostFunctionConvergenceFactor(convergence_factor)
self.optimizer.SetGradientConvergenceTolerance(gradient_convergence_tolerance)
self.optimizer.SetNumberOfIterations(num_iterations)

# Define object to handle image registration
Expand All @@ -138,19 +133,7 @@ def print_iteration():
# Registration likely attempts to set scales by default,
# no observed impact on performance from this warning
registration.Update()

# Report results
if(verbose):
print('Solution = ' + str(list(transform.GetParameters())))

# Update template
transformed_mesh = itk.transform_mesh_filter(template_mesh, transform=transform)

# Write out
# TODO move away from monolithic design, leave write responsibility to user
if filepath is not None:
itk.meshwrite(transformed_mesh, filepath)
if(verbose):
print(f'Wrote resulting mesh to {filepath}')

return (transform, transformed_mesh)
3 changes: 2 additions & 1 deletion src/hasi/hasi/meshtomeshregistrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class MeshToMeshRegistrar:
MeshType = itk.Mesh[itk.F, Dimension]
PointSetType = itk.PointSet[itk.F, Dimension]

def __init__(self):
def __init__(self,verbose:bool=False):
self.verbose = verbose
self.initialize()

# Abstract method for developer to explicitly reset internal state objects
Expand Down
81 changes: 27 additions & 54 deletions src/hasi/hasi/pointsetentropyregistrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,25 @@
from .meshtomeshregistrar import MeshToMeshRegistrar

class PointSetEntropyRegistrar(MeshToMeshRegistrar):
# Common definitions
MAX_ITERATIONS = 200

# Type definitions for function annotations
Dimension = 3
PixelType = itk.F
ImageType = itk.Image[PixelType, Dimension]
TemplateImageType = ImageType
TargetImageType = ImageType

MeshType = itk.Mesh[itk.F, Dimension]

# Class definitions
PointSetType = itk.PointSet[itk.F,Dimension]
TransformType = itk.AffineTransform[itk.D, Dimension]

def __init__(self):
super(self.__class__,self).__init__()
# Default function values
MAX_ITERATIONS = 200
LEARNING_RATE = 1.0
CONVERGENCE_WINDOW_SIZE = 1
MINIMUM_CONVERGENCE_VALUE = 1.0

def __init__(self,verbose:bool=False):
super(self.__class__,self).__init__(verbose=verbose)

def initialize(self):
# Optimizer is exposed so that calling scripts can reference with custom observers
Expand All @@ -53,55 +55,38 @@ def initialize(self):
DoEstimateLearningRateOnce=False,
DoEstimateLearningRateAtEachIteration=False,
ReturnBestParametersAndValue=True)

if self.verbose:
def print_iteration():
print(f'Iteration: {self.optimizer.GetCurrentIteration()} Metric: {self.optimizer.GetCurrentMetricValue()}')
self.optimizer.AddObserver(itk.AnyEvent(), print_iteration)

def register(self,
template_mesh:MeshType=None,
template_point_set:PointSetType=None,
target_mesh:MeshType=None,
target_point_set:PointSetType=None,
filepath:str=None,
verbose=False,
metric=None,
transform=None,
resample_rate=1.0,
max_iterations=MAX_ITERATIONS,
learning_rate=1.0,
resample_from_target=False) -> (TransformType, PointSetType):
max_iterations:int=MAX_ITERATIONS,
learning_rate:float=LEARNING_RATE,
minimum_convergence_value:float=MINIMUM_CONVERGENCE_VALUE,
convergence_window_size:int=CONVERGENCE_WINDOW_SIZE,
resample_from_target:bool=False) \
-> (TransformType, PointSetType):

# Verify a template and target were passed in
if(not (template_mesh or template_point_set) or not (target_mesh or target_point_set)):
raise Exception('Registration requires both a template and a target!')

# Need both a mesh and a point set representing the template and the transform
if(not template_mesh):
template_mesh = self.MeshType.New(Points=template_point_set.GetPoints())
else:
template_point_set = self.PointSetType.New(Points=template_mesh.GetPoints())

if(not target_mesh):
target_mesh = self.MeshType.New(Points=target_point_set.GetPoints())
else:
target_point_set = self.PointSetType.New(Points=target_mesh.GetPoints())

# Optionally resample to improve performance
if(resample_rate != 1.0):
template_point_set = self.randomly_sample_mesh_points(mesh=template_point_set, sampling_rate=resample_rate)
target_point_set = self.randomly_sample_mesh_points(mesh=target_point_set, sampling_rate=resample_rate)
template_point_set = self.PointSetType.New(Points=template_mesh.GetPoints())
target_point_set = self.PointSetType.New(Points=target_mesh.GetPoints())

# Define registration components
if not transform:
transform = self.TransformType.New()
transform.SetIdentity()

#template_image = self.mesh_to_image(template_mesh)
# Set physical dimensions for transform
#fixed_physical_dimensions = list()
#fixed_origin = list(template_image.GetOrigin())

#for i in range(0, self.Dimension):
# fixed_physical_dimensions.append(template_image.GetSpacing()[i] * \
# (template_image.GetLargestPossibleRegion().GetSize()[i] - 1))

# Default to JHCT point set entropy metric
if not metric:
metric = itk.JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4[self.PointSetType].New(
FixedPointSet=template_point_set,
Expand All @@ -113,12 +98,10 @@ def register(self,
CovarianceKNeighborhood=5,
EvaluationKNeighborhood=10,
Alpha=1.1)
#metric.SetVirtualDomainFromImage( template_image );

metric.SetFixedPointSet(template_point_set)
metric.SetMovingPointSet(target_point_set)
metric.SetMovingTransform(transform)

metric.Initialize()

# Define scales to guide gradient descent steps
Expand All @@ -129,19 +112,15 @@ def register(self,
VirtualDomainPointSet=metric.GetVirtualTransformedPointSet())

self.optimizer.SetMetric(metric)
self.optimizer.SetNumberOfIterations(max_iterations)
self.optimizer.SetScalesEstimator(shift_scale_estimator)
self.optimizer.SetLearningRate(learning_rate)

def print_iteration():
print(f'{self.optimizer.GetCurrentIteration()} {self.optimizer.GetCurrentMetricValue()}')

if verbose:
self.optimizer.AddObserver(itk.AnyEvent(), print_iteration)
self.optimizer.SetMinimumConvergenceValue(minimum_convergence_value)
self.optimizer.SetConvergenceWindowSize(convergence_window_size)
self.optimizer.SetNumberOfIterations(max_iterations)

self.optimizer.StartOptimization()

if(verbose):
if(self.verbose):
print(f'Number of iterations run: {self.optimizer.GetCurrentIteration()}')
print(f'Final metric value: {self.optimizer.GetCurrentMetricValue()}')
print(f'Final transform position: {list(self.optimizer.GetCurrentPosition())}')
Expand All @@ -155,10 +134,4 @@ def print_iteration():
if(resample_from_target):
registered_template_mesh = self.resample_template_from_target(registered_template_mesh, target_mesh)

# Write out
if filepath is not None:
itk.meshwrite(registered_template_mesh,filepath)
if(verbose):
print(f'Wrote resulting point set to {filepath}')

return (transform, registered_template_mesh)
7 changes: 3 additions & 4 deletions src/hasi/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@

import itk

MEANSQUARES_METRIC_MAXIMUM_THRESHOLD = 0.015
DIFFEO_METRIC_MAXIMUM_THRESHOLD = 0.25
POINT_SET_METRIC_MAXIMUM_THRESHOLD = 0.0001
MIN_REFINE_DISTANCE = 0.1
MAX_REFINE_DISTANCE = 1.0
MAX_ITERATIONS = 200

TEMPLATE_MESH_FILE = 'test/Input/906-R-atlas.obj'
Expand Down Expand Up @@ -182,4 +181,4 @@ def test_refine_template_from_population():

# Distance matches expectation
distance = get_pairwise_hausdorff_distance(mesh_result, template_mesh)
assert(distance > 0.7 and distance < 1.0)
assert(MIN_REFINE_DISTANCE < distance < MAX_REFINE_DISTANCE)
Loading

0 comments on commit daa1d72

Please sign in to comment.