Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with executing TensorProductConv #61

Open
sylee0124 opened this issue Feb 14, 2025 · 9 comments · Fixed by #64
Open

Problem with executing TensorProductConv #61

sylee0124 opened this issue Feb 14, 2025 · 9 comments · Fixed by #64
Assignees
Labels
bug Something isn't working

Comments

@sylee0124
Copy link

Hi, I am trying to run tp_conv, but I failed to do so.
NVRTC runs without any problems, but it gets stuck afterward. The memory is allocated in GPU but nothing happens.
I had no trouble when running oeq.TensorProduct.
Here is the test code I was using. Also I installed from the source and using A100 GPU with pytorch 2.5.1-cu12.4

Thank you.

import torch
import e3nn.o3 as o3
import openequivariance as oeq

gen = torch.Generator(device='cuda')

batch_size = 1000
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") 


instructions=[(0, 0, 0, "uvu", True)]

problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False)

node_ct, nonzero_ct = 3, 4

# Receiver, sender indices for message passing GNN

X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen)
Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen)
W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen)

tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=False) # Reuse problem from earlier
Z = tp_conv.forward(X, Y, W, torch.tensor([0, 1, 1, 2]), torch.tensor([1, 0, 2, 1])) # Z has shape [node_ct, z_ir.dim]
print(torch.norm(Z))
@vbharadwaj-bk
Copy link
Member

Thanks for raising! Note that your src / destination tensors need to be in int64 format and on device for this to work. Can you try the line below?

Z = tp_conv.forward(X, Y, W, torch.tensor([0, 1, 1, 2], dtype=torch.long, device='cuda'), torch.tensor([1, 0, 2, 1], dtype=torch.long, device='cuda')) # Z has shape [node_ct, z_ir.dim]

@sylee0124
Copy link
Author

I tried the fix, but I still have the same problem.
I think it has something to do with deadlock in multiprocessing.
This is what I get when I kill the process.

Image

@vbharadwaj-bk
Copy link
Member

vbharadwaj-bk commented Feb 14, 2025

Strange... the code is working on our A100 system, and we also don't use multiprocessing anywhere in our package (beyond where it is used internally by Pytorch). What invocation are you using to run this script?

If you'd like to help debug, you can try commenting the lines indicated below in the file "ConvolutionBase.py". It will break the functionality, but if it runs to completion, then we know the problem lies in the kernel call. Otherwise, it lies somewhere at the Python layer.

def setup_torch_module(self):
        '''
        Need two different functions depending on whether the
        convolution is deterministic.
        '''
        if not self.deterministic:
            @torch.library.custom_op(f"openequivariance::conv_forward{self.conv_id}", mutates_args=(), device_types="cuda")
            def forward(L1_in : torch.Tensor, L2_in : torch.Tensor, 
                    weights : torch.Tensor, rows: torch.Tensor, cols: torch.Tensor) -> torch.Tensor:
                L1_in_c, L2_in_c, weights_c = L1_in.contiguous(), L2_in.contiguous(), weights.contiguous()
                L3_out = torch.zeros((L1_in_c.shape[0], self.L3.dim ), dtype=L1_in.dtype, device='cuda')

               # Comment the lines below

                #self.internal.exec_conv_rawptrs(L1_in_c.data_ptr(), L2_in_c.data_ptr(),
                #    weights_c.data_ptr(), L3_out.data_ptr(),
                #    rows.data_ptr(), cols.data_ptr(),
                #    cols.shape[0], L1_in.shape[0], self.workspace_ptr)

@sylee0124
Copy link
Author

I commented out the section in ConvolutionBase.py but still encountered errors.
When I run this test Python code, it spawns multiple subprocesses.
I'm not sure if this is an issue with my setup or something else.

Anyway, thanks for the quick response!

@vbharadwaj-bk
Copy link
Member

No problem! Ok - ruled out that it is a kernel call issue. How about this: try setting self.forward = None in the setup_torch_module_function, and let's see if you get an error message. If not, then the deadlock occurs before the call to forward. See below:

def setup_torch_module(self):
        '''
        Need two different functions depending on whether the
        convolution is deterministic.
        '''
        if not self.deterministic:
            @torch.library.custom_op(f"openequivariance::conv_forward{self.conv_id}", mutates_args=(), device_types="cuda")
            def forward(L1_in : torch.Tensor, L2_in : torch.Tensor, 
                    weights : torch.Tensor, rows: torch.Tensor, cols: torch.Tensor) -> torch.Tensor:
                L1_in_c, L2_in_c, weights_c = L1_in.contiguous(), L2_in.contiguous(), weights.contiguous()
                L3_out = torch.zeros((L1_in_c.shape[0], self.L3.dim ), dtype=L1_in.dtype, device='cuda')

                self.internal.exec_conv_rawptrs(L1_in_c.data_ptr(), L2_in_c.data_ptr(),
                    weights_c.data_ptr(), L3_out.data_ptr(),
                    rows.data_ptr(), cols.data_ptr(),
                    cols.shape[0], L1_in.shape[0], self.workspace_ptr)

                return L3_out
            
            @forward.register_fake
            def _(L1_in, L2_in, weights, rows, cols):
                return L1_in.new_empty(L1_in.shape[0], self.L3.dim)
            
            self.forward = None #forward

@vbharadwaj-bk vbharadwaj-bk self-assigned this Feb 14, 2025
@vbharadwaj-bk
Copy link
Member

vbharadwaj-bk commented Feb 14, 2025

I have a hunch that the problem could actually lie with the logging library: https://stackoverflow.com/questions/24509650/deadlock-with-logging-multiprocess-multithread-python-script.

One way to check this: can you head to logging_utils.py and set logger.setLevel(logging.CRITICAL)?

If that works, we can open a PR to fix this. This might also be an option, depending on whether or not changing the logging level solves the issue: https://github.com/google/python-atfork/blob/main/atfork/stdlib_fixer.py.

@vbharadwaj-bk vbharadwaj-bk added the bug Something isn't working label Feb 14, 2025
@vbharadwaj-bk vbharadwaj-bk linked a pull request Feb 15, 2025 that will close this issue
@vbharadwaj-bk
Copy link
Member

@sylee0124 Any follow-up if the latest commits have helped? We suspect that multiple processes are spawned by ninja and that disabling the logger might help.

@sylee0124
Copy link
Author

Apologies for the delayed response.
I tested the approach of setting self.forward = None, but it resulted in the error:
'NoneType' object has no attribute 'register_autograd'
Also, when I commented out tp_conv.forward() the deadlock error does not occur.

Additionally, I tested the latest commit, but the issue persists.
Please let me know if you have any further suggestions or if I can provide more details.

@vbharadwaj-bk
Copy link
Member

vbharadwaj-bk commented Feb 24, 2025

We can try to debug this a little further (but it'll be difficult - I tried this on three different machines without encountering this issue, so at some point I'll need to ask if you can replicate this issue on at least one other system).

For starters: do you get the same error if you a) install Pytorch Geometric and b) run examples/readme_tutorial.py? I don't expect any difference, but probably good to start there so we are looking at exactly the same piece of code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants