Skip to content

Commit

Permalink
[MONAI] python code formatting
Browse files Browse the repository at this point in the history
Signed-off-by: monai-bot <[email protected]>
  • Loading branch information
monai-bot committed May 10, 2022
1 parent 54d9081 commit be12fa2
Show file tree
Hide file tree
Showing 21 changed files with 309 additions and 1,043 deletions.
12 changes: 3 additions & 9 deletions DiNTS/download_msd_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,11 @@

def main():
parser = argparse.ArgumentParser(description="training")
parser.add_argument(
"--msd_task", action="store", default="Task07_Pancreas", help="msd task"
)
parser.add_argument(
"--root", action="store", default="./data_msd", help="data root"
)
parser.add_argument("--msd_task", action="store", default="Task07_Pancreas", help="msd task")
parser.add_argument("--root", action="store", default="./data_msd", help="data root")
args = parser.parse_args()

resource = (
"https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.msd_task + ".tar"
)
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.msd_task + ".tar"
compressed_file = os.path.join(args.root, args.msd_task + ".tar")
if not os.path.exists(args.root):
download_and_extract(resource, compressed_file, args.root)
Expand Down
62 changes: 19 additions & 43 deletions DiNTS/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@
from glob import glob
from pathlib import Path

import monai
import nibabel as nib
import numpy as np
import pandas as pd
import torch
import yaml
from scipy import ndimage as ndi

# from monai.utils.enums import InverseKeys
from transforms import creating_transforms_ensemble, str2aug
from utils import keep_largest_cc, parse_monai_specs # parse_monai_transform_specs,

