Skip to content

Commit

Permalink
Add preprocessing and postprocessing to forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Dec 20, 2024
1 parent b6d63c7 commit 3dc2864
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"numpy<2", # pytorch 2.2.2-py3.9_0 for macos is compiled with numpy 1.*
"protobuf",
"pydantic>=2.7.0,<2.10",
"pytorch3dunet",
# "pytorch3dunet", # todo: this doesn't exist
"pyyaml",
"xarray",
],
Expand Down
54 changes: 54 additions & 0 deletions tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import yaml
from pytorch3dunet.augment.transforms import Compose, Normalize, Standardize, ToTensor
from pytorch3dunet.datasets.utils import get_train_loaders
from pytorch3dunet.unet3d.losses import get_loss_criterion
from pytorch3dunet.unet3d.metrics import get_evaluation_metric
Expand Down Expand Up @@ -97,6 +98,8 @@ def __init__(
self,
model,
device,
in_channels,
out_channels,
optimizer,
lr_scheduler,
loss_criterion,
Expand Down Expand Up @@ -139,6 +142,8 @@ def __init__(
pre_trained=pre_trained,
**kwargs,
)
self._in_channels = in_channels
self._out_channels = out_channels
self._device = device
self.logs_callbacks: LogsCallbacks = BaseCallbacks()
self.should_stop_callbacks: Callbacks = ShouldStopCallbacks()
Expand Down Expand Up @@ -170,16 +175,60 @@ def forward(self, input_tensors: List[torch.Tensor]):
if self.is_2d_model() and z != 1:
raise ValueError(f"2d model detected but z != 1 for tensor {input_tensor.shape}")

# todo: normalization need to be consistent with the training one (it should be retrieved by the config)
preprocessor = Compose([Standardize(), ToTensor(expand_dims=True)])
input_tensor = self._apply_transformation(compose=preprocessor, tensor=input_tensor)

def apply_final_activation(input_tensors) -> torch.Tensor:
if self.model.final_activation is not None:
return self.model.final_activation(input_tensors)
return input_tensors

with torch.no_grad():
if self.is_2d_model():
input_tensor = input_tensor.squeeze(dim=-3) # b, c, [z], y, x
predictions = self.model(input_tensor.to(self._device))
predictions = apply_final_activation(predictions)
predictions = predictions.unsqueeze(dim=-3) # for consistency
else:
predictions = self.model(input_tensor.to(self._device))
predictions = apply_final_activation(predictions)

predictions = predictions.cpu()

# this needs to be exposed as well
# currently we scale the features from 0 - 1 (consistent scale for rendering across channels)
postprocessor = Compose([Normalize(norm01=True), ToTensor(expand_dims=True)])
predictions = self._apply_transformation(compose=postprocessor, tensor=predictions)
return predictions

def _apply_transformation(self, compose: Compose, tensor: torch.Tensor) -> torch.Tensor:
"""
To apply transformations pytorch 3d unet requires shape of DxHxW or CxDxHxW
"""
b, c, z, y, x = tensor.shape
non_batch_tensors = []
for batch_idx in range(b):
# drop batch
non_batch_tensor = tensor[batch_idx, :]

# drop channel dim if single channel
dropped_channel = non_batch_tensor.squeeze(dim=-4) if self.is_input_single_channel() else non_batch_tensor

# adds channel back with the`expand_dims`
transformed_tensor = compose(dropped_channel.detach().cpu().numpy())

non_batch_tensors.append(transformed_tensor)

# add batch dim again
return torch.stack(non_batch_tensors, dim=0)

def is_input_single_channel(self) -> bool:
return self._in_channels == 1

def is_output_single_channel(self) -> bool:
return self._out_channels == 1

@staticmethod
def get_axes_from_tensor(tensor: torch.Tensor) -> Tuple[str, ...]:
if tensor.ndim != 5:
Expand Down Expand Up @@ -238,6 +287,9 @@ def parse(self) -> Trainer:

model = get_model(config["model"])

in_channels = config["model"]["in_channels"]
out_channels = config["model"]["out_channels"]

if torch.cuda.device_count() > 1 and not config["device"] == "cpu":
model = nn.DataParallel(model)
if torch.cuda.is_available() and not config["device"] == "cpu":
Expand Down Expand Up @@ -266,6 +318,8 @@ def parse(self) -> Trainer:

return Trainer(
device=config["device"],
in_channels=in_channels,
out_channels=out_channels,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
Expand Down

0 comments on commit 3dc2864

Please sign in to comment.