From ed348b122ee4964036197ba81a74c1c349155d49 Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Wed, 29 Jan 2025 18:05:28 +0000 Subject: [PATCH 1/2] Add Qwen2.5 vLLM generator (based on LlamaGenerator), update ccl topology in process_output_decode Signed-off-by: Salar Hosseini (cherry picked from commit 4fbdcc30acedf0a7cc1d96b0ef06898d12a29919) --- models/demos/llama3/tt/generator_vllm.py | 95 +++++++++++++++++------- models/demos/llama3/tt/llama_model.py | 4 +- 2 files changed, 71 insertions(+), 28 deletions(-) diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py index bcde0bb45d2..cff4b51b440 100644 --- a/models/demos/llama3/tt/generator_vllm.py +++ b/models/demos/llama3/tt/generator_vllm.py @@ -7,7 +7,6 @@ import torch import PIL from llama_models.llama3.api.chat_format import create_vision_mask -from llama_models.llama3.api.tokenizer import Tokenizer import ttnn from models.demos.llama3.tt.generator import LlamaGenerator @@ -21,6 +20,41 @@ from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN_ID, MLLAMA_IMAGE_TOKEN +def initialize_vllm_text_transformer( + hf_config, + mesh_device, + max_batch_size, + max_seq_len, + n_layers=None, + dtype=ttnn.bfloat8_b, + optimizations=LlamaOptimizations.performance, +): + # Load model args, weights + model_args = TtModelArgs( + mesh_device, + instruct=("Instruct" in hf_config._name_or_path), + max_batch_size=max_batch_size, + optimizations=optimizations, + max_seq_len=max_seq_len, + ) + assert model_args.model_name.replace("-", "") in hf_config._name_or_path.replace( + "-", "" + ), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model name ({model_args.model_name}) with model weights ({model_args.DEFAULT_CKPT_DIR})." + if n_layers is not None: + model_args.n_layers = n_layers + state_dict = model_args.load_state_dict() + + tt_model = TtTransformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + use_paged_kv_cache=True, + ) + return tt_model, model_args + + def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): """ Based on vllm.model_executor.models.mllama.py::input_processor_for_mllama(). @@ -140,33 +174,42 @@ def __init__(self, *args, **kwargs): @classmethod def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None): - instruct_mode = "Instruct" in hf_config._name_or_path - max_seq_len = 131072 # TODO: modify this for different models/devices - optimizations = LlamaOptimizations.performance # TODO: maybe change to accuracy - dtype = ttnn.bfloat8_b - - # Load model args, weights - model_args = TtModelArgs( + tt_model, model_args = initialize_vllm_text_transformer( + hf_config, mesh_device, - instruct=instruct_mode, - max_batch_size=max_batch_size, - optimizations=optimizations, - max_seq_len=max_seq_len, + max_batch_size, + max_seq_len=131072, + n_layers=n_layers, + dtype=ttnn.bfloat8_b, + optimizations=LlamaOptimizations.performance, ) - assert ( - model_args.model_name in hf_config._name_or_path - ), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model weights ({model_args.DEFAULT_CKPT_DIR})." - if n_layers is not None: - model_args.n_layers = n_layers - state_dict = model_args.load_state_dict() - - tt_model = TtTransformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - use_paged_kv_cache=True, + return cls(tt_model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args.model_cache_path + + def prefill_forward(self, *args, **kwargs): + return super().prefill_forward_text(*args, **kwargs) + + def decode_forward(self, *args, **kwargs): + return super().decode_forward_text(*args, **kwargs) + + +class TtQwen2ForCausalLM(LlamaGenerator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None): + tt_model, model_args = initialize_vllm_text_transformer( + hf_config, + mesh_device, + max_batch_size, + max_seq_len=131072, + n_layers=n_layers, + dtype=ttnn.bfloat8_b, + optimizations=LlamaOptimizations.performance, ) return cls(tt_model, model_args, mesh_device) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index c514bb1a3b7..4a4cab1689f 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -254,10 +254,10 @@ def process_output_decode(self, tt_out, B, S=1): num_links=2, cluster_axis=0, mesh_device=self.mesh_device, - topology=ttnn.Topology.Linear, + topology=self.args.ccl_topology(), ) else: - tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=self.args.ccl_topology()) tt_out = ttnn.untilize(tt_out, use_multicore=True) if self.args.num_devices > 1: tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() From 17b2695d327e3d892e18861aa7de8388c1af1442 Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Thu, 30 Jan 2025 18:14:48 +0000 Subject: [PATCH 2/2] [Llama3] Pad decode tokens to tile size to fix batch 1 issue with generator Signed-off-by: Salar Hosseini (cherry picked from commit 4221d8ad53663d547b0e69858b244cc9903846ba) --- models/demos/llama3/tt/llama_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 4a4cab1689f..3b784ad0bbb 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -174,8 +174,10 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): assert current_pos.shape[0] == B, "Batch size mismatch" assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" + # Necessary padding to be full tile sized when on device + tokens = torch.nn.functional.pad(tokens.view(-1), (0, 32 - len(tokens)), "constant", 0) tokens = ttnn.from_torch( - tokens.view(-1), + tokens, device=None, dtype=ttnn.uint32, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),