import monai
from monai.data import (
DataLoader,
Dataset,
Expand All @@ -46,33 +52,18 @@
# from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss
# from monai.metrics import compute_meandice
from monai.utils import set_determinism
from scipy import ndimage as ndi

# from monai.utils.enums import InverseKeys
from transforms import creating_transforms_ensemble, str2aug
from utils import keep_largest_cc, parse_monai_specs # parse_monai_transform_specs,


def main():
parser = argparse.ArgumentParser(description="inference")
parser.add_argument(
"--algorithm", type=str, default=None, help="ensemble algorithm"
)
parser.add_argument("--algorithm", type=str, default=None, help="ensemble algorithm")
parser.add_argument("--checkpoint", type=str, default=None, help="checkpoint")
parser.add_argument("--config", action="store", required=True, help="configuration")
parser.add_argument("--local_rank", required=int, help="local process rank")
parser.add_argument(
"--input_root", action="store", required=True, help="input root"
)
parser.add_argument(
"--original_root", action="store", required=True, help="orignal dataset root"
)
parser.add_argument(
"--output_root", action="store", required=True, help="output root"
)
parser.add_argument(
"--post", default=False, action="store_true", help="post-processing"
)
parser.add_argument("--input_root", action="store", required=True, help="input root")
parser.add_argument("--original_root", action="store", required=True, help="orignal dataset root")
parser.add_argument("--output_root", action="store", required=True, help="output root")
parser.add_argument("--post", default=False, action="store_true", help="post-processing")
parser.add_argument("--dir_list", nargs="*", type=str, default=[])
args = parser.parse_args()

Expand Down Expand Up @@ -135,9 +126,7 @@ def main():

for _i in range(num_folds):
list_filenames = []
for root, dirs, files in os.walk(
os.path.join(args.input_root, args.dir_list[_i])
):
for root, dirs, files in os.walk(os.path.join(args.input_root, args.dir_list[_i])):
for basename in files:
if "_prob1.nii" in basename:
filename = os.path.join(root, basename)
Expand All @@ -155,31 +144,23 @@ def main():
volume_list = []
for _k in range(1, output_classes):
volume_list.append(
os.path.join(
args.input_root,
all_filenames[_j][_i].replace("_prob1", "_prob" + str(_k)),
)
os.path.join(args.input_root, all_filenames[_j][_i].replace("_prob1", "_prob" + str(_k)))
)
case_dict["fold" + str(_j)] = volume_list
# print(case_dict)
files.append(case_dict)

ensemble_files = files
ensemble_files = partition_dataset(
data=ensemble_files,
shuffle=False,
num_partitions=dist.get_world_size(),
even_divisible=False,
data=ensemble_files, shuffle=False, num_partitions=dist.get_world_size(), even_divisible=False
)[dist.get_rank()]
print("ensemble_files", len(ensemble_files))

key_list = ["fold" + str(_item) for _item in range(num_folds)]
ensemble_transforms = creating_transforms_ensemble(keys=key_list)

ensemble_ds = monai.data.Dataset(data=ensemble_files, transform=ensemble_transforms)
ensemble_loader = DataLoader(
ensemble_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=False
)
ensemble_loader = DataLoader(ensemble_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=False)

start_time = time.time()
for ensemble_data in ensemble_loader:
Expand Down Expand Up @@ -213,15 +194,11 @@ def main():
print(np.amax(nda_out), np.amin(nda_out), np.mean(nda_out))
# resize to orignal data size
# find orignal data
file_basename = ensemble_data["fold0_meta_dict"]["filename_or_obj"][0].split(
os.sep
)[-1]
file_basename = ensemble_data["fold0_meta_dict"]["filename_or_obj"][0].split(os.sep)[-1]
original_data_path = list(Path(args.original_root).glob(file_basename))[0]
original_data = nib.load(original_data_path)
# get affine matrix
seg_affine = (
ensemble_data["fold0_meta_dict"]["original_affine"].numpy().squeeze()
)
seg_affine = ensemble_data["fold0_meta_dict"]["original_affine"].numpy().squeeze()
img_affine = original_data.affine
img_shape = original_data.shape
T = np.matmul(np.linalg.inv(seg_affine), img_affine)
Expand All @@ -244,8 +221,7 @@ def main():

out_img = nib.Nifti1Image(nda_out, out_affine)
out_filename = os.path.join(
args.output_root,
ensemble_data["fold0_meta_dict"]["filename_or_obj"][0].split(os.sep)[-1],
args.output_root, ensemble_data["fold0_meta_dict"]["filename_or_obj"][0].split(os.sep)[-1]
)
out_filename = out_filename.replace("_prob1", "")
print("out_filename", out_filename)
Expand Down
86 changes: 25 additions & 61 deletions DiNTS/infer_multi-gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from datetime import datetime
from glob import glob

import monai
import nibabel as nib
import numpy as np
import pandas as pd
Expand All @@ -32,6 +31,16 @@
import torch.nn.functional as F
import yaml
from auto_unet import AutoUnet
from torch import nn
from torch.nn.parallel import DistributedDataParallel

# from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from transforms import creating_transforms_testing, str2aug
from utils import parse_monai_specs # parse_monai_transform_specs,

import monai
from monai.data import (
DataLoader,
Dataset,
Expand All @@ -44,42 +53,21 @@

# from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss
from monai.metrics import compute_meandice
from monai.transforms import (
AsDiscrete,
BatchInverseTransform,
Invertd,
KeepLargestConnectedComponent,
)
from monai.transforms import AsDiscrete, BatchInverseTransform, Invertd, KeepLargestConnectedComponent
from monai.utils import set_determinism
from monai.utils.enums import InverseKeys
from torch import nn
from torch.nn.parallel import DistributedDataParallel

# from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from transforms import creating_transforms_testing, str2aug
from utils import parse_monai_specs # parse_monai_transform_specs,


def main():
parser = argparse.ArgumentParser(description="inference")
parser.add_argument("--arch_ckpt", action="store", required=True, help="data root")
parser.add_argument("--checkpoint", type=str, default=None, help="checkpoint")
parser.add_argument("--config", action="store", required=True, help="configuration")
parser.add_argument(
"--json", action="store", required=True, help="full path of .json file"
)
parser.add_argument(
"--json_key", action="store", required=True, help=".json data list key"
)
parser.add_argument("--json", action="store", required=True, help="full path of .json file")
parser.add_argument("--json_key", action="store", required=True, help=".json data list key")
parser.add_argument("--local_rank", required=int, help="local process rank")
parser.add_argument(
"--output_root", action="store", required=True, help="output root"
)
parser.add_argument(
"--prob", default=False, action="store_true", help="probility map"
)
parser.add_argument("--output_root", action="store", required=True, help="output root")
parser.add_argument("--prob", default=False, action="store_true", help="probility map")
parser.add_argument("--root", action="store", required=True, help="data root")
args = parser.parse_args()

Expand Down Expand Up @@ -141,9 +129,7 @@ def main():
transform_string = intensity_norm[_k]
transform_name, transform_dict = parse_monai_specs(transform_string)
if dist.get_rank() == 0:
print(
"\nintensity normalization {0:d}:\t{1:s}".format(_k + 1, transform_name)
)
print("\nintensity normalization {0:d}:\t{1:s}".format(_k + 1, transform_name))
for _key in transform_dict.keys():
print(" {0}:\t{1}".format(_key, transform_dict[_key]))
transform_class = getattr(monai.transforms, transform_name)
Expand All @@ -170,18 +156,13 @@ def main():

infer_files = files
infer_files = partition_dataset(
data=infer_files,
shuffle=False,
num_partitions=dist.get_world_size(),
even_divisible=False,
data=infer_files, shuffle=False, num_partitions=dist.get_world_size(), even_divisible=False
)[dist.get_rank()]
print("infer_files", len(infer_files))

# label_interpolation_transform = creating_label_interpolation_transform(label_interpolation, spacing, output_classes)
# train_transforms = creating_transforms_training(foreground_crop_margin, label_interpolation_transform, num_patches_per_image, patch_size, scale_intensity_range, augmenations)
infer_transforms = creating_transforms_testing(
foreground_crop_margin, intensity_norm_transforms, spacing
)
infer_transforms = creating_transforms_testing(foreground_crop_margin, intensity_norm_transforms, spacing)

argmax = AsDiscrete(argmax=True, to_onehot=False, n_classes=output_classes)
onehot = AsDiscrete(argmax=False, to_onehot=True, n_classes=output_classes)
Expand All @@ -195,11 +176,7 @@ def main():
# train_loader = DataLoader(train_ds, batch_size=num_images_per_batch, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
# infer_loader = DataLoader(infer_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())
infer_loader = DataLoader(
infer_ds,
batch_size=1,
shuffle=False,
num_workers=4,
pin_memory=torch.cuda.is_available(),
infer_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available()
)

# inverter = Invertd(
Expand Down Expand Up @@ -261,9 +238,7 @@ def main():
)

code_a = torch.from_numpy(code_a).to(torch.float32).cuda()
code_c = (
F.one_hot(torch.from_numpy(code_c), model.cell_ops).to(torch.float32).cuda()
)
code_c = F.one_hot(torch.from_numpy(code_c), model.cell_ops).to(torch.float32).cuda()
model = model.to(device)

if torch.cuda.device_count() > 1:
Expand All @@ -280,10 +255,7 @@ def main():
input()

saver = monai.data.NiftiSaver(
output_dir=args.output_root,
output_postfix="seg",
resample=False,
output_dtype=np.uint8,
output_dir=args.output_root, output_postfix="seg", resample=False, output_dtype=np.uint8
)

# # amp
Expand Down Expand Up @@ -467,8 +439,7 @@ def main():
# print("post-processing")

out_filename = os.path.join(
args.output_root,
infer_data["image_meta_dict"]["filename_or_obj"][0].split(os.sep)[-1],
args.output_root, infer_data["image_meta_dict"]["filename_or_obj"][0].split(os.sep)[-1]
)
# out_filename = out_filename.replace("case_", "prediction_") + ".nii.gz"
out_affine = infer_data["image_meta_dict"]["affine"].numpy().squeeze()
Expand All @@ -479,18 +450,11 @@ def main():
if args.prob:
for _k in range(1, output_classes):
out_filename = os.path.join(
args.output_root,
infer_data["image_meta_dict"]["filename_or_obj"][0].split(
os.sep
)[-1],
args.output_root, infer_data["image_meta_dict"]["filename_or_obj"][0].split(os.sep)[-1]
)
# out_filename = out_filename.replace("case_", "prediction_") + ".nii.gz"
out_filename = out_filename.replace(
".nii", "_prob{0:d}.nii".format(_k)
)
out_affine = (
infer_data["image_meta_dict"]["affine"].numpy().squeeze()
)
out_filename = out_filename.replace(".nii", "_prob{0:d}.nii".format(_k))
out_affine = infer_data["image_meta_dict"]["affine"].numpy().squeeze()

# out_img = nib.Nifti1Image(infer_outputs[_k:_k+1, ...].squeeze().astype(np.float32), out_affine)
infer_outputs_indiv = infer_outputs[_k : _k + 1, ...].squeeze()
Expand Down
Loading

0 comments on commit be12fa2

Please sign in to comment.