Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I disable legacy processing in llava-next #35457

Open
1 of 4 tasks
foreverpiano opened this issue Dec 30, 2024 · 3 comments
Open
1 of 4 tasks

How can I disable legacy processing in llava-next #35457

foreverpiano opened this issue Dec 30, 2024 · 3 comments
Labels

Comments

@foreverpiano
Copy link

foreverpiano commented Dec 30, 2024

System Info

4.47.1

Who can help?

vision models: @amyeroberts, @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
if input_ids.shape[1] != 1:
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47

Sample script

def main():
    args = parse_args()
    
    processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
    model = LlavaNextForConditionalGeneration.from_pretrained(
        "llava-hf/llava-v1.6-mistral-7b-hf",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        attn_implementation="flash_attention_2",
    ).to("cuda:0")
    
    setup_model_with_compression(model, args)
    
    url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
    image = Image.open(requests.get(url, stream=True).raw)
    
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": "What is shown in this image?"},
            ],
        },
    ]
    
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")
    
    output = model.generate(**inputs, max_new_tokens=100)
    print(processor.decode(output[0], skip_special_tokens=True))

Expected behavior

how does the legacy processing work? can I disable it ?

@LysandreJik
Copy link
Member

cc @zucchini-nlp as well

@zucchini-nlp
Copy link
Member

@foreverpiano the legacy processing is disabled at this moment with the official checkpoints. There is one false warning emitted though which will be removed in the next release

If you want to update your own llava checkpoint for the new processing logic, you can do so with

processor.vision_feature_select_strategy = model.config.vision_feature_select_strategy
processor.patch_size = model.config.vision_config.patch_size
model.config.image_seq_length = {{MIN_SEQ_LENGTH_ONE_IMAGE_TAKES_FOR_YOUR_ARCHITECTURE}}

@foreverpiano
Copy link
Author

foreverpiano commented Jan 6, 2025

@zucchini-nlp Thanks for your reply. But when I use the above script, it actually runs the legacy processing. Is it about the transformer version?

pip info:

torch==2.4.0
torchvision==0.19.0
numpy==1.26.0
transformers==4.47.1
datasets==2.15.0
accelerate==1.2.1
urllib3==2.1.0
bitsandbytes==0.41.2
chardet==5.2.0
scipy==1.11.4
deepdiff==6.7.1
diffusers==0.8.0
sentencepiece==0.1.99
protobuf==4.25.1

jieba==0.42.1
fuzzywuzzy==0.18.0
rouge==1.0.1
IPython==8.18.1

# tools for installing flash-attn
packaging==24.1
ninja
triton==3.0.0
pytest==8.2.1
flash_attn

matplotlib

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants