Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynUNet crashes with DataParallel and DeepSupervision #6442

Open
razorx89 opened this issue Apr 27, 2023 · 15 comments · Fixed by #6444 · May be fixed by #6484
Open

DynUNet crashes with DataParallel and DeepSupervision #6442

razorx89 opened this issue Apr 27, 2023 · 15 comments · Fixed by #6444 · May be fixed by #6484
Assignees
Labels
bug Something isn't working

Comments

@razorx89
Copy link
Contributor

Describe the bug
DynUNet crashes in a torch.nn.DataParallel scenario, since a mutable list is used to get the supervision heads.

return DynUNetSkipLayer(
index,
downsample=downsamples[0],
upsample=upsamples[0],
next_layer=next_layer,
heads=self.heads,
super_head=superheads[0],
)

self.heads[self.index - 1] = self.super_head(upout)

This does not work for multiple GPUs in this scenarios, because we end up with tensors in the list having different CUDA devices. The code crashes when stacking the tensors in the list at:
if self.training and self.deep_supervision:
out_all = [out]
for feature_map in self.heads:
out_all.append(interpolate(feature_map, out.shape[2:]))
return torch.stack(out_all, dim=1)

To Reproduce
Run torch.nn.DataParallel(DynUNet(..., deep_supervision=True), device_ids=[0, 1])

Expected behavior
DynUNet forward should be threadsafe. I know that DistributedDataParallel is superior and would solve the problem, however, it should still work by correctly storing results from block return values instead of using a "global" mutable list.

Environment

================================
Printing MONAI config...
================================
MONAI version: 1.1.0
Numpy version: 1.24.2
Pytorch version: 1.14.0a0+410ce96
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /usr/local/lib/python3.8/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.2
scikit-image version: 0.19.3
Pillow version: 9.4.0
Tensorboard version: 2.12.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.15.0a0
tqdm version: 4.64.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.4
pandas version: 1.5.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 20.04.5 LTS
Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.29
Processor: x86_64
Machine: x86_64
Python version: 3.8.10
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 48
Num logical CPUs: 48
Num usable CPUs: 48
CPU usage (%): [4.9, 5.4, 4.4, 4.4, 4.9, 4.4, 4.4, 4.4, 5.3, 4.9, 4.9, 4.9, 4.9, 5.3, 4.9, 5.3, 4.9, 4.4, 6.3, 4.9, 4.9, 4.4, 4.4, 5.3, 4.9, 4.9, 4.4, 4.9, 4.4, 4.9, 4.4, 4.9, 5.3, 4.9, 4.4, 4.9, 4.4, 4.9, 4.4, 4.9, 4.9, 4.4, 4.9, 4.4, 4.9, 4.4, 4.9, 99.5]
CPU freq. (MHz): 1646
Load avg. in last 1, 5, 15 mins (%): [0.2, 1.2, 6.7]
Disk usage (%): 60.9
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.8
Available memory (GB): 991.8
Used memory (GB): 9.7

================================
Printing GPU config...
================================
Num GPUs: 2
Has CUDA: True
CUDA version: 11.8
cuDNN enabled: True
cuDNN version: 8700
Current device: 0
Library compiled for CUDA architectures: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'compute_90']
GPU 0 Name: NVIDIA RTX A6000
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 47.5
GPU 0 CUDA capability (maj.min): 8.6
GPU 1 Name: NVIDIA RTX A6000
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 84
GPU 1 Total memory (GB): 47.5
GPU 1 CUDA capability (maj.min): 8.6
@Nic-Ma
Copy link
Contributor

Nic-Ma commented Apr 27, 2023

Hi @yiheng-wang-nv ,

Is DynUNet thread-safe?

Thanks.

@yiheng-wang-nv
Copy link
Contributor

Hi @razorx89 , could you provide detailed code that can reproduce the crash issue?

I did a simple test with the following code and did not meet error:

import torch
from monai.networks.nets import DynUNet

kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    deep_supervision=True,
    deep_supr_num=1,
)

net = torch.nn.DataParallel(net, device_ids=[0, 1])

@razorx89
Copy link
Contributor Author

razorx89 commented Apr 27, 2023

Here you go, it crashes on the second batch:

import torch
from monai.networks.nets import DynUNet

kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    deep_supervision=True,
    deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])

x = torch.randn(16, 1, 64, 64, 64)
x = x.cuda()

for i in range(10):
    print("Batch", i)
    net(x)
Batch 0
Batch 1
Traceback (most recent call last):
  File "dynunet_bug.py", line 27, in <module>
    net(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 601, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/monai/networks/nets/dynunet.py", line 273, in forward
    return torch.stack(out_all, dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_cat)

@yiheng-wang-nv
Copy link
Contributor

Hi @razorx89 , thanks and I reproduced the issue. I submitted a PR in #6444 and it is tested with the code you posted. Could you please help to review that PR?

wyli pushed a commit that referenced this issue Apr 27, 2023
Fixes #6442 .

### Description

This PR is used to fix the thread safe issue of DynUNet. Created list in
forward function is replaced by a tensor.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <[email protected]>
@razorx89
Copy link
Contributor Author

Thanks, it does not crash anymore. However, I don't think that the returned supervision heads are correct. In my experiments it does not learn at all (I cannot provide an example of this). Revisiting the posted code line above:

self.heads[self.index - 1] = self.super_head(upout)

Add a print statement at this location:

print(x.device, id(self.heads), self.index)
cuda:0 140709361877056 1
cuda:1 140709361877056 1

You will see that both replicas write to the same list instance. Thus, both replicas will return a tensor with the same content (plus/minus race conditions), regardless of the input of the replica.

@razorx89
Copy link
Contributor Author

razorx89 commented May 2, 2023

Any updates on this, @yiheng-wang-nv? I would love to see this issue reopened, since it is still not working correctly.

@yiheng-wang-nv
Copy link
Contributor

Hi @ericspod , could you please help to give some suggestions here? In multi-thread (DataParallel) case, it seems the interpolate way (

out_all = torch.zeros(out.shape[0], len(self.heads) + 1, *out.shape[1:], device=out.device, dtype=out.dtype)
) to concat output in different scales cannot work. I remember we did this kind of changes in order to support TorchScript.

@yiheng-wang-nv
Copy link
Contributor

Add an error example, when set batch size to 9 and using 2 GPUs, torch.nn.DataParallel will scatter the input tensor into two batch size = 4 and batch size = 5 tensors, and then error will happen

import torch
from monai.networks.nets import DynUNet

kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    deep_supervision=True,
    deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])

x = torch.randn(9, 1, 64, 64, 64)
x = x.cuda()

for i in range(10):
    print("Batch", i)
    net(x)

@yiheng-wang-nv
Copy link
Contributor

Hi @wyli @Nic-Ma @ericspod , I'm not sure if this issue can be fixed soon, what do you think of mention the limitation (of dataparallel support) in docstrings first?

@wyli wyli linked a pull request May 6, 2023 that will close this issue
7 tasks
@wyli
Copy link
Contributor

wyli commented May 6, 2023

it seems that making device-specific heads works fine #6484 but it's a bit hacky..

@ericspod
Copy link
Member

ericspod commented May 6, 2023

I'm not sure device-specific heads are going to be enough if DataParallel is doing things with multiple threads. It's a race condition when two or more threads is accessing and modifying the self.heads dictionary. In normal Python we'd use something like local to have thread-specific data or use locks to synchronize access to a shared object, but I don't know about Torchscript compatibility for that. Perhaps it's best to just say that DataParallel is compatible since DistributedDataParallel could be used in place.

@wyli
Copy link
Contributor

wyli commented May 6, 2023

yes, at each forward pass, new threads are created https://github.com/pytorch/pytorch/blob/0bf9722a3af5c00125ccb557a3618f11e5413236/torch/nn/parallel/parallel_apply.py#L73 it's not easy to get a generic solution and compatible with torchscript.

I guess the main use case of DataParallel is to quickly try large batch sizes when there are multiple gpus. to properly leverage the gpus to accelerate training DistributedDataParallel is necessary

@razorx89
Copy link
Contributor Author

razorx89 commented May 8, 2023

Maybe my two cents on why I am using DataParallel: I am working on a single node multi gpu system. During training I am running every n-th epoch an evaluation on the validation set, where I am using a sliding window inferer on full size CT images. During training I am distributing cropped images from multiple CTs across multiple GPUs and aggregate the losses before the optimization step. During validation I am computing metrics on a single CT at a time using a sliding window inferer with a DataParallel model. So sw_batch_size crops from the sliding window algorithm get distributed across the GPUs. I cannot batch the CT images since they all have a different number of slices (depth), or I have to use padding which may increase memory usage and/or runtime (e.g. mixing whole body with abdomen CTs). This pattern is as far as I know not easily implementable using DistributedDataParallel.

@wyli
Copy link
Contributor

wyli commented May 8, 2023

thanks, perhaps these changes work for your use case #6484

@KumoLiu KumoLiu added the bug Something isn't working label Dec 7, 2023
@chezhia
Copy link

chezhia commented Jun 18, 2024

Hi @razorx89, I ran into the same issue of Deep supervision not working in DataParallel mode. The solution I came up with is to change these lines in the DynUNet (monai/networks/nets/dynunet.py) definition and it seems to work for my test case. Can you try this:

def forward(self, x):
        out = self.skip_layers(x)
        out = self.output_block(out)
        if self.training and self.deep_supervision:
            out_all = [out]  # 'out' should be on 'cuda:0' by default
            for feature_map in self.heads:
                # Interpolate feature map to the size of 'out' and ensure device consistency if necessary
                interpolated_map = interpolate(feature_map, out.shape[2:]).to(out.device)
                out_all.append(interpolated_map)
            return torch.stack(out_all, dim=1)  # This should not cause device mismatch errors
        return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
7 participants