-
Notifications
You must be signed in to change notification settings - Fork 871
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial commit for gpt_fast example Remove files and finish tests Add readme Complete readme * Add int8 quantization to example * Add missing json file * Enable streaming response * Remove print * Adapt unit test to list return value, fix lint error * Assert if batch_size is not 1 * Addressed review comments * Added GPU compatibility remark --------- Co-authored-by: Ankith Gunapal <[email protected]>
- Loading branch information
Showing
5 changed files
with
509 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
|
||
## GPT-Fast | ||
|
||
[GPT fast](https://github.com/pytorch-labs/gpt-fast) is a simple and efficient pytorch-native transformer text generation. | ||
|
||
It features: | ||
* Very low latency | ||
* <1000 lines of python | ||
* No dependencies other than PyTorch and sentencepiece | ||
* int8/int4 quantization | ||
* Speculative decoding | ||
* Tensor parallelism | ||
* Supports Nvidia and AMD GPUs | ||
|
||
More details about gpt-fast can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-2/). | ||
The examples has been tested on A10, A100 as well as H100. | ||
|
||
|
||
#### Pre-requisites | ||
|
||
`cd` to the example folder `examples/large_models/gpt_fast` | ||
|
||
Install dependencies and upgrade torch to nightly build (currently required) | ||
``` | ||
git clone https://github.com/pytorch-labs/gpt-fast/ | ||
pip install sentencepiece huggingface_hub | ||
pip uninstall torchtext torchdata torch torchvision torchaudio -y | ||
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed | ||
``` | ||
|
||
### Step 1: Download and convert the weights | ||
|
||
Currently supported models: | ||
``` | ||
openlm-research/open_llama_7b | ||
meta-llama/Llama-2-7b-chat-hf | ||
meta-llama/Llama-2-13b-chat-hf | ||
meta-llama/Llama-2-70b-chat-hf | ||
codellama/CodeLlama-7b-Python-hf | ||
codellama/CodeLlama-34b-Python-hf | ||
``` | ||
Prepare weights: | ||
``` | ||
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf | ||
cd gpt-fast | ||
huggingface-cli login | ||
./scripts/prepare.sh $MODEL_REPO | ||
cd .. | ||
``` | ||
|
||
### (Optional) Step 1.5: Quantize the model to int4 | ||
|
||
To speed up model loading and inference even further we can optionally quantize the model to int4 instead of int8. Please see the [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/) for details on the potential accuracy loss. | ||
|
||
``` | ||
cd gpt-fast | ||
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 | ||
cd .. | ||
``` | ||
|
||
The quantized model will show up as checkpoints/$MODEL_REPO/model_int4.pth. To enable it in the example you need to exchange the filename in the [`model_config.yaml`](./model_config.yaml) file. | ||
|
||
|
||
### Step 2: Generate model archive | ||
|
||
``` | ||
torch-model-archiver --model-name gpt_fast --version 1.0 --handler handler.py --config-file model_config.yaml --extra-files "gpt-fast/generate.py,gpt-fast/model.py,gpt-fast/quantize.py,gpt-fast/tp.py" --archive-format no-archive | ||
mv gpt-fast/checkpoints gpt_fast/ | ||
``` | ||
|
||
### Step 3: Add the model archive to model store | ||
|
||
``` | ||
mkdir model_store | ||
mv gpt_fast model_store | ||
``` | ||
|
||
### Step 4: Start torchserve | ||
|
||
``` | ||
torchserve --start --ncs --model-store model_store --models gpt_fast | ||
``` | ||
|
||
### Step 5: Run inference | ||
|
||
``` | ||
curl "http://localhost:8080/predictions/gpt_fast" -T request.json | ||
# Returns: The capital of France, Paris, is a city of romance, fashion, and art. The city is home to the Eiffel Tower, the Louvre, and the Arc de Triomphe. Paris is also known for its cafes, restaurants | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import json | ||
import logging | ||
import os | ||
import time | ||
from pathlib import Path | ||
|
||
import torch | ||
from generate import _load_model, decode_one_token, encode_tokens, prefill | ||
from sentencepiece import SentencePieceProcessor | ||
|
||
from ts.handler_utils.timer import timed | ||
from ts.protocol.otf_message_handler import send_intermediate_predict_response | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GptHandler(BaseHandler): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.model = None | ||
self.tokenizer = None | ||
self.context = None | ||
self.prefill = prefill | ||
self.decode_one_token = decode_one_token | ||
self.initialized = False | ||
self.device = torch.device("cpu") | ||
self.prompt_length = 0 | ||
|
||
def initialize(self, ctx): | ||
self.context = ctx | ||
properties = ctx.system_properties | ||
if torch.cuda.is_available(): | ||
self.map_location = "cuda" | ||
self.device = torch.device( | ||
self.map_location + ":" + str(os.getenv("LOCAL_RANK", 0)) | ||
) | ||
|
||
checkpoint_path = Path(ctx.model_yaml_config["handler"]["converted_ckpt_dir"]) | ||
assert checkpoint_path.is_file(), checkpoint_path | ||
|
||
tokenizer_path = checkpoint_path.parent / "tokenizer.model" | ||
assert tokenizer_path.is_file(), tokenizer_path | ||
|
||
logger.info("Loading model ...") | ||
t0 = time.time() | ||
self.model = _load_model(checkpoint_path, self.device, torch.bfloat16, False) | ||
torch.cuda.synchronize() | ||
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds") | ||
|
||
self.tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) | ||
|
||
if ctx.model_yaml_config["handler"]["compile"]: | ||
self.decode_one_token = torch.compile( | ||
self.decode_one_token, mode="reduce-overhead", fullgraph=True | ||
) | ||
self.prefill = torch.compile(self.prefill, fullgraph=True, dynamic=True) | ||
|
||
torch.manual_seed(42 * 42) | ||
|
||
self.initialized = True | ||
|
||
@timed | ||
def preprocess(self, requests): | ||
assert ( | ||
len(requests) == 1 | ||
), "GPT fast is currently only supported with batch_size=1" | ||
req_data = requests[0] | ||
|
||
input_data = req_data.get("data") or req_data.get("body") | ||
|
||
if isinstance(input_data, (bytes, bytearray)): | ||
input_data = input_data.decode("utf-8") | ||
|
||
input_data = json.loads(input_data) | ||
|
||
prompt = input_data["prompt"] | ||
|
||
encoded = encode_tokens(self.tokenizer, prompt, bos=True, device=self.device) | ||
|
||
self.prompt_length = encoded.size(0) | ||
|
||
return { | ||
"encoded": encoded, | ||
"max_new_tokens": input_data.get("max_new_tokens", 50), | ||
} | ||
|
||
@timed | ||
def inference(self, input_data): | ||
tokenizer = self.tokenizer | ||
period_id = tokenizer.encode(".")[0] | ||
|
||
def call_me(x): | ||
nonlocal period_id, tokenizer | ||
text = self.tokenizer.decode([period_id] + x.tolist())[1:] | ||
send_intermediate_predict_response( | ||
[text], | ||
self.context.request_ids, | ||
"Intermediate Prediction success", | ||
200, | ||
self.context, | ||
) | ||
|
||
y = self.generate( | ||
input_data["encoded"], | ||
input_data["max_new_tokens"], | ||
callback=call_me, | ||
temperature=0.8, | ||
top_k=1, | ||
) | ||
logger.info(f"Num tokens = {y.size(0) - self.prompt_length}") | ||
return y | ||
|
||
def postprocess(self, y): | ||
return [""] | ||
|
||
@torch.no_grad() | ||
def generate( | ||
self, | ||
prompt: torch.Tensor, | ||
max_new_tokens: int, | ||
*, | ||
callback=lambda x: x, | ||
**sampling_kwargs, | ||
) -> torch.Tensor: | ||
""" | ||
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | ||
""" | ||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
T = prompt.size(0) | ||
T_new = T + max_new_tokens | ||
|
||
max_seq_length = min(T_new, self.model.config.block_size) | ||
|
||
device, dtype = prompt.device, prompt.dtype | ||
with torch.device(device): | ||
self.model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) | ||
|
||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
empty = torch.empty(T_new, dtype=dtype, device=device) | ||
empty[:T] = prompt | ||
seq = empty | ||
input_pos = torch.arange(0, T, device=device) | ||
|
||
next_token = self.prefill( | ||
self.model, prompt.view(1, -1), input_pos, **sampling_kwargs | ||
) | ||
period_id = self.tokenizer.encode(".")[0] | ||
text = self.tokenizer.decode([period_id] + next_token.tolist())[1:] | ||
send_intermediate_predict_response( | ||
[text], | ||
self.context.request_ids, | ||
"Intermediate Prediction success", | ||
200, | ||
self.context, | ||
) | ||
|
||
seq[T] = next_token | ||
|
||
input_pos = torch.tensor([T], device=device, dtype=torch.int) | ||
|
||
generated_tokens, _ = self.decode_n_tokens( | ||
next_token.view(1, -1), | ||
input_pos, | ||
max_new_tokens - 1, | ||
callback=callback, | ||
**sampling_kwargs, | ||
) | ||
seq[T + 1 :] = torch.cat(generated_tokens) | ||
|
||
return seq | ||
|
||
def decode_n_tokens( | ||
self, | ||
cur_token: torch.Tensor, | ||
input_pos: torch.Tensor, | ||
num_new_tokens: int, | ||
callback=lambda _: _, | ||
**sampling_kwargs, | ||
): | ||
new_tokens, new_probs = [], [] | ||
for i in range(num_new_tokens): | ||
with torch.backends.cuda.sdp_kernel( | ||
enable_flash=False, enable_mem_efficient=False, enable_math=True | ||
): # Actually better for Inductor to codegen attention here | ||
next_token, next_prob = self.decode_one_token( | ||
self.model, cur_token, input_pos, **sampling_kwargs | ||
) | ||
input_pos += 1 | ||
new_tokens.append(next_token.clone()) | ||
callback(new_tokens[-1]) | ||
new_probs.append(next_prob.clone()) | ||
cur_token = next_token.view(1, -1) | ||
return new_tokens, new_probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#frontend settings | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
maxBatchDelay: 200 | ||
responseTimeout: 300 | ||
deviceType: "gpu" | ||
continuousBatching: false | ||
handler: | ||
converted_ckpt_dir: "checkpoints/meta-llama/Llama-2-7b-hf/model.pth" | ||
max_new_tokens: 50 | ||
compile: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"prompt": "The capital of France", | ||
"max_new_tokens": 50 | ||
} |
Oops, something went wrong.