Skip to content

Commit

Permalink
Merge pull request #25 from uncbiag/feat-add-training
Browse files Browse the repository at this point in the history
Feat add training
  • Loading branch information
HastingsGreer authored Oct 29, 2024
2 parents baa3843 + f2e870f commit a9b1310
Show file tree
Hide file tree
Showing 5 changed files with 1,834 additions and 18 deletions.
55 changes: 37 additions & 18 deletions src/unigradicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,27 @@
from icon_registration.mermaidlite import compute_warped_image_multiNC
import icon_registration.itk_wrapper



input_shape = [1, 1, 175, 175, 175]

class GradientICONSparse(network_wrappers.RegistrationModule):
def __init__(self, network, similarity, lmbda):
def __init__(self, network, similarity, lmbda, use_label=False):

super().__init__()

self.regis_net = network
self.lmbda = lmbda
self.similarity = similarity
self.use_label = use_label

def forward(self, image_A, image_B):
def forward(self, image_A, image_B, label_A=None, label_B=None):

assert self.identity_map.shape[2:] == image_A.shape[2:]
assert self.identity_map.shape[2:] == image_B.shape[2:]
if self.use_label:
label_A = image_A if label_A is None else label_A
label_B = image_B if label_B is None else label_B
assert self.identity_map.shape[2:] == label_A.shape[2:]
assert self.identity_map.shape[2:] == label_B.shape[2:]

# Tag used elsewhere for optimization.
# Must be set at beginning of forward b/c not preserved by .cuda() etc
Expand Down Expand Up @@ -75,10 +79,29 @@ def forward(self, image_A, image_B):
1,
zero_boundary=True
)

similarity_loss = self.similarity(
self.warped_image_A, image_B
) + self.similarity(self.warped_image_B, image_A)

if self.use_label:
self.warped_label_A = compute_warped_image_multiNC(
torch.cat([label_A, inbounds_tag], axis=1) if inbounds_tag is not None else label_A,
self.phi_AB_vectorfield,
self.spacing,
1,
)

self.warped_label_B = compute_warped_image_multiNC(
torch.cat([label_B, inbounds_tag], axis=1) if inbounds_tag is not None else label_B,
self.phi_BA_vectorfield,
self.spacing,
1,
)

similarity_loss = self.similarity(
self.warped_label_A, label_B
) + self.similarity(self.warped_label_B, label_A)
else:
similarity_loss = self.similarity(
self.warped_image_A, image_B
) + self.similarity(self.warped_image_B, image_A)

if len(self.input_shape) - 2 == 3:
Iepsilon = (
Expand Down Expand Up @@ -142,8 +165,10 @@ def forward(self, image_A, image_B):

def clean(self):
del self.phi_AB, self.phi_BA, self.phi_AB_vectorfield, self.phi_BA_vectorfield, self.warped_image_A, self.warped_image_B
if self.use_label:
del self.warped_label_A, self.warped_label_B

def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5)):
def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.LNCC(sigma=5), use_label=False):
dimension = len(input_shape) - 2
inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension))

Expand All @@ -155,7 +180,7 @@ def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.L
if include_last_step:
inner_net = icon.TwoStepRegistration(inner_net, icon.FunctionFromVectorField(networks.tallUNet2(dimension=dimension)))

net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda)
net = GradientICONSparse(inner_net, loss_fn, lmbda=lmbda, use_label=use_label)
net.assign_identity_map(input_shape)
return net

Expand Down Expand Up @@ -237,7 +262,7 @@ def preprocess(image, modality="ct", segmentation=None):
min_ = -1000
max_ = 1000
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
image = itk.clamp_image_filter(image, Bounds=(-1000, 1000))
image = itk.clamp_image_filter(image, Bounds=(min_, max_))
elif modality == "mri":
image = itk.CastImageFilter[type(image), itk.Image[itk.F, 3]].New()(image)
min_, _ = itk.image_intensity_min_max(image)
Expand Down Expand Up @@ -358,10 +383,4 @@ def warp_command():
use_reference_image=True,
reference_image=fixed
)
itk.imwrite(warped_moving_image, args.warped_moving_out)






itk.imwrite(warped_moving_image, args.warped_moving_out)
Loading

0 comments on commit a9b1310

Please sign in to comment.