-
Notifications
You must be signed in to change notification settings - Fork 870
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torch.compile ImageClassifier example (#2915)
* updated torch.compile example * added pytest for example * lint failure * show 3x speedup with torch.compile * show 3x speedup with torch.compile * show 3x speedup with torch.compile * added missing file * added missing file * added sbin to path * skipping test * Skipping pytest for now as its causing other tests to fail
- Loading branch information
Showing
6 changed files
with
210 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
|
||
# TorchServe inference with torch.compile of densenet161 model | ||
|
||
This example shows how to take eager model of `densenet161`, configure TorchServe to use `torch.compile` and run inference using `torch.compile` | ||
|
||
|
||
### Pre-requisites | ||
|
||
- `PyTorch >= 2.0` | ||
|
||
Change directory to the examples directory | ||
Ex: `cd examples/pt2/torch_compile` | ||
|
||
|
||
### torch.compile config | ||
|
||
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html) | ||
|
||
In this example , we use the following config | ||
|
||
``` | ||
echo "pt2 : {backend: inductor, mode: reduce-overhead}" > model-config.yaml | ||
``` | ||
|
||
### Create model archive | ||
|
||
``` | ||
wget https://download.pytorch.org/models/densenet161-8d451a50.pth | ||
mkdir model_store | ||
torch-model-archiver --model-name densenet161 --version 1.0 --model-file model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model-config.yaml -f | ||
``` | ||
|
||
#### Start TorchServe | ||
``` | ||
torchserve --start --ncs --model-store model_store --models densenet161.mar | ||
``` | ||
|
||
#### Run Inference | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg | ||
``` | ||
|
||
produces the output | ||
|
||
``` | ||
{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
} | ||
``` | ||
|
||
### Performance improvement from using `torch.compile` | ||
|
||
To measure the handler `preprocess`, `inference`, `postprocess` times, run the following | ||
|
||
#### Measure inference time with PyTorch eager | ||
|
||
``` | ||
echo "handler:" > model-config.yaml && \ | ||
echo " profile: true" >> model-config.yaml | ||
``` | ||
|
||
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above. | ||
After a few iterations of warmup, we see the following | ||
|
||
``` | ||
2024-02-03T00:54:31,136 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.118656158447266|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS] | ||
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:18.77564811706543|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS] | ||
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.16630400717258453|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS] | ||
``` | ||
|
||
#### Measure inference time with `torch.compile` | ||
|
||
``` | ||
echo "pt2: {backend: inductor, mode: reduce-overhead}" > model-config.yaml && \ | ||
echo "handler:" >> model-config.yaml && \ | ||
echo " profile: true" >> model-config.yaml | ||
``` | ||
|
||
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above. | ||
`torch.compile` needs a few inferences to warmup. Once warmed up, we see the following | ||
``` | ||
2024-02-03T00:56:14,808 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.9771199226379395|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS] | ||
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.8818559646606445|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS] | ||
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.19392000138759613|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS] | ||
``` | ||
|
||
### Conclusion | ||
|
||
`torch.compile` reduces the inference time from 18ms to 5ms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pt2 : {backend: inductor, mode: reduce-overhead} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torchvision.models.densenet import DenseNet | ||
|
||
|
||
class ImageClassifier(DenseNet): | ||
def __init__(self): | ||
super(ImageClassifier, self).__init__(48, (6, 12, 36, 24), 96) | ||
|
||
def load_state_dict(self, state_dict, strict=True): | ||
# '.'s are no longer allowed in module names, but previous _DenseLayer | ||
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. | ||
# They are also in the checkpoints in model_urls. This pattern is used | ||
# to find such keys. | ||
# Credit - https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py#def _load_state_dict() | ||
import re | ||
|
||
pattern = re.compile( | ||
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" | ||
) | ||
|
||
for key in list(state_dict.keys()): | ||
res = pattern.match(key) | ||
if res: | ||
new_key = res.group(1) + res.group(2) | ||
state_dict[new_key] = state_dict[key] | ||
del state_dict[key] | ||
|
||
return super(ImageClassifier, self).load_state_dict(state_dict, strict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
from pkg_resources import packaging | ||
|
||
from ts.torch_handler.image_classifier import ImageClassifier | ||
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext | ||
from ts.utils.util import load_label_mapping | ||
from ts_scripts.utils import try_and_handle | ||
|
||
CURR_FILE_PATH = Path(__file__).parent.absolute() | ||
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1] | ||
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile") | ||
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg") | ||
MAPPING_DATA = REPO_ROOT_DIR.joinpath( | ||
"examples", "image_classifier", "index_to_name.json" | ||
) | ||
MODEL_PTH_FILE = "densenet161-8d451a50.pth" | ||
MODEL_FILE = "model.py" | ||
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") | ||
|
||
|
||
PT2_AVAILABLE = ( | ||
True | ||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0") | ||
else False | ||
) | ||
|
||
EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"] | ||
|
||
|
||
@pytest.fixture | ||
def custom_working_directory(tmp_path): | ||
# Set the custom working directory | ||
custom_dir = tmp_path / "model_dir" | ||
custom_dir.mkdir() | ||
os.chdir(custom_dir) | ||
yield custom_dir | ||
# Clean up and return to the original working directory | ||
os.chdir(tmp_path) | ||
|
||
|
||
@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0") | ||
@pytest.mark.skip(reason="Skipping as its causing other testcases to fail") | ||
def test_torch_compile_inference(monkeypatch, custom_working_directory): | ||
monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR) | ||
# Get the path to the custom working directory | ||
model_dir = custom_working_directory | ||
|
||
try_and_handle( | ||
f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}" | ||
) | ||
|
||
# Handler for Image classification | ||
handler = ImageClassifier() | ||
|
||
# Context definition | ||
ctx = MockContext( | ||
model_pt_file=model_dir.joinpath(MODEL_PTH_FILE), | ||
model_dir=EXAMPLE_ROOT_DIR.as_posix(), | ||
model_file=MODEL_FILE, | ||
model_yaml_config_file=MODEL_YAML_CFG_FILE, | ||
) | ||
|
||
torch.manual_seed(42 * 42) | ||
handler.initialize(ctx) | ||
handler.context = ctx | ||
handler.mapping = load_label_mapping(MAPPING_DATA) | ||
|
||
data = {} | ||
with open(TEST_DATA, "rb") as image: | ||
image_file = image.read() | ||
byte_array_type = bytearray(image_file) | ||
data["body"] = byte_array_type | ||
|
||
result = handler.handle([data], ctx) | ||
|
||
labels = list(result[0].keys()) | ||
|
||
assert labels == EXPECTED_RESULTS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1175,3 +1175,4 @@ BabyLlamaHandler | |
CMakeLists | ||
TorchScriptHandler | ||
libllamacpp | ||
warmup |