-
Notifications
You must be signed in to change notification settings - Fork 870
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example for using TorchInductor caching with torch.compile (#2925)
* example for using torhcinductor caching * example for using torhcinductor caching * example for using torhcinductor caching * update README * update README * review comments * updated readme * Verified with 4 workers * verified with 4 workers * added additional links for debugging
- Loading branch information
Showing
6 changed files
with
271 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
|
||
# TorchInductor Caching with TorchServe inference of densenet161 model | ||
|
||
`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. To handle this, `TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine | ||
|
||
## TorchInductor FX Graph Cache | ||
There is an experimental feature to cache FX Graph as well. This is not enabled by default and can be set with the following config | ||
|
||
``` | ||
import os | ||
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | ||
``` | ||
|
||
This needs to be set before you `import torch` | ||
|
||
or | ||
|
||
``` | ||
import torch | ||
torch._inductor.config.fx_graph_cache = True | ||
``` | ||
|
||
To see the effect of caching on `torch.compile` execution times, we need to have a multi worker setup. In this example, we use 4 workers. Workers 2,3,4 will see the benefit of caching when they execute `torch.compile` | ||
|
||
We show below how this can be used with TorchServe | ||
|
||
|
||
### Pre-requisites | ||
|
||
- `PyTorch >= 2.2` | ||
|
||
Change directory to the examples directory | ||
Ex: `cd examples/pt2/torch_inductor_caching` | ||
|
||
|
||
### torch.compile config | ||
|
||
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html) | ||
|
||
In this example , we use the following config | ||
|
||
```yaml | ||
pt2 : {backend: inductor, mode: max-autotune} | ||
``` | ||
### Create model archive | ||
``` | ||
wget https://download.pytorch.org/models/densenet161-8d451a50.pth | ||
mkdir model_store | ||
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler ./caching_handler.py --config-file model-config-fx-cache.yaml -f | ||
``` | ||
|
||
#### Start TorchServe | ||
``` | ||
torchserve --start --ncs --model-store model_store --models densenet161.mar | ||
``` | ||
|
||
#### Run Inference | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg | ||
``` | ||
|
||
produces the output | ||
|
||
``` | ||
{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
}{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
}{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
}{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
} | ||
``` | ||
|
||
## TorchInductor Cache Directory | ||
`TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine. | ||
|
||
Since the default directory is in `/tmp`, the cache is deleted on restart | ||
|
||
`torch.compile` provides a config to change the cache directory for `TorchInductor ` | ||
|
||
``` | ||
import os | ||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/path/to/directory" # replace with your desired path | ||
``` | ||
|
||
|
||
We show below how this can be used with TorchServe | ||
|
||
|
||
### Pre-requisites | ||
|
||
- `PyTorch >= 2.2` | ||
|
||
Change directory to the examples directory | ||
Ex: `cd examples/pt2/torch_inductor_caching` | ||
|
||
|
||
### torch.compile config | ||
|
||
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html) | ||
|
||
In this example , we use the following config | ||
|
||
```yaml | ||
pt2 : {backend: inductor, mode: max-autotune} | ||
``` | ||
### Create model archive | ||
``` | ||
wget https://download.pytorch.org/models/densenet161-8d451a50.pth | ||
mkdir model_store | ||
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler ./caching_handler.py --config-file model-config-cache-dir.yaml -f | ||
``` | ||
|
||
#### Start TorchServe | ||
``` | ||
torchserve --start --ncs --model-store model_store --models densenet161.mar | ||
``` | ||
|
||
#### Run Inference | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg | ||
``` | ||
|
||
produces the output | ||
|
||
``` | ||
{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
}{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
}{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
}{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
} | ||
``` | ||
|
||
## Additional links for improving `torch.compile` performance and debugging | ||
|
||
- [Compile Threads](https://pytorch.org/blog/training-production-ai-models/#34-controlling-just-in-time-compilation-time) | ||
- [Profiling torch.compile](https://pytorch.org/docs/stable/torch.compiler_profiling_torch_compile.html) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
import os | ||
|
||
import torch | ||
from torch._dynamo.utils import counters | ||
|
||
from ts.torch_handler.image_classifier import ImageClassifier | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TorchInductorCacheHandler(ImageClassifier): | ||
""" | ||
Diffusion-Fast handler class for text to image generation. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
"""In this initialize function, the model is loaded and | ||
initialized here. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
self.context = ctx | ||
self.manifest = ctx.manifest | ||
properties = ctx.system_properties | ||
|
||
if ( | ||
"handler" in ctx.model_yaml_config | ||
and "torch_inductor_caching" in ctx.model_yaml_config["handler"] | ||
): | ||
if ctx.model_yaml_config["handler"]["torch_inductor_caching"].get( | ||
"torch_inductor_fx_graph_cache", False | ||
): | ||
torch._inductor.config.fx_graph_cache = True | ||
if ( | ||
"torch_inductor_cache_dir" | ||
in ctx.model_yaml_config["handler"]["torch_inductor_caching"] | ||
): | ||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = ctx.model_yaml_config[ | ||
"handler" | ||
]["torch_inductor_caching"]["torch_inductor_cache_dir"] | ||
|
||
super().initialize(ctx) | ||
self.initialized = True | ||
|
||
def inference(self, data, *args, **kwargs): | ||
with torch.inference_mode(): | ||
marshalled_data = data.to(self.device) | ||
results = self.model(marshalled_data, *args, **kwargs) | ||
|
||
# Debugs for FX Graph Cache hit | ||
if torch._inductor.config.fx_graph_cache: | ||
fx_graph_cache_hit, fx_graph_cache_miss = ( | ||
counters["inductor"]["fxgraph_cache_hit"], | ||
counters["inductor"]["fxgraph_cache_miss"], | ||
) | ||
logger.info( | ||
f'TorchInductor FX Graph cache hit {counters["inductor"]["fxgraph_cache_hit"]}, FX Graph cache miss {counters["inductor"]["fxgraph_cache_miss"]}' | ||
) | ||
return results |
7 changes: 7 additions & 0 deletions
7
examples/pt2/torch_inductor_caching/model-config-cache-dir.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
minWorkers: 4 | ||
maxWorkers: 4 | ||
responseTimeout: 600 | ||
pt2 : {backend: inductor, mode: max-autotune} | ||
handler: | ||
torch_inductor_caching: | ||
torch_inductor_cache_dir: "/home/ubuntu/serve/examples/pt2/torch_inductor_caching/cache" |
7 changes: 7 additions & 0 deletions
7
examples/pt2/torch_inductor_caching/model-config-fx-cache.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
minWorkers: 4 | ||
maxWorkers: 4 | ||
responseTimeout: 600 | ||
pt2 : {backend: inductor, mode: max-autotune} | ||
handler: | ||
torch_inductor_caching: | ||
torch_inductor_fx_graph_cache: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters