Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export training model to StableHlo #8366

Closed
Zantares opened this issue Nov 8, 2024 · 3 comments
Closed

Export training model to StableHlo #8366

Zantares opened this issue Nov 8, 2024 · 3 comments

Comments

@Zantares
Copy link

Zantares commented Nov 8, 2024

❓ Questions and Help

The export API only supports torch.nn.module as input, is any method to export a training model with step_fn to StableHlo?

Here is a simple training case from example:

  def __init__(self):
    ...
    self.device = torch_xla.device()
    self.model = torchvision.models.resnet50().to(self.device)
    self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
    self.loss_fn = nn.CrossEntropyLoss()
    ...

  def run_optimizer(self):
    self.optimizer.step()

  def step_fn(self, data, target):
    self.optimizer.zero_grad()
    output = self.model(data)
    loss = self.loss_fn(output, target)
    loss.backward()
    self.run_optimizer()
    return loss

The guidance https://pytorch.org/xla/master/features/stablehlo.html#torch-export-to-stablehlo only introduced how to export the original self.model, but it didn't tell how to export the model with Optimizer and Loss functions.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 8, 2024

@qihqi not sure if exporting for training is something we support today.

@Zantares
Copy link
Author

Add more background:

Compare with Torch-XLA, I found that JAX has a convenient API takes jitted function as input. Here is an example from JAX repo:

...
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))
...


if __name__ == "__main__":
  @jit
  def update(params, batch):
    grads = grad(loss)(params, batch)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

  ...
  params = update(params, next(batches))
  ...

Then it can be easily exported as below:

  # Export the function to StableHLO
  sh_exported = export.export(update)(params, batch)
  sh_text = get_stablehlo_asm(sh_exported.mlir_module())
  print(sh_text)

I can execute the generated StableHLO and get expected results. So, I'm wondering if Torch-XLA can export training model like this.

@Zantares
Copy link
Author

Zantares commented Jan 9, 2025

Answered in #8486 (comment), close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants