Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2.5 vLLM generator (based on LlamaGenerator), fix batch 1 issue with generator's decode forward #17422

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 69 additions & 26 deletions models/demos/llama3/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import PIL
from llama_models.llama3.api.chat_format import create_vision_mask
from llama_models.llama3.api.tokenizer import Tokenizer
import ttnn

from models.demos.llama3.tt.generator import LlamaGenerator
Expand All @@ -21,6 +20,41 @@
from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN_ID, MLLAMA_IMAGE_TOKEN


def initialize_vllm_text_transformer(
hf_config,
mesh_device,
max_batch_size,
max_seq_len,
n_layers=None,
dtype=ttnn.bfloat8_b,
optimizations=LlamaOptimizations.performance,
):
# Load model args, weights
model_args = TtModelArgs(
mesh_device,
instruct=("Instruct" in hf_config._name_or_path),
max_batch_size=max_batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
)
assert model_args.model_name.replace("-", "") in hf_config._name_or_path.replace(
"-", ""
), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model name ({model_args.model_name}) with model weights ({model_args.DEFAULT_CKPT_DIR})."
if n_layers is not None:
model_args.n_layers = n_layers
state_dict = model_args.load_state_dict()

tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
use_paged_kv_cache=True,
)
return tt_model, model_args


def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
"""
Based on vllm.model_executor.models.mllama.py::input_processor_for_mllama().
Expand Down Expand Up @@ -140,33 +174,42 @@ def __init__(self, *args, **kwargs):

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None):
instruct_mode = "Instruct" in hf_config._name_or_path
max_seq_len = 131072 # TODO: modify this for different models/devices
optimizations = LlamaOptimizations.performance # TODO: maybe change to accuracy
dtype = ttnn.bfloat8_b

# Load model args, weights
model_args = TtModelArgs(
tt_model, model_args = initialize_vllm_text_transformer(
hf_config,
mesh_device,
instruct=instruct_mode,
max_batch_size=max_batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
max_batch_size,
max_seq_len=131072,
n_layers=n_layers,
dtype=ttnn.bfloat8_b,
optimizations=LlamaOptimizations.performance,
)
assert (
model_args.model_name in hf_config._name_or_path
), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model weights ({model_args.DEFAULT_CKPT_DIR})."
if n_layers is not None:
model_args.n_layers = n_layers
state_dict = model_args.load_state_dict()

tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
use_paged_kv_cache=True,
return cls(tt_model, model_args, mesh_device)

@property
def cache_path(self):
return self.model_args.model_cache_path

def prefill_forward(self, *args, **kwargs):
return super().prefill_forward_text(*args, **kwargs)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)


class TtQwen2ForCausalLM(LlamaGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None):
tt_model, model_args = initialize_vllm_text_transformer(
hf_config,
mesh_device,
max_batch_size,
max_seq_len=131072,
n_layers=n_layers,
dtype=ttnn.bfloat8_b,
optimizations=LlamaOptimizations.performance,
)
return cls(tt_model, model_args, mesh_device)

Expand Down
8 changes: 5 additions & 3 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
assert current_pos.shape[0] == B, "Batch size mismatch"
assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size"

# Necessary padding to be full tile sized when on device
tokens = torch.nn.functional.pad(tokens.view(-1), (0, 32 - len(tokens)), "constant", 0)
tokens = ttnn.from_torch(
tokens.view(-1),
tokens,
device=None,
dtype=ttnn.uint32,
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
Expand Down Expand Up @@ -254,10 +256,10 @@ def process_output_decode(self, tt_out, B, S=1):
num_links=2,
cluster_axis=0,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
topology=self.args.ccl_topology(),
)
else:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=self.args.ccl_topology())
tt_out = ttnn.untilize(tt_out, use_multicore=True)
if self.args.num_devices > 1:
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
Expand Down
Loading