Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2-VL used to work with inputs_embeds instead of input_ids, but no more #35463

Open
2 of 4 tasks
minostauros opened this issue Dec 31, 2024 · 0 comments · May be fixed by #35466
Open
2 of 4 tasks

Qwen2-VL used to work with inputs_embeds instead of input_ids, but no more #35463

minostauros opened this issue Dec 31, 2024 · 0 comments · May be fixed by #35466

Comments

@minostauros
Copy link

minostauros commented Dec 31, 2024

System Info

  • transformers version: 4.47.1
  • Platform: Linux-4.18.0-513.18.1.el8_9.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.27.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Preparation

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

model = Qwen2VLForConditionalGeneration.from_pretrained(
 "Qwen/Qwen2-VL-7B-Instruct",
 torch_dtype=torch.bfloat16,
 attn_implementation="eager", # flash_attention_2 also produces the same error
 device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

Working example

generated_ids = model.generate(**inputs, max_new_tokens=128)

Used to work

Worked in 9470d65 but not in v4.47.1 [comparison]

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
pixel_values = inputs["pixel_values"]
image_grid_thw = inputs["image_grid_thw"]

inputs_embeds = model.model.embed_tokens(input_ids)
if pixel_values is not None:
    pixel_values = pixel_values.type(model.visual.get_dtype())
    image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
    n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
    n_image_features = image_embeds.shape[0]
    if n_image_tokens != n_image_features:
        raise ValueError(
            f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
        )
    image_mask = (
        (input_ids == model.config.image_token_id)
        .unsqueeze(-1)
        .expand_as(inputs_embeds)
        .to(inputs_embeds.device)
    )
    image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
    inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if attention_mask is not None:
    attention_mask = attention_mask.to(inputs_embeds.device)

generated_ids = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=128)

Expected behavior

The latter should work the same as the former.

The latter's error message example

  File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 578, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (2362) must match the size of tensor b (1182) at non-singleton dimension 3
@minostauros minostauros changed the title Qwen2-VL used to work with input_embeds instead of input_ids, but no more Qwen2-VL used to work with inputs_embeds instead of input_ids, but no more Dec 31, 2024
@minostauros minostauros changed the title Qwen2-VL used to work with inputs_embeds instead of input_ids, but no more Qwen2-VL should work with inputs_embeds instead of input_ids Dec 31, 2024
@minostauros minostauros changed the title Qwen2-VL should work with inputs_embeds instead of input_ids Qwen2-VL used to work with inputs_embeds instead of input_ids, but no more Dec 31, 2024
@minostauros minostauros linked a pull request Dec 31, 2024 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants