Skip to content

Commit

Permalink
[VLM] Qwen2.5-VL
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Feb 5, 2025
1 parent 9a5b155 commit bf3b79e
Show file tree
Hide file tree
Showing 14 changed files with 1,315 additions and 52 deletions.
11 changes: 11 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `Qwen2_5_VLForConditionalGeneration`
* Qwen2.5-VL
* T + I<sup>E+</sup> + V<sup>E+</sup>
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
*
* ✅︎
* ✅︎
- * `UltravoxModel`
* Ultravox
* T + A<sup>E+</sup>
Expand Down Expand Up @@ -880,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
:::

:::{note}
To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
:::

### Pooling Models

See [this page](pooling-models) for more information on how to use pooling models.
Expand Down
31 changes: 31 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,36 @@ def run_qwen2_vl(question: str, modality: str):
return llm, prompt, stop_token_ids


# Qwen2.5-VL
def run_qwen2_5_vl(question: str, modality: str):

model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

llm = LLM(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
Expand All @@ -557,6 +587,7 @@ def run_qwen2_vl(question: str, modality: str):
"pixtral_hf": run_pixtral_hf,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
}


Expand Down
58 changes: 58 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,63 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
)


def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None

model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

llm = LLM(
model=model_name,
max_model_len=32768 if process_vision_info is None else 4096,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]

processor = AutoProcessor.from_pretrained(model_name)

prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

stop_token_ids = None

if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages,
return_video_sample_fps=False)

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=image_data,
chat_template=None,
)


model_example_map = {
"aria": load_aria,
"deepseek_vl_v2": load_deepseek_vl2,
Expand All @@ -404,6 +461,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
"qwen2_5_vl": load_qwen2_5_vl,
}


Expand Down
22 changes: 22 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
else ("half", "float")),
marks=[pytest.mark.core_model],
),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0.
"qwen2_vl": VLMTestInfo(
models=["Qwen/Qwen2-VL-2B-Instruct"],
test_type=(
Expand All @@ -138,6 +140,26 @@
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_vl": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.skipif(
TRANSFORMERS_VERSION < "4.49.0",
reason="HF model requires transformers>=4.49.0",
), pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def _test_processing_correctness(
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_3",
])
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def check_available_online(
trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
trust_remote_code=True),
# [Encoder-decoder]
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _placeholder_str(self, modality: ModalityStr,
return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
Expand All @@ -430,7 +430,7 @@ def _placeholder_str(self, modality: ModalityStr,
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"
Expand Down
58 changes: 34 additions & 24 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.model_executor.custom_op import CustomOp

Expand Down Expand Up @@ -772,8 +773,12 @@ def __init__(
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
base, is_neox_style, dtype)

self.mrope_section = mrope_section
if self.mrope_section:
Expand Down Expand Up @@ -831,49 +836,47 @@ def forward(
@staticmethod
def get_input_positions(
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""

llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
input_tokens,
image_grid_thw,
video_grid_thw,
image_token_id,
video_token_id,
vision_start_token_id,
vision_end_token_id,
spatial_merge_size,
context_len,
seq_len,
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)

return llm_positions.tolist(), mrope_position_delta

@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""

image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)

if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
Expand All @@ -892,6 +895,7 @@ def get_input_positions_tensor(

image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
Expand All @@ -915,9 +919,13 @@ def get_input_positions_tensor(
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_second_per_grid_t = 1.0
if second_per_grid_ts is not None:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video

llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
Expand All @@ -927,8 +935,10 @@ def get_input_positions_tensor(
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
tokens_per_second).long().flatten()

h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
Expand Down
Loading

0 comments on commit bf3b79e

Please sign in to comment.