Skip to content

Commit

Permalink
Changes to support torch._export.aot_compile (#2832)
Browse files Browse the repository at this point in the history
* Changes to support torch._export.aot_compile

* Changes to support torch._export.aot_compile

* corrected copy paste error in readme

* Update examples/pt2/README.md

Co-authored-by: Mark Saroufim <[email protected]>

* Update examples/pt2/README.md

Co-authored-by: Mark Saroufim <[email protected]>

* Update examples/pt2/torch_export_aot_compile/README.md

Co-authored-by: Mark Saroufim <[email protected]>

* review comments

* addressed review comments

* updates based on review comments

* changed the config for pt2 export

* refactored the code

* moved ninja to common as cpu needs this too

* lint fixes

* lint fixes

* lint fixes

---------

Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
agunapal and msaroufim authored Dec 21, 2023
1 parent ed47cc6 commit 426b4f7
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 12 deletions.
21 changes: 21 additions & 0 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,24 @@ print(extra_files['foo.txt'])
# from inference()
print(ep(torch.randn(5)))
```

## torch._export.aot_compile

Using `torch.compile` to wrap your existing eager PyTorch model can result in out of the box speedups. However, `torch.compile` is a JIT compiler. TorchServe has been supporting `torch.compile` since PyTorch 2.0 release. In a production setting, when you have multiple instances of TorchServe, each of of your instances would `torch.compile` the model on the first inference. TorchServe's model archiver is not able to truly guarantee reproducibility because its a JIT compiler.

In addition, the first inference request with `torch.compile` will be slow as the model needs to compile.

To solve this problem, `torch.export` has an experimental API `torch._export.aot_compile` which is able to `torch.export` a torch compilable model if it has no graph breaks and then run it with AOTInductor

You can find more details [here](https://pytorch.org/docs/main/torch.compiler_aot_inductor.html)


This is an experimental API and needs PyTorch 2.2 nightlies
To achieve this, add the following config in your `model-config.yaml`

```yaml
pt2 :
export:
aot_compile: true
```
You can find an example [here](./torch_export_aot_compile/README.md)
68 changes: 68 additions & 0 deletions examples/pt2/torch_export_aot_compile/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# TorchServe inference with torch._export.aot_compile

This example shows how to run TorchServe with Torch exported model with AOTInductor

To understand when to use `torch._export.aot_compile`, please refer to this [section](../README.md/#torchexportaotcompile)


### Pre-requisites

- `PyTorch >= 2.2.0`
- `CUDA 12.1`

Change directory to the examples directory
Ex: `cd examples/pt2/torch_export_aot_compile`

Install PyTorch 2.2 nightlies by running
```
chmod +x install_pytorch_nightlies.sh
source install_pytorch_nightlies.sh
```

You can also achieve this by installing TorchServe dependencies with the `nightly_torch` flag
```
python ts_scripts/install_dependencies.py --cuda=cu121 --nightly_torch
```


### Create a Torch exported model with AOTInductor

The model is saved with `.so` extension
Here we are torch exporting with AOT Inductor with `max_autotune` mode.
This is also making use of `dynamic_shapes` to support batch size from 1 to 32.
In the code, the min batch_size is mentioned as 2 instead of 1. Its by design. The code works for batch size 1. You can find an explanation for this [here](https://pytorch.org/docs/main/export.html#expressing-dynamism)

```
python resnet18_torch_export.py
```

### Create model archive

```
torch-model-archiver --model-name res18-pt2 --handler image_classifier --version 1.0 --serialized-file resnet18_pt2.so --config-file model-config.yaml --extra-files ../../image_classifier/index_to_name.json
mkdir model_store
mv res18-pt2.mar model_store/.
```

#### Start TorchServe
```
torchserve --start --model-store model_store --models res18-pt2=res18-pt2.mar --ncs
```

#### Run Inference

```
curl http://127.0.0.1:8080/predictions/res18-pt2 -T ../../image_classifier/kitten.jpg
```

produces the output

```
{
"tabby": 0.4087875485420227,
"tiger_cat": 0.34661102294921875,
"Egyptian_cat": 0.13007202744483948,
"lynx": 0.024034621194005013,
"bucket": 0.011633828282356262
}
```
13 changes: 13 additions & 0 deletions examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

# Uninstall torchtext, torchdata, torch, torchvision, and torchaudio
pip uninstall torchtext torchdata torch torchvision torchaudio -y

# Install nightly PyTorch and torchvision from the specified index URL
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed

# Optional: Display the installed PyTorch and torchvision versions
python -c "import torch; print('PyTorch version:', torch.__version__)"
python -c "import torchvision; print('torchvision version:', torchvision.__version__)"

echo "PyTorch and torchvision updated successfully!"
6 changes: 6 additions & 0 deletions examples/pt2/torch_export_aot_compile/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
batchSize: 15
maxBatchDelay: 5000
responseTimeout: 300
pt2 :
export:
aot_compile: true
29 changes: 29 additions & 0 deletions examples/pt2/torch_export_aot_compile/resnet18_torch_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

import torch
from torchvision.models import ResNet18_Weights, resnet18

torch.set_float32_matmul_precision("high")

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.no_grad():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

# Max value is 15 because of https://github.com/pytorch/pytorch/pull/116152
# On a CUDA enabled device, we tested batch_size of 32.
batch_dim = torch.export.Dim("batch", min=2, max=15)
so_path = torch._export.aot_compile(
model,
example_inputs,
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
# Specify the generated shared library path
options={
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
"max_autotune": True,
},
)
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ captum==0.6.0
packaging==23.2
pynvml==11.5.0
pyyaml==6.0
ninja==1.11.1.1
128 changes: 128 additions & 0 deletions test/pytest/test_torch_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from pathlib import Path

import torch
from pkg_resources import packaging

from ts.torch_handler.image_classifier import ImageClassifier
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
from ts.utils.util import load_label_mapping
from ts_scripts.utils import try_and_handle

CURR_FILE_PATH = Path(__file__).parent.absolute()
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_export_aot_compile")
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
MAPPING_DATA = REPO_ROOT_DIR.joinpath(
"examples", "image_classifier", "index_to_name.json"
)
MODEL_SO_FILE = "resnet18_pt2.so"
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")


PT_220_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1")
else False
)

EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "bucket"]
TEST_CASES = [
("kitten.jpg", EXPECTED_RESULTS[0]),
]


import os

import pytest


@pytest.fixture
def custom_working_directory(tmp_path):
# Set the custom working directory
custom_dir = tmp_path / "model_dir"
custom_dir.mkdir()
os.chdir(custom_dir)
yield custom_dir
# Clean up and return to the original working directory
os.chdir(tmp_path)


@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
def test_torch_export_aot_compile(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory

# Construct the path to the Python script to execute
script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py")

# Get the .pt2 file from torch.export
cmd = "python " + script_path
try_and_handle(cmd)

# Handler for Image classification
handler = ImageClassifier()

# Context definition
ctx = MockContext(
model_pt_file=MODEL_SO_FILE,
model_dir=model_dir.as_posix(),
model_file=None,
model_yaml_config_file=MODEL_YAML_CFG_FILE,
)

torch.manual_seed(42 * 42)
handler.initialize(ctx)
handler.context = ctx
handler.mapping = load_label_mapping(MAPPING_DATA)

data = {}
with open(TEST_DATA, "rb") as image:
image_file = image.read()
byte_array_type = bytearray(image_file)
data["body"] = byte_array_type

result = handler.handle([data], ctx)

labels = list(result[0].keys())

assert labels == EXPECTED_RESULTS


@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory

# Construct the path to the Python script to execute
script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py")

# Get the .pt2 file from torch.export
cmd = "python " + script_path
try_and_handle(cmd)

# Handler for Image classification
handler = ImageClassifier()

# Context definition
ctx = MockContext(
model_pt_file=MODEL_SO_FILE,
model_dir=model_dir.as_posix(),
model_file=None,
model_yaml_config_file=MODEL_YAML_CFG_FILE,
)

torch.manual_seed(42 * 42)
handler.initialize(ctx)
handler.context = ctx
handler.mapping = load_label_mapping(MAPPING_DATA)

data = {}
with open(TEST_DATA, "rb") as image:
image_file = image.read()
byte_array_type = bytearray(image_file)
data["body"] = byte_array_type

# Send a batch of 16 elements
result = handler.handle([data for i in range(15)], ctx)

assert len(result) == 15
Empty file.
27 changes: 27 additions & 0 deletions ts/handler_utils/torch_export/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import tempfile

import torch.fx._pytree as fx_pytree
from torch._inductor.utils import aot_inductor_launcher, cache_dir
from torch.utils import _pytree as pytree
from torch.utils.cpp_extension import load_inline


def load_exported_model(model_so_path, device):
module = load_inline(
name="aot_inductor",
cpp_sources=[aot_inductor_launcher(model_so_path, device)],
# use a unique build directory to avoid test interference
build_directory=tempfile.mkdtemp(dir=cache_dir()),
functions=["run", "get_call_spec"],
with_cuda=("cuda" == device),
)
call_spec = module.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])

def optimized(*args):
flat_inputs = fx_pytree.tree_flatten_spec((args, {}), in_spec)
flat_outputs = module.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)

return optimized
Loading

0 comments on commit 426b4f7

Please sign in to comment.