-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with executing TensorProductConv #61
Comments
Thanks for raising! Note that your src / destination tensors need to be in int64 format and on device for this to work. Can you try the line below? Z = tp_conv.forward(X, Y, W, torch.tensor([0, 1, 1, 2], dtype=torch.long, device='cuda'), torch.tensor([1, 0, 2, 1], dtype=torch.long, device='cuda')) # Z has shape [node_ct, z_ir.dim] |
Strange... the code is working on our A100 system, and we also don't use multiprocessing anywhere in our package (beyond where it is used internally by Pytorch). What invocation are you using to run this script? If you'd like to help debug, you can try commenting the lines indicated below in the file "ConvolutionBase.py". It will break the functionality, but if it runs to completion, then we know the problem lies in the kernel call. Otherwise, it lies somewhere at the Python layer.
|
I commented out the section in ConvolutionBase.py but still encountered errors. Anyway, thanks for the quick response! |
No problem! Ok - ruled out that it is a kernel call issue. How about this: try setting self.forward = None in the setup_torch_module_function, and let's see if you get an error message. If not, then the deadlock occurs before the call to forward. See below: def setup_torch_module(self):
'''
Need two different functions depending on whether the
convolution is deterministic.
'''
if not self.deterministic:
@torch.library.custom_op(f"openequivariance::conv_forward{self.conv_id}", mutates_args=(), device_types="cuda")
def forward(L1_in : torch.Tensor, L2_in : torch.Tensor,
weights : torch.Tensor, rows: torch.Tensor, cols: torch.Tensor) -> torch.Tensor:
L1_in_c, L2_in_c, weights_c = L1_in.contiguous(), L2_in.contiguous(), weights.contiguous()
L3_out = torch.zeros((L1_in_c.shape[0], self.L3.dim ), dtype=L1_in.dtype, device='cuda')
self.internal.exec_conv_rawptrs(L1_in_c.data_ptr(), L2_in_c.data_ptr(),
weights_c.data_ptr(), L3_out.data_ptr(),
rows.data_ptr(), cols.data_ptr(),
cols.shape[0], L1_in.shape[0], self.workspace_ptr)
return L3_out
@forward.register_fake
def _(L1_in, L2_in, weights, rows, cols):
return L1_in.new_empty(L1_in.shape[0], self.L3.dim)
self.forward = None #forward |
I have a hunch that the problem could actually lie with the logging library: https://stackoverflow.com/questions/24509650/deadlock-with-logging-multiprocess-multithread-python-script. One way to check this: can you head to logging_utils.py and set If that works, we can open a PR to fix this. This might also be an option, depending on whether or not changing the logging level solves the issue: https://github.com/google/python-atfork/blob/main/atfork/stdlib_fixer.py. |
@sylee0124 Any follow-up if the latest commits have helped? We suspect that multiple processes are spawned by ninja and that disabling the logger might help. |
Apologies for the delayed response. Additionally, I tested the latest commit, but the issue persists. |
We can try to debug this a little further (but it'll be difficult - I tried this on three different machines without encountering this issue, so at some point I'll need to ask if you can replicate this issue on at least one other system). For starters: do you get the same error if you a) install Pytorch Geometric and b) run |
Hi, I am trying to run
tp_conv
, but I failed to do so.NVRTC runs without any problems, but it gets stuck afterward. The memory is allocated in GPU but nothing happens.
I had no trouble when running
oeq.TensorProduct
.Here is the test code I was using. Also I installed from the source and using A100 GPU with pytorch 2.5.1-cu12.4
Thank you.
The text was updated successfully, but these errors were encountered: