-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Export training model to StableHlo #8366
Comments
@qihqi not sure if exporting for training is something we support today. |
Add more background: Compare with Torch-XLA, I found that JAX has a convenient API takes jitted function as input. Here is an example from JAX repo: ...
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
...
if __name__ == "__main__":
@jit
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
...
params = update(params, next(batches))
... Then it can be easily exported as below: # Export the function to StableHLO
sh_exported = export.export(update)(params, batch)
sh_text = get_stablehlo_asm(sh_exported.mlir_module())
print(sh_text) I can execute the generated StableHLO and get expected results. So, I'm wondering if Torch-XLA can export training model like this. |
Answered in #8486 (comment), close this issue. |
❓ Questions and Help
The export API only supports
torch.nn.module
as input, is any method to export a training model with step_fn to StableHlo?Here is a simple training case from example:
The guidance https://pytorch.org/xla/master/features/stablehlo.html#torch-export-to-stablehlo only introduced how to export the original
self.model
, but it didn't tell how to export the model with Optimizer and Loss functions.The text was updated successfully, but these errors were encountered: