-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2 questions for the composite op feature #8486
Comments
Hi @Zantares, thank you for reporting the issue!
I agree with the Regarding the composite op in training, the missing piece may not be adding autograd for |
Thanks for the answer @lsy323 ! According to the reply, we can focus on the 1st question of redundant ops in this issue. For the 2nd question, I have found the Aten op lowering process in Torch-XLA, I'd like to submit a draft PR later to see if it's acceptable. |
Hi @Zantares, We are moving My understanding so far of what you are trying to do:
I am assuming that the runtime used in 2 is not XLA (otherwise you would train directly in torch_xla and avoid the hassle of export). Feel free to share more about this runtime / device you are using; just for my own curiosity :) Is the above understanding correct? |
Hi @qihqi , That's exactly correct. We are trying to train PT models on our custom SW (which is based on MLIR) and HW, and we have chosen StableHLO as the bridge. The initial step involves exporting StableHLO from Torch-XLA and then conduction a reevaluation. In the next step, we have a plan to integrate our SW with Torch-XLA by PJRT, though we're aware that this will be a time-consuming process. |
BTW, I'm also aware of |
There are few resources on Then the docs/ subfolder has more details, especially this: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/docs/how_it_works.md examples/ folder has some samples on how to use it with comments. Although I agree the documentation is lacking and we are working on it. Integrating the SW with PJRT sounds good; it will work with torch_xla2 and jax as well. Back to your usecase, we are also making it easier to create a While this is working in progress, for now, you can use |
Appreciate the answer! It solved the question mentioned in another issue #8366. |
❓ Questions and Help
Glad to see that the composite op feature is added to Torch-XLA. I have tried this feature and got some questions, hope to get answers/suggestions here:
custom_call
) can't be erased after created the composite op, e.g.Gelu
:The generated StableHLO is:
The
erf
op inmain
is useless and not erased. I have checked the composite op pass, it left these useless ops to latercanonicalizer
instead of erasing directly, but thecanonicalizer
didn't handle it... I guess it's caused by the custom call side-effect.The question: Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?
Then the backward graph is not generated.
The question: Is any plan to support composite op feature in training? It seems the missing part is only to add the Autograd for
mark_tensor
, but I'm just a XLA developer and not familiar with PyTorch, I don't know how to add it...The text was updated successfully, but these errors were encountered: