-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DynUNet crashes with DataParallel and DeepSupervision #6442
Comments
Hi @yiheng-wang-nv , Is DynUNet thread-safe? Thanks. |
Hi @razorx89 , could you provide detailed code that can reproduce the crash issue? I did a simple test with the following code and did not meet error:
|
Here you go, it crashes on the second batch: import torch
from monai.networks.nets import DynUNet
kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
net = DynUNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
deep_supervision=True,
deep_supr_num=2,
)
net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])
x = torch.randn(16, 1, 64, 64, 64)
x = x.cuda()
for i in range(10):
print("Batch", i)
net(x)
|
Fixes #6442 . ### Description This PR is used to fix the thread safe issue of DynUNet. Created list in forward function is replaced by a tensor. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]>
Thanks, it does not crash anymore. However, I don't think that the returned supervision heads are correct. In my experiments it does not learn at all (I cannot provide an example of this). Revisiting the posted code line above: MONAI/monai/networks/nets/dynunet.py Line 51 in 5f344cc
Add a print statement at this location: print(x.device, id(self.heads), self.index)
You will see that both replicas write to the same list instance. Thus, both replicas will return a tensor with the same content (plus/minus race conditions), regardless of the input of the replica. |
Any updates on this, @yiheng-wang-nv? I would love to see this issue reopened, since it is still not working correctly. |
Hi @ericspod , could you please help to give some suggestions here? In multi-thread (DataParallel) case, it seems the interpolate way ( MONAI/monai/networks/nets/dynunet.py Line 272 in c2a9a31
|
Add an error example, when set batch size to 9 and using 2 GPUs,
|
it seems that making device-specific heads works fine #6484 but it's a bit hacky.. |
I'm not sure device-specific heads are going to be enough if |
yes, at each forward pass, new threads are created https://github.com/pytorch/pytorch/blob/0bf9722a3af5c00125ccb557a3618f11e5413236/torch/nn/parallel/parallel_apply.py#L73 it's not easy to get a generic solution and compatible with torchscript. I guess the main use case of |
Maybe my two cents on why I am using |
thanks, perhaps these changes work for your use case #6484 |
Hi @razorx89, I ran into the same issue of Deep supervision not working in
|
Describe the bug
DynUNet crashes in a
torch.nn.DataParallel
scenario, since a mutable list is used to get the supervision heads.MONAI/monai/networks/nets/dynunet.py
Lines 212 to 219 in 5f344cc
MONAI/monai/networks/nets/dynunet.py
Line 51 in 5f344cc
This does not work for multiple GPUs in this scenarios, because we end up with tensors in the list having different CUDA devices. The code crashes when stacking the tensors in the list at:
MONAI/monai/networks/nets/dynunet.py
Lines 271 to 275 in 5f344cc
To Reproduce
Run
torch.nn.DataParallel(DynUNet(..., deep_supervision=True), device_ids=[0, 1])
Expected behavior
DynUNet forward should be threadsafe. I know that
DistributedDataParallel
is superior and would solve the problem, however, it should still work by correctly storing results from block return values instead of using a "global" mutable list.Environment
The text was updated successfully, but these errors were encountered: