Skip to content

Commit

Permalink
[Model] add vllm compatible models (#544)
Browse files Browse the repository at this point in the history
* Add VLLM model integration and update configurations

- Introduce VLLM model in the model registry.
- Update AVAILABLE_MODELS to include new models:
  - models/__init__.py: Added "aria", "internvideo2", "llama_vision", "oryx", "ross", "slime", "videochat2", "vllm", "xcomposer2_4KHD", "xcomposer2d5".
- Create vllm.py for VLLM model implementation:
  - Implemented encoding for images and videos.
  - Added methods for generating responses and handling multi-round generation.
- Update mmu tasks with new prompt formats and evaluation metrics:
  - mmmu_val.yaml: Added specific kwargs for prompt types.
  - mmmu_val_reasoning.yaml: Enhanced prompts for reasoning tasks.
  - utils.py: Adjusted evaluation rules and scoring for predictions.
- Added script for easy model execution:
  - vllm_qwen2vl.sh: Script to run VLLM with specified parameters.

* Set environment variables for VLLM script

- Configure environment for better performance and debugging.
- Added variables to control multiprocessing and NCCL behavior.

miscs/vllm_qwen2vl.sh:
- Set `VLLM_WORKER_MULTIPROC_METHOD` to `spawn` for compatibility.
- Enabled `NCCL_BLOCKING_WAIT` to avoid hangs.
- Increased `NCCL_TIMEOUT` to 18000000 for long-running processes.
- Set `NCCL_DEBUG` to `DEBUG` for detailed logs.

* Rename scripts and update paths

- Renamed representation scripts for clarity.
  - miscs/repr_scripts.sh -> miscs/model_dryruns/llava_1_5.sh
  - miscs/cicd_qwen2vl.sh -> miscs/model_dryruns/qwen2vl.sh
  - miscs/tinyllava_repr_scripts.sh -> miscs/model_dryruns/tinyllava.sh
  - miscs/vllm_qwen2vl.sh -> miscs/model_dryruns/vllm_qwen2vl.sh
- Updated parameters in the vllm_qwen2vl.sh script.
  - miscs/model_dryruns/vllm_qwen2vl.sh: Added `--limit=64` to output path command.

* Optimize image handling in VLLM model

- Simplify image conversion in the `to_base64` method:
  - vllm.py: Directly convert input image to RGB format instead of copying it.
- Remove unnecessary base64 encoding for images:
  - vllm.py: Return the PIL image directly instead of converting it to base64.
- Update video frame processing to return PIL images:
  - vllm.py: Replace base64 encoding of frames with returning the PIL frames directly.

* Revert "Optimize image handling in VLLM model"

This reverts commit 469e1fc.

* use threads to encode visuals

---------

Co-authored-by: kcz358 <[email protected]>
  • Loading branch information
Luodian and kcz358 authored Feb 20, 2025
1 parent 2508d42 commit 968d5f1
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 28 deletions.
21 changes: 11 additions & 10 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
logger.add(sys.stdout, level="WARNING")

AVAILABLE_MODELS = {
"aria": "Aria",
"auroracap": "AuroraCap",
"batch_gpt4": "BatchGPT4",
"claude": "Claude",
Expand All @@ -21,44 +22,44 @@
"gpt4v": "GPT4V",
"idefics2": "Idefics2",
"instructblip": "InstructBLIP",
"internvideo2": "InternVideo2",
"internvl": "InternVLChat",
"internvl2": "InternVL2",
"llama_vid": "LLaMAVid",
"llama_vision": "LlamaVision",
"llava": "Llava",
"llava_hf": "LlavaHf",
"llava_onevision": "Llava_OneVision",
"llava_onevision_moviechat": "Llava_OneVision_MovieChat",
"llava_sglang": "LlavaSglang",
"llava_vid": "LlavaVid",
"slime": "Slime",
"longva": "LongVA",
"mantis": "Mantis",
"minicpm_v": "MiniCPM_V",
"minimonkey": "MiniMonkey",
"moviechat": "MovieChat",
"mplug_owl_video": "mplug_Owl",
"oryx": "Oryx",
"phi3v": "Phi3v",
"qwen_vl": "Qwen_VL",
"qwen2_vl": "Qwen2_VL",
"qwen2_5_vl": "Qwen2_5_VL",
"qwen2_5_vl_interleave": "Qwen2_5_VL_Interleave",
"qwen2_audio": "Qwen2_Audio",
"qwen2_vl": "Qwen2_VL",
"qwen_vl": "Qwen_VL",
"qwen_vl_api": "Qwen_VL_API",
"reka": "Reka",
"ross": "Ross",
"slime": "Slime",
"srt_api": "SRT_API",
"tinyllava": "TinyLlava",
"videoChatGPT": "VideoChatGPT",
"videochat2": "VideoChat2",
"video_llava": "VideoLLaVA",
"vila": "VILA",
"vita": "VITA",
"vllm": "VLLM",
"xcomposer2_4KHD": "XComposer2_4KHD",
"internvideo2": "InternVideo2",
"xcomposer2d5": "XComposer2D5",
"oryx": "Oryx",
"videochat2": "VideoChat2",
"llama_vision": "LlamaVision",
"aria": "Aria",
"ross": "Ross",
"vita": "VITA",
}


Expand Down
194 changes: 194 additions & 0 deletions lmms_eval/models/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio
import base64
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from io import BytesIO
from multiprocessing import cpu_count
from typing import List, Optional, Tuple, Union

import numpy as np
from accelerate import Accelerator, DistributedType
from decord import VideoReader, cpu
from loguru import logger as eval_logger
from PIL import Image
from tqdm import tqdm

from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

NUM_SECONDS_TO_SLEEP = 5

try:
from vllm import LLM, SamplingParams
except ImportError:
vllm = None


@register_model("vllm")
class VLLM(lmms):
def __init__(
self,
model_version: str = "Qwen/Qwen2.5-VL-3B-Instruct",
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.8,
batch_size: int = 1,
timeout: int = 60,
max_images: int = 32,
max_videos: int = 8,
max_audios: int = 8,
max_frame_num: int = 32,
threads: int = 16, # Threads to use for decoding visuals
trust_remote_code: Optional[bool] = True,
**kwargs,
) -> None:
super().__init__()
# Manually set a image token for GPT4V so that we can search for it
# and split the text and image
# Here we just use the same token as llava for convenient
self.model_version = model_version
self.max_images = max_images
self.max_frame_num = max_frame_num
self.threads = threads

accelerator = Accelerator()
self.client = LLM(
model=self.model_version,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
limit_mm_per_prompt={"image": max_images, "video": max_videos, "audio": max_audios},
trust_remote_code=trust_remote_code,
)
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.accelerator = accelerator
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes

self.device = self.accelerator.device
self.batch_size_per_gpu = int(batch_size)

# Function to encode the image
def encode_image(self, image: Union[Image.Image, str]):
if isinstance(image, str):
img = Image.open(image).convert("RGB")
else:
img = image.copy()

output_buffer = BytesIO()
img.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()

base64_str = base64.b64encode(byte_data).decode("utf-8")
return base64_str

# Function to encode the video
def encode_video(self, video_path):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frame_num, dtype=int)

