diff --git a/examples/pt2/README.md b/examples/pt2/README.md index bc9908d03a..ef4578fbed 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -107,3 +107,24 @@ print(extra_files['foo.txt']) # from inference() print(ep(torch.randn(5))) ``` + +## torch._export.aot_compile + +Using `torch.compile` to wrap your existing eager PyTorch model can result in out of the box speedups. However, `torch.compile` is a JIT compiler. TorchServe has been supporting `torch.compile` since PyTorch 2.0 release. In a production setting, when you have multiple instances of TorchServe, each of of your instances would `torch.compile` the model on the first inference. TorchServe's model archiver is not able to truly guarantee reproducibility because its a JIT compiler. + +In addition, the first inference request with `torch.compile` will be slow as the model needs to compile. + +To solve this problem, `torch.export` has an experimental API `torch._export.aot_compile` which is able to `torch.export` a torch compilable model if it has no graph breaks and then run it with AOTInductor + +You can find more details [here](https://pytorch.org/docs/main/torch.compiler_aot_inductor.html) + + +This is an experimental API and needs PyTorch 2.2 nightlies +To achieve this, add the following config in your `model-config.yaml` + +```yaml +pt2 : + export: + aot_compile: true +``` +You can find an example [here](./torch_export_aot_compile/README.md) diff --git a/examples/pt2/torch_export_aot_compile/README.md b/examples/pt2/torch_export_aot_compile/README.md new file mode 100644 index 0000000000..faef26dfe0 --- /dev/null +++ b/examples/pt2/torch_export_aot_compile/README.md @@ -0,0 +1,68 @@ +# TorchServe inference with torch._export.aot_compile + +This example shows how to run TorchServe with Torch exported model with AOTInductor + +To understand when to use `torch._export.aot_compile`, please refer to this [section](../README.md/#torchexportaotcompile) + + +### Pre-requisites + +- `PyTorch >= 2.2.0` +- `CUDA 12.1` + +Change directory to the examples directory +Ex: `cd examples/pt2/torch_export_aot_compile` + +Install PyTorch 2.2 nightlies by running +``` +chmod +x install_pytorch_nightlies.sh +source install_pytorch_nightlies.sh +``` + +You can also achieve this by installing TorchServe dependencies with the `nightly_torch` flag +``` +python ts_scripts/install_dependencies.py --cuda=cu121 --nightly_torch +``` + + +### Create a Torch exported model with AOTInductor + +The model is saved with `.so` extension +Here we are torch exporting with AOT Inductor with `max_autotune` mode. +This is also making use of `dynamic_shapes` to support batch size from 1 to 32. +In the code, the min batch_size is mentioned as 2 instead of 1. Its by design. The code works for batch size 1. You can find an explanation for this [here](https://pytorch.org/docs/main/export.html#expressing-dynamism) + +``` +python resnet18_torch_export.py +``` + +### Create model archive + +``` +torch-model-archiver --model-name res18-pt2 --handler image_classifier --version 1.0 --serialized-file resnet18_pt2.so --config-file model-config.yaml --extra-files ../../image_classifier/index_to_name.json +mkdir model_store +mv res18-pt2.mar model_store/. +``` + +#### Start TorchServe +``` +torchserve --start --model-store model_store --models res18-pt2=res18-pt2.mar --ncs +``` + +#### Run Inference + +``` +curl http://127.0.0.1:8080/predictions/res18-pt2 -T ../../image_classifier/kitten.jpg +``` + +produces the output + +``` +{ + "tabby": 0.4087875485420227, + "tiger_cat": 0.34661102294921875, + "Egyptian_cat": 0.13007202744483948, + "lynx": 0.024034621194005013, + "bucket": 0.011633828282356262 +} +``` diff --git a/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh b/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh new file mode 100755 index 0000000000..f7ee5c39db --- /dev/null +++ b/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Uninstall torchtext, torchdata, torch, torchvision, and torchaudio +pip uninstall torchtext torchdata torch torchvision torchaudio -y + +# Install nightly PyTorch and torchvision from the specified index URL +pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed + +# Optional: Display the installed PyTorch and torchvision versions +python -c "import torch; print('PyTorch version:', torch.__version__)" +python -c "import torchvision; print('torchvision version:', torchvision.__version__)" + +echo "PyTorch and torchvision updated successfully!" diff --git a/examples/pt2/torch_export_aot_compile/model-config.yaml b/examples/pt2/torch_export_aot_compile/model-config.yaml new file mode 100644 index 0000000000..af05f710c1 --- /dev/null +++ b/examples/pt2/torch_export_aot_compile/model-config.yaml @@ -0,0 +1,6 @@ +batchSize: 15 +maxBatchDelay: 5000 +responseTimeout: 300 +pt2 : + export: + aot_compile: true diff --git a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py new file mode 100644 index 0000000000..db31228bac --- /dev/null +++ b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py @@ -0,0 +1,29 @@ +import os + +import torch +from torchvision.models import ResNet18_Weights, resnet18 + +torch.set_float32_matmul_precision("high") + +model = resnet18(weights=ResNet18_Weights.DEFAULT) +model.eval() + +with torch.no_grad(): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device=device) + example_inputs = (torch.randn(2, 3, 224, 224, device=device),) + + # Max value is 15 because of https://github.com/pytorch/pytorch/pull/116152 + # On a CUDA enabled device, we tested batch_size of 32. + batch_dim = torch.export.Dim("batch", min=2, max=15) + so_path = torch._export.aot_compile( + model, + example_inputs, + # Specify the first dimension of the input x as dynamic + dynamic_shapes={"x": {0: batch_dim}}, + # Specify the generated shared library path + options={ + "aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"), + "max_autotune": True, + }, + ) diff --git a/requirements/common.txt b/requirements/common.txt index a5b7501a6a..2023978013 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -4,3 +4,4 @@ captum==0.6.0 packaging==23.2 pynvml==11.5.0 pyyaml==6.0 +ninja==1.11.1.1 diff --git a/test/pytest/test_torch_export.py b/test/pytest/test_torch_export.py new file mode 100644 index 0000000000..7031c983eb --- /dev/null +++ b/test/pytest/test_torch_export.py @@ -0,0 +1,128 @@ +from pathlib import Path + +import torch +from pkg_resources import packaging + +from ts.torch_handler.image_classifier import ImageClassifier +from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext +from ts.utils.util import load_label_mapping +from ts_scripts.utils import try_and_handle + +CURR_FILE_PATH = Path(__file__).parent.absolute() +REPO_ROOT_DIR = CURR_FILE_PATH.parents[1] +EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_export_aot_compile") +TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg") +MAPPING_DATA = REPO_ROOT_DIR.joinpath( + "examples", "image_classifier", "index_to_name.json" +) +MODEL_SO_FILE = "resnet18_pt2.so" +MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") + + +PT_220_AVAILABLE = ( + True + if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1") + else False +) + +EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "bucket"] +TEST_CASES = [ + ("kitten.jpg", EXPECTED_RESULTS[0]), +] + + +import os + +import pytest + + +@pytest.fixture +def custom_working_directory(tmp_path): + # Set the custom working directory + custom_dir = tmp_path / "model_dir" + custom_dir.mkdir() + os.chdir(custom_dir) + yield custom_dir + # Clean up and return to the original working directory + os.chdir(tmp_path) + + +@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") +def test_torch_export_aot_compile(custom_working_directory): + # Get the path to the custom working directory + model_dir = custom_working_directory + + # Construct the path to the Python script to execute + script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py") + + # Get the .pt2 file from torch.export + cmd = "python " + script_path + try_and_handle(cmd) + + # Handler for Image classification + handler = ImageClassifier() + + # Context definition + ctx = MockContext( + model_pt_file=MODEL_SO_FILE, + model_dir=model_dir.as_posix(), + model_file=None, + model_yaml_config_file=MODEL_YAML_CFG_FILE, + ) + + torch.manual_seed(42 * 42) + handler.initialize(ctx) + handler.context = ctx + handler.mapping = load_label_mapping(MAPPING_DATA) + + data = {} + with open(TEST_DATA, "rb") as image: + image_file = image.read() + byte_array_type = bytearray(image_file) + data["body"] = byte_array_type + + result = handler.handle([data], ctx) + + labels = list(result[0].keys()) + + assert labels == EXPECTED_RESULTS + + +@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") +def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): + # Get the path to the custom working directory + model_dir = custom_working_directory + + # Construct the path to the Python script to execute + script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py") + + # Get the .pt2 file from torch.export + cmd = "python " + script_path + try_and_handle(cmd) + + # Handler for Image classification + handler = ImageClassifier() + + # Context definition + ctx = MockContext( + model_pt_file=MODEL_SO_FILE, + model_dir=model_dir.as_posix(), + model_file=None, + model_yaml_config_file=MODEL_YAML_CFG_FILE, + ) + + torch.manual_seed(42 * 42) + handler.initialize(ctx) + handler.context = ctx + handler.mapping = load_label_mapping(MAPPING_DATA) + + data = {} + with open(TEST_DATA, "rb") as image: + image_file = image.read() + byte_array_type = bytearray(image_file) + data["body"] = byte_array_type + + # Send a batch of 16 elements + result = handler.handle([data for i in range(15)], ctx) + + assert len(result) == 15 diff --git a/ts/handler_utils/torch_export/__init__.py b/ts/handler_utils/torch_export/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ts/handler_utils/torch_export/load_model.py b/ts/handler_utils/torch_export/load_model.py new file mode 100644 index 0000000000..fe3ca867c5 --- /dev/null +++ b/ts/handler_utils/torch_export/load_model.py @@ -0,0 +1,27 @@ +import tempfile + +import torch.fx._pytree as fx_pytree +from torch._inductor.utils import aot_inductor_launcher, cache_dir +from torch.utils import _pytree as pytree +from torch.utils.cpp_extension import load_inline + + +def load_exported_model(model_so_path, device): + module = load_inline( + name="aot_inductor", + cpp_sources=[aot_inductor_launcher(model_so_path, device)], + # use a unique build directory to avoid test interference + build_directory=tempfile.mkdtemp(dir=cache_dir()), + functions=["run", "get_call_spec"], + with_cuda=("cuda" == device), + ) + call_spec = module.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + + def optimized(*args): + flat_inputs = fx_pytree.tree_flatten_spec((args, {}), in_spec) + flat_outputs = module.run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 711e24956c..02b56ca98c 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -53,6 +53,10 @@ ) PT2_AVAILABLE = False +if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"): + PT220_AVAILABLE = True +else: + PT220_AVAILABLE = False if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: @@ -180,26 +184,43 @@ def initialize(self, context): self.model = setup_ort_session(self.model_pt_path, self.map_location) logger.info("Succesfully setup ort session") + elif ( + self.model_pt_path.endswith(".so") + and self._use_torch_export_aot_compile() + and PT220_AVAILABLE + ): + # Set cuda device to the gpu_id of the backend worker + # This is needed as the API for loading the exported model doesn't yet have a device id + if torch.cuda.is_available() and properties.get("gpu_id") is not None: + torch.cuda.set_device(self.device) + + self.model = self._load_torch_export_aot_compile(self.model_pt_path) + logger.warning( + "torch._export is an experimental feature! Succesfully loaded torch exported model." + ) else: raise RuntimeError("No model weights could be loaded") if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config: pt2_value = self.model_yaml_config["pt2"] - # pt2_value can be the backend, passed as a str, or arbitrary kwargs, passed as a dict - if isinstance(pt2_value, str): - compile_options = dict(backend=pt2_value) - elif isinstance(pt2_value, dict): - compile_options = pt2_value + if "export" in pt2_value: + valid_backend = False else: - raise ValueError("pt2 should be str or dict") + # pt2_value can be the backend, passed as a str, or arbitrary kwargs, passed as a dict + if isinstance(pt2_value, str): + compile_options = dict(backend=pt2_value) + elif isinstance(pt2_value, dict): + compile_options = pt2_value + else: + raise ValueError("pt2 should be str or dict") - # if backend is not provided, compile will use its default, which is valid - valid_backend = ( - check_valid_pt2_backend(compile_options["backend"]) - if "backend" in compile_options - else True - ) + # if backend is not provided, compile will use its default, which is valid + valid_backend = ( + check_valid_pt2_backend(compile_options["backend"]) + if "backend" in compile_options + else True + ) else: valid_backend = False @@ -234,6 +255,11 @@ def initialize(self, context): self.initialized = True + def _load_torch_export_aot_compile(self, model_so_path): + from ts.handler_utils.torch_export.load_model import load_exported_model + + return load_exported_model(model_so_path, self.map_location) + def _load_torchscript_model(self, model_pt_path): """Loads the PyTorch model and returns the NN model object. @@ -285,6 +311,18 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path): model.load_state_dict(state_dict) return model + def _use_torch_export_aot_compile(self): + torch_export_aot_compile = False + if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config: + # Check if torch._export.aot_compile is being used + pt2_value = self.model_yaml_config["pt2"] + export_value = pt2_value.get("export", None) + if isinstance(export_value, dict) and "aot_compile" in export_value: + torch_export_aot_compile = ( + True if export_value["aot_compile"] == True else False + ) + return torch_export_aot_compile + @timed def preprocess(self, data): """ diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 44c851b6bb..8724dd70a4 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1147,3 +1147,10 @@ sam zlib verifier fastgen +reproducibility +AOTInductor +aot +compilable +nightlies +torchexportaotcompile +autotune