Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2 questions for the composite op feature #8486

Open
Zantares opened this issue Dec 12, 2024 · 8 comments
Open

2 questions for the composite op feature #8486

Zantares opened this issue Dec 12, 2024 · 8 comments
Assignees
Labels
stablehlo StableHLO related work

Comments

@Zantares
Copy link

❓ Questions and Help

Glad to see that the composite op feature is added to Torch-XLA. I have tried this feature and got some questions, hope to get answers/suggestions here:

  1. Some redundant IRs (start from custom_call) can't be erased after created the composite op, e.g. Gelu:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder

class Example(torch.nn.Module):
    def __init__(self):
        super(Example, self).__init__()
        self.gelu = torch.nn.GELU(approximate="none")
        self.composite_op = StableHLOCompositeBuilder("composite.gelu", {"approximate": "none"})

    def forward(self, x):
        x = self.composite_op.mark_inputs(x)
        y = self.gelu(x)
        y = self.composite_op.mark_outputs(y)
        return y

x = torch.randn(10, device=xm.xla_device())
model = Example().to(xm.xla_device())
print(model(x))

input_args = (x, )
exported = torch.export.export(model, input_args)
# print(exported.graph)
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)

The generated StableHLO is:

module @IrToHlo.16 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %cst = stablehlo.constant dense<0.707106769> : tensor<10xf32>
    %0 = stablehlo.multiply %arg0, %cst : tensor<10xf32>
    %1 = stablehlo.custom_call @mhlo.erf(%0) {mhlo.attributes = {}, mhlo.version = 1 : i64} : (tensor<10xf32>) -> tensor<10xf32>
    %2 = stablehlo.composite "composite.gelu" %arg0 {composite_attributes = {approximate = "none"}, decomposition = @composite.gelu.impl} : (tensor<10xf32>) -> tensor<10xf32>
    return %2 : tensor<10xf32>
  }
  func.func private @composite.gelu.impl(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<10xf32>
    %cst_0 = stablehlo.constant dense<0.707106769> : tensor<10xf32>
    %cst_1 = stablehlo.constant dense<5.000000e-01> : tensor<10xf32>
    %0 = stablehlo.multiply %arg0, %cst_1 : tensor<10xf32>
    %1 = stablehlo.multiply %arg0, %cst_0 : tensor<10xf32>
    %2 = stablehlo.custom_call @mhlo.erf(%1) {mhlo.attributes = {}, mhlo.version = 1 : i64} : (tensor<10xf32>) -> tensor<10xf32>
    %3 = stablehlo.add %2, %cst : tensor<10xf32>
    %4 = stablehlo.multiply %0, %3 : tensor<10xf32>
    return %4 : tensor<10xf32>
  }
}

The erf op in main is useless and not erased. I have checked the composite op pass, it left these useless ops to later canonicalizer instead of erasing directly, but the canonicalizer didn't handle it... I guess it's caused by the custom call side-effect.

The question: Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?

  1. Composite op feature can't work in training. Even the proposal of this feature is for inference now (work for export API), I tried to enabled it in training locally, but I found that it reported a warning:

UserWarning: xla::mark_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /data4/home/luteng/code/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Then the backward graph is not generated.

The question: Is any plan to support composite op feature in training? It seems the missing part is only to add the Autograd for mark_tensor, but I'm just a XLA developer and not familiar with PyTorch, I don't know how to add it...

@lsy323 lsy323 self-assigned this Dec 13, 2024
@lsy323 lsy323 added the stablehlo StableHLO related work label Dec 13, 2024
@lsy323
Copy link
Collaborator

lsy323 commented Dec 16, 2024

Hi @Zantares, thank you for reporting the issue!

Can the composite op pass erase these ops directly? Is any special reason to avoid the erasing operation here?

I agree with the erf in main is expected to be removed by DCE, let me try to repro it on my end to investigate.

Regarding the composite op in training, the missing piece may not be adding autograd for mark_tensor. The current stablehlo export flow is integrated with torch.export, It seems that torch.export doesn't have training support yet.

@Zantares
Copy link
Author

Thanks for the answer @lsy323 ! According to the reply, we can focus on the 1st question of redundant ops in this issue.

For the 2nd question, I have found the Aten op lowering process in Torch-XLA, I'd like to submit a draft PR later to see if it's acceptable.

@Zantares
Copy link
Author

Hi @lsy323 , I added a draft PR #8502 to demonstrate the solution of composite op in training. Hope to get feedback/suggestion, thanks!

@qihqi
Copy link
Collaborator

qihqi commented Jan 8, 2025

Hi @Zantares,

We are moving export-to-stablehlo feature to torch_xla2 instead. Please take a look at: https://github.com/pytorch/xla/blob/master/docs/source/features/stablehlo.md

My understanding so far of what you are trying to do:

  1. Export the train_step (which is one forward, one backward, one optimizer update) to Stablehlo, with some composites for key ops.
  2. Evaluate this stablehlo (in some runtime) repeatedly to achieve training, composites exported above can help you to call fused kernel.

I am assuming that the runtime used in 2 is not XLA (otherwise you would train directly in torch_xla and avoid the hassle of export). Feel free to share more about this runtime / device you are using; just for my own curiosity :)

Is the above understanding correct?

@Zantares
Copy link
Author

Zantares commented Jan 8, 2025

Hi @qihqi

That's exactly correct. We are trying to train PT models on our custom SW (which is based on MLIR) and HW, and we have chosen StableHLO as the bridge. The initial step involves exporting StableHLO from Torch-XLA and then conduction a reevaluation. In the next step, we have a plan to integrate our SW with Torch-XLA by PJRT, though we're aware that this will be a time-consuming process.

@Zantares
Copy link
Author

Zantares commented Jan 8, 2025

Hi @Zantares,

We are moving export-to-stablehlo feature to torch_xla2 instead. Please take a look at: https://github.com/pytorch/xla/blob/master/docs/source/features/stablehlo.md

BTW, I'm also aware of torch_xla2 in the repo, but so few introduction for it. Can I get any detail information of it?

@qihqi
Copy link
Collaborator

qihqi commented Jan 8, 2025

There are few resources on torch_xla2 in the repo, starting with the README.md at https://github.com/pytorch/xla/tree/master/experimental/torch_xla2

Then the docs/ subfolder has more details, especially this: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/docs/how_it_works.md

examples/ folder has some samples on how to use it with comments.

Although I agree the documentation is lacking and we are working on it.

Integrating the SW with PJRT sounds good; it will work with torch_xla2 and jax as well.

Back to your usecase, we are also making it easier to create a train_step callable from a model in https://github.com/pytorch/xla/pull/8495/files#diff-9d7a1652aaaef89d6d010ec05e61c3688a24e8e3ef16a38c91723a5103284871

While this is working in progress, for now, you can use torch_xla2.extract_jax to create a jax callable, then create a train step function (that calls jax.grad followed by optax.update) by hand then export that.

@Zantares
Copy link
Author

Zantares commented Jan 9, 2025

There are few resources on torch_xla2 in the repo, starting with the README.md at https://github.com/pytorch/xla/tree/master/experimental/torch_xla2

Then the docs/ subfolder has more details, especially this: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/docs/how_it_works.md

examples/ folder has some samples on how to use it with comments.

Although I agree the documentation is lacking and we are working on it.

Integrating the SW with PJRT sounds good; it will work with torch_xla2 and jax as well.

Back to your usecase, we are also making it easier to create a train_step callable from a model in https://github.com/pytorch/xla/pull/8495/files#diff-9d7a1652aaaef89d6d010ec05e61c3688a24e8e3ef16a38c91723a5103284871

While this is working in progress, for now, you can use torch_xla2.extract_jax to create a jax callable, then create a train step function (that calls jax.grad followed by optax.update) by hand then export that.

Appreciate the answer! It solved the question mentioned in another issue #8366.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stablehlo StableHLO related work
Projects
None yet
Development

No branches or pull requests

3 participants