# Ensure the last frame is included
if total_frame_num - 1 not in uniform_sampled_frames:
uniform_sampled_frames = np.append(uniform_sampled_frames, total_frame_num - 1)

frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()

base64_frames = []
for frame in frames:
img = Image.fromarray(frame)
output_buffer = BytesIO()
img.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
base64_frames.append(base64_str)

return base64_frames

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

batch_size = self.batch_size_per_gpu
batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)]
for batch_requests in batched_requests:
batched_messages = []
for idx in range(len(batch_requests)):
contexts, gen_kwargs, doc_to_visual, doc_id, task, split = batch_requests[idx].arguments
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if gen_kwargs["max_new_tokens"] > 4096:
gen_kwargs["max_new_tokens"] = 4096
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = 0.95

params = {
"temperature": gen_kwargs["temperature"],
"max_tokens": gen_kwargs["max_new_tokens"],
"top_p": gen_kwargs["top_p"],
}
sampling_params = SamplingParams(**params)

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
if None in visuals:
visuals = []
imgs = []
else:
visuals = self.flatten(visuals)
imgs = [] # multiple images or frames for video
all_tasks = []
with ThreadPoolExecutor(max_workers=self.threads) as executor:
for visual in visuals:
if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual):
all_tasks.append(executor.submit(self.encode_video, visual))
elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual):
all_tasks.append(executor.submit(self.encode_image, visual))
elif isinstance(visual, Image.Image):
all_tasks.append(executor.submit(self.encode_image, visual))

for task in all_tasks:
imgs.append(task.result())

messages = [{"role": "user", "content": []}]
# When there is no image token in the context, append the image to the text
messages[0]["content"].append({"type": "text", "text": contexts})
for img in imgs:
messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})

batched_messages.append(messages)

response = self.client.chat(sampling_params=sampling_params, messages=batched_messages)
response_text = [o.outputs[0].text for o in response]

assert len(response_text) == len(batch_requests)
res.extend(response_text)
pbar.update(len(batch_requests))

pbar.close()
return res

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "GPT4V not support"

def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation")
6 changes: 6 additions & 0 deletions lmms_eval/tasks/mmmu/mmmu_val.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,10 @@ metric_list:
aggregation: !function utils.mmmu_aggregate_results
higher_is_better: true

lmms_eval_specific_kwargs:
default:
prompt_type: "format"
multiple_choice_prompt: "Answer with the option's letter from the given choices directly."
open_ended_prompt: "Answer the question using a single word or phrase."

include: _default_template_yaml
20 changes: 17 additions & 3 deletions lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ doc_to_text: !function utils.mmmu_doc_to_text
doc_to_target: "answer"
# The return value of process_results will be used by metrics
process_results: !function utils.mmmu_reasoning_process_results
# repeats: 8
# filter_list:
# # - name: "pass@64"
# # filter:
# # - function: "take_first_k"
# # k: 64
# - name: "pass@8"
# filter:
# - function: "take_first_k"
# k: 8

metric_list:
- metric: mmmu_judge_acc
Expand All @@ -16,11 +26,15 @@ metric_list:
lmms_eval_specific_kwargs:
default:
prompt_type: "reasoning"
multiple_choice_prompt: "Please show step-by-step reasoning, and answer the question with option letter from given choices."
open_ended_prompt: "Please show step-by-step reasoning, and answer the question using a single word or phrase."
multiple_choice_prompt: "Please reason step by step, and answer the question with option letter from given choices in the format of Answer: <option letter>."
open_ended_prompt: "Please reason step by step, and answer the question using a single word or phrase in the format of Answer: <answer>."

generation_kwargs:
max_new_tokens: 256
max_new_tokens: 16384
temperature: 0.7
do_sample: true
top_p: 0.95
top_k: 50
until:
- "</s>"
- "Q:"
Expand Down
33 changes: 18 additions & 15 deletions lmms_eval/tasks/mmmu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@
```
# Evaluation Rules
- The model prediction contains the reasoning process, you should spot the final answer from the it.
- For multiple-choice questions: Score 1 if the predicted answer matches the correct answer.
- The model prediction may contain the reasoning process, you should spot the final answer from it.
- For multiple-choice questions: Score 1 if the predicted answer matches the ground truth answer, it can be directly in option letters or the content of the options.
- For open-ended questions:
* Score 1 if the prediction matches the answer semantically and contains all key elements
* Score 1 if the prediction matches the answer semantically, it can be in different format.
* Score 0 for partially correct answers or answers with extra incorrect information, even if the reasoning process is correct.
- Ignore minor differences in formatting, capitalization, or spacing since the model may explain in a different way.
- Treat numerical answers as correct if they match within reasonable precision
- For questions requiring units, both value and unit must be correct
# Strict Output format
[0/1]"""
0 or 1"""

if API_TYPE == "openai":
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
Expand All @@ -96,7 +96,7 @@ def get_chat_response(content: str, max_tokens: int, retries: int = 5):
payload = {
"model": MODEL_VERSION,
"messages": messages,
"temperature": 0.2,
"temperature": 0.0,
"max_tokens": max_tokens,
}

Expand Down Expand Up @@ -144,10 +144,7 @@ def construct_prompt(doc, mc_prompt="", open_ended_prompt=""):


def mmmu_doc_to_text(doc, lmms_eval_specific_kwargs=None):
if lmms_eval_specific_kwargs is not None and "multiple_choice_prompt" in lmms_eval_specific_kwargs:
question = construct_prompt(doc, lmms_eval_specific_kwargs["multiple_choice_prompt"], lmms_eval_specific_kwargs["open_ended_prompt"])
else:
question = construct_prompt(doc)
question = construct_prompt(doc, lmms_eval_specific_kwargs["multiple_choice_prompt"], lmms_eval_specific_kwargs["open_ended_prompt"])
if config["metadata"]["interleaved_format"]:
question = replace_images_tokens(question)
return question
Expand Down Expand Up @@ -178,11 +175,16 @@ def mmmu_process_results(doc, results):


def mmmu_reasoning_process_results(doc, results):
pred = results[0]
formatted_question = construct_prompt(doc, MC_PROMPT, OPEN_ENDED_PROMPT)
llm_judge_prompt = JUDGE_RULES.format(question=formatted_question, answer=doc["answer"], pred=pred)
llm_judge_score = get_chat_response(llm_judge_prompt, max_tokens=20, retries=3)
mmmu_judge_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "pred": pred, "score": llm_judge_score}
parsed_preds = []
scores = []
for pred in results:
formatted_question = construct_prompt(doc, MC_PROMPT, OPEN_ENDED_PROMPT)
llm_judge_prompt = JUDGE_RULES.format(question=formatted_question, answer=doc["answer"], pred=pred)
llm_judge_score = get_chat_response(llm_judge_prompt, max_tokens=20, retries=3)
scores.append(llm_judge_score)
parsed_preds.append(pred)

mmmu_judge_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "pred": parsed_preds, "score": scores}
return {"mmmu_judge_acc": mmmu_judge_acc}


Expand Down Expand Up @@ -247,7 +249,8 @@ def mmmu_aggregate_judge_results(results):
total_score = 0
for result in results:
try:
total_score += int(result["score"])
item_score = 1 if "1" in result["score"] or "[1]" in result["score"] else 0
total_score += item_score
except:
eval_logger.warning(f"Failed to convert score to int for {result['id']}: {result['score']}")
total_score += 0
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 968d5f1

Please sign in to comment.