Skip to content

Commit

Permalink
Merge pull request #31 from khanlab/cpu_inference
Browse files Browse the repository at this point in the history
Cpu-based inference
  • Loading branch information
akhanf authored Feb 1, 2021
2 parents 66ff6f5 + 60cbbbb commit f9298cb
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 50 deletions.
11 changes: 10 additions & 1 deletion hippunfold/config/snakebids.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,17 @@ parse_args:
default: False
action: 'store_true'

--use_gpu:
help: 'Enable gpu for inference by setting resource gpus=1 in run_inference rule (default: %(default)s)'
default: False
action: 'store_true'

--nnunet_disable_tta:
help: 'Disable test-time augmentation for nnU-net inference, speeds up inference by 8x, at expense of accuracy (default: %(default)s)'
default: False
action: 'store_true'


#--- workflow specific configuration --

singularity:
Expand Down
9 changes: 5 additions & 4 deletions hippunfold/workflow/rules/nnunet.smk
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,22 @@ rule run_inference:
out_folder = 'templbl',
task = parse_task_from_tar,
chkpnt = parse_chkpnt_from_tar,
disable_tta = '' if config['nnunet_disable_tta'] else '--disable_tta'
output:
nnunet_seg = bids(root='work',datatype='seg_{modality}',**config['subj_wildcards'],suffix='dseg.nii.gz',desc='nnunet',space='corobl',hemi='{hemi,Lflip|R}')
shadow: 'minimal'
threads: 16
threads: 16
resources:
gpus = 1,
gpus = 1 if config['use_gpu'] else 0,
mem_mb = 32000,
time = 30,
time = 30 if config['use_gpu'] else 60,
group: 'subj'
shell: 'mkdir -p {params.model_dir} {params.in_folder} {params.out_folder} && ' #create temp folders
'cp -v {input.in_img} {params.temp_img} && ' #cp input image to temp folder
'tar -xvf {input.model_tar} -C {params.model_dir} && ' #extract model
'export RESULTS_FOLDER={params.model_dir} && ' #set nnunet env var to point to model
'export nnUNet_n_proc_DA={threads} && ' #set threads
'nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} && ' # run inference
'nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} {params.disable_tta} && ' # run inference
'cp -v {params.temp_lbl} {output.nnunet_seg}' #copy from temp output folder to final output


Expand Down
44 changes: 0 additions & 44 deletions requirements.txt

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
install_requires=[
"snakebids>=0.2.0",
"snakemake>=5.28.0",
"nnunet>=1.6.6",
"nnunet @ git+https://github.com/ylugithub/nnUNet.git@v1.6.6",
"appdirs",
"pandas",
"nibabel",
Expand Down

0 comments on commit f9298cb

Please sign in to comment.