Skip to content

Commit

Permalink
torch.compile ImageClassifier example (#2915)
Browse files Browse the repository at this point in the history
* updated torch.compile example

* added pytest for example

* lint failure

* show 3x speedup with torch.compile

* show 3x speedup with torch.compile

* show 3x speedup with torch.compile

* added missing file

* added missing file

* added sbin to path

* skipping test

* Skipping pytest for now as its causing other tests to fail
  • Loading branch information
agunapal authored Feb 6, 2024
1 parent fa0f1e3 commit 88eca54
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 10 deletions.
15 changes: 5 additions & 10 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
## PyTorch 2.x integration

PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental given that most public benchmarks have focused on training instead of inference.
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption.

We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture.

## Get started

Install torchserve and ensure that you're using at least `torch>=2.0.0`

To use the latest nightlies, you can run the following commands
```sh
python ts_scripts/install_dependencies.py --cuda=cu118
pip install torchserve torch-model-archiver
python ts_scripts/install_dependencies.py --cuda=cu121 --nightly_torch
pip install torchserve-nightly torch-model-archiver-nightly
```

## torch.compile
Expand All @@ -27,13 +28,7 @@ You can also pass a dictionary with compile options if you need more control ove
pt2 : {backend: inductor, mode: reduce-overhead}
```
As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file

```
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier --config-file model_config.yaml
torchserve --start --ncs --model-store model_store --models densenet161.mar
```
An example of using `torch.compile` can be found [here](./torch_compile/README.md)

The exact same approach works with any other model, what's going on is the below

Expand Down
94 changes: 94 additions & 0 deletions examples/pt2/torch_compile/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

# TorchServe inference with torch.compile of densenet161 model

This example shows how to take eager model of `densenet161`, configure TorchServe to use `torch.compile` and run inference using `torch.compile`


### Pre-requisites

- `PyTorch >= 2.0`

Change directory to the examples directory
Ex: `cd examples/pt2/torch_compile`


### torch.compile config

`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)

In this example , we use the following config

```
echo "pt2 : {backend: inductor, mode: reduce-overhead}" > model-config.yaml
```

### Create model archive

```
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model-config.yaml -f
```

#### Start TorchServe
```
torchserve --start --ncs --model-store model_store --models densenet161.mar
```

#### Run Inference

```
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
```

produces the output

```
{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}
```

### Performance improvement from using `torch.compile`

To measure the handler `preprocess`, `inference`, `postprocess` times, run the following

#### Measure inference time with PyTorch eager

```
echo "handler:" > model-config.yaml && \
echo " profile: true" >> model-config.yaml
```

Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
After a few iterations of warmup, we see the following

```
2024-02-03T00:54:31,136 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.118656158447266|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:18.77564811706543|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.16630400717258453|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
```

#### Measure inference time with `torch.compile`

```
echo "pt2: {backend: inductor, mode: reduce-overhead}" > model-config.yaml && \
echo "handler:" >> model-config.yaml && \
echo " profile: true" >> model-config.yaml
```

Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
`torch.compile` needs a few inferences to warmup. Once warmed up, we see the following
```
2024-02-03T00:56:14,808 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.9771199226379395|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.8818559646606445|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.19392000138759613|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
```

### Conclusion

`torch.compile` reduces the inference time from 18ms to 5ms
1 change: 1 addition & 0 deletions examples/pt2/torch_compile/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : {backend: inductor, mode: reduce-overhead}
27 changes: 27 additions & 0 deletions examples/pt2/torch_compile/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torchvision.models.densenet import DenseNet


class ImageClassifier(DenseNet):
def __init__(self):
super(ImageClassifier, self).__init__(48, (6, 12, 36, 24), 96)

def load_state_dict(self, state_dict, strict=True):
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
# Credit - https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py#def _load_state_dict()
import re

pattern = re.compile(
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]

return super(ImageClassifier, self).load_state_dict(state_dict, strict)
82 changes: 82 additions & 0 deletions test/pytest/test_example_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from pathlib import Path

import pytest
import torch
from pkg_resources import packaging

from ts.torch_handler.image_classifier import ImageClassifier
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
from ts.utils.util import load_label_mapping
from ts_scripts.utils import try_and_handle

CURR_FILE_PATH = Path(__file__).parent.absolute()
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile")
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
MAPPING_DATA = REPO_ROOT_DIR.joinpath(
"examples", "image_classifier", "index_to_name.json"
)
MODEL_PTH_FILE = "densenet161-8d451a50.pth"
MODEL_FILE = "model.py"
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")


PT2_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0")
else False
)

EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"]


@pytest.fixture
def custom_working_directory(tmp_path):
# Set the custom working directory
custom_dir = tmp_path / "model_dir"
custom_dir.mkdir()
os.chdir(custom_dir)
yield custom_dir
# Clean up and return to the original working directory
os.chdir(tmp_path)


@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0")
@pytest.mark.skip(reason="Skipping as its causing other testcases to fail")
def test_torch_compile_inference(monkeypatch, custom_working_directory):
monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR)
# Get the path to the custom working directory
model_dir = custom_working_directory

try_and_handle(
f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}"
)

# Handler for Image classification
handler = ImageClassifier()

# Context definition
ctx = MockContext(
model_pt_file=model_dir.joinpath(MODEL_PTH_FILE),
model_dir=EXAMPLE_ROOT_DIR.as_posix(),
model_file=MODEL_FILE,
model_yaml_config_file=MODEL_YAML_CFG_FILE,
)

torch.manual_seed(42 * 42)
handler.initialize(ctx)
handler.context = ctx
handler.mapping = load_label_mapping(MAPPING_DATA)

data = {}
with open(TEST_DATA, "rb") as image:
image_file = image.read()
byte_array_type = bytearray(image_file)
data["body"] = byte_array_type

result = handler.handle([data], ctx)

labels = list(result[0].keys())

assert labels == EXPECTED_RESULTS
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1175,3 +1175,4 @@ BabyLlamaHandler
CMakeLists
TorchScriptHandler
libllamacpp
warmup

0 comments on commit 88eca54

Please sign in to comment.