Skip to content

Commit

Permalink
example for using TorchInductor caching with torch.compile (#2925)
Browse files Browse the repository at this point in the history
* example for using torhcinductor caching

* example for using torhcinductor caching

* example for using torhcinductor caching

* update README

* update README

* review comments

* updated readme

* Verified with 4 workers

* verified with 4 workers

* added additional links for debugging
  • Loading branch information
agunapal authored Feb 13, 2024
1 parent 33d87e3 commit bef3b63
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,17 @@ torchserve takes care of 4 and 5 for you while the remaining steps are your resp

### Note

`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. If that's an issue for you make sure to populate these two environment variables to improve your warm starts.
`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. To reduce the warm up time, `TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine

To persist this cache and /or to make use of additional experimental caching feature, set the following

```
import os

os.environ["TORCHINDUCTOR_CACHE_DIR"] = "1"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "/path/to/directory" # replace with your desired path
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/path/to/directory" # replace with your desired path
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
```
An example of how to use these with TorchServe is shown [here](./torch_inductor_caching/)
## torch.export.export
Expand Down
184 changes: 184 additions & 0 deletions examples/pt2/torch_inductor_caching/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@

# TorchInductor Caching with TorchServe inference of densenet161 model

`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. To handle this, `TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine

## TorchInductor FX Graph Cache
There is an experimental feature to cache FX Graph as well. This is not enabled by default and can be set with the following config

```
import os
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
```

This needs to be set before you `import torch`

or

```
import torch
torch._inductor.config.fx_graph_cache = True
```

To see the effect of caching on `torch.compile` execution times, we need to have a multi worker setup. In this example, we use 4 workers. Workers 2,3,4 will see the benefit of caching when they execute `torch.compile`

We show below how this can be used with TorchServe


### Pre-requisites

- `PyTorch >= 2.2`

Change directory to the examples directory
Ex: `cd examples/pt2/torch_inductor_caching`


### torch.compile config

`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)

In this example , we use the following config

```yaml
pt2 : {backend: inductor, mode: max-autotune}
```
### Create model archive
```
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler ./caching_handler.py --config-file model-config-fx-cache.yaml -f
```

#### Start TorchServe
```
torchserve --start --ncs --model-store model_store --models densenet161.mar
```

#### Run Inference

```
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
```

produces the output

```
{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}
```

## TorchInductor Cache Directory
`TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine.

Since the default directory is in `/tmp`, the cache is deleted on restart

`torch.compile` provides a config to change the cache directory for `TorchInductor `

```
import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/path/to/directory" # replace with your desired path
```


We show below how this can be used with TorchServe


### Pre-requisites

- `PyTorch >= 2.2`

Change directory to the examples directory
Ex: `cd examples/pt2/torch_inductor_caching`


### torch.compile config

`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)

In this example , we use the following config

```yaml
pt2 : {backend: inductor, mode: max-autotune}
```
### Create model archive
```
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler ./caching_handler.py --config-file model-config-cache-dir.yaml -f
```

#### Start TorchServe
```
torchserve --start --ncs --model-store model_store --models densenet161.mar
```

#### Run Inference

```
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
```

produces the output

```
{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}{
"tabby": 0.4664836823940277,
"tiger_cat": 0.4645617604255676,
"Egyptian_cat": 0.06619937717914581,
"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}
```

## Additional links for improving `torch.compile` performance and debugging

- [Compile Threads](https://pytorch.org/blog/training-production-ai-models/#34-controlling-just-in-time-compilation-time)
- [Profiling torch.compile](https://pytorch.org/docs/stable/torch.compiler_profiling_torch_compile.html)
65 changes: 65 additions & 0 deletions examples/pt2/torch_inductor_caching/caching_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import os

import torch
from torch._dynamo.utils import counters

from ts.torch_handler.image_classifier import ImageClassifier

logger = logging.getLogger(__name__)


class TorchInductorCacheHandler(ImageClassifier):
"""
Diffusion-Fast handler class for text to image generation.
"""

def __init__(self):
super().__init__()
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the model is loaded and
initialized here.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.context = ctx
self.manifest = ctx.manifest
properties = ctx.system_properties

if (
"handler" in ctx.model_yaml_config
and "torch_inductor_caching" in ctx.model_yaml_config["handler"]
):
if ctx.model_yaml_config["handler"]["torch_inductor_caching"].get(
"torch_inductor_fx_graph_cache", False
):
torch._inductor.config.fx_graph_cache = True
if (
"torch_inductor_cache_dir"
in ctx.model_yaml_config["handler"]["torch_inductor_caching"]
):
os.environ["TORCHINDUCTOR_CACHE_DIR"] = ctx.model_yaml_config[
"handler"
]["torch_inductor_caching"]["torch_inductor_cache_dir"]

super().initialize(ctx)
self.initialized = True

def inference(self, data, *args, **kwargs):
with torch.inference_mode():
marshalled_data = data.to(self.device)
results = self.model(marshalled_data, *args, **kwargs)

# Debugs for FX Graph Cache hit
if torch._inductor.config.fx_graph_cache:
fx_graph_cache_hit, fx_graph_cache_miss = (
counters["inductor"]["fxgraph_cache_hit"],
counters["inductor"]["fxgraph_cache_miss"],
)
logger.info(
f'TorchInductor FX Graph cache hit {counters["inductor"]["fxgraph_cache_hit"]}, FX Graph cache miss {counters["inductor"]["fxgraph_cache_miss"]}'
)
return results
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
minWorkers: 4
maxWorkers: 4
responseTimeout: 600
pt2 : {backend: inductor, mode: max-autotune}
handler:
torch_inductor_caching:
torch_inductor_cache_dir: "/home/ubuntu/serve/examples/pt2/torch_inductor_caching/cache"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
minWorkers: 4
maxWorkers: 4
responseTimeout: 600
pt2 : {backend: inductor, mode: max-autotune}
handler:
torch_inductor_caching:
torch_inductor_fx_graph_cache: true
2 changes: 2 additions & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,8 @@ BabyLlamaHandler
CMakeLists
TorchScriptHandler
libllamacpp
USERID
torchinductor
libtorch
Andrej
Karpathy's
Expand Down

0 comments on commit bef3b63

Please sign in to comment.