diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index f07d9af1..1a3dbd61 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -11,6 +11,7 @@ logger.add(sys.stdout, level="WARNING") AVAILABLE_MODELS = { + "aria": "Aria", "auroracap": "AuroraCap", "batch_gpt4": "BatchGPT4", "claude": "Claude", @@ -21,44 +22,44 @@ "gpt4v": "GPT4V", "idefics2": "Idefics2", "instructblip": "InstructBLIP", + "internvideo2": "InternVideo2", "internvl": "InternVLChat", "internvl2": "InternVL2", "llama_vid": "LLaMAVid", + "llama_vision": "LlamaVision", "llava": "Llava", "llava_hf": "LlavaHf", "llava_onevision": "Llava_OneVision", "llava_onevision_moviechat": "Llava_OneVision_MovieChat", "llava_sglang": "LlavaSglang", "llava_vid": "LlavaVid", - "slime": "Slime", "longva": "LongVA", "mantis": "Mantis", "minicpm_v": "MiniCPM_V", "minimonkey": "MiniMonkey", "moviechat": "MovieChat", "mplug_owl_video": "mplug_Owl", + "oryx": "Oryx", "phi3v": "Phi3v", - "qwen_vl": "Qwen_VL", - "qwen2_vl": "Qwen2_VL", "qwen2_5_vl": "Qwen2_5_VL", "qwen2_5_vl_interleave": "Qwen2_5_VL_Interleave", "qwen2_audio": "Qwen2_Audio", + "qwen2_vl": "Qwen2_VL", + "qwen_vl": "Qwen_VL", "qwen_vl_api": "Qwen_VL_API", "reka": "Reka", + "ross": "Ross", + "slime": "Slime", "srt_api": "SRT_API", "tinyllava": "TinyLlava", "videoChatGPT": "VideoChatGPT", + "videochat2": "VideoChat2", "video_llava": "VideoLLaVA", "vila": "VILA", + "vita": "VITA", + "vllm": "VLLM", "xcomposer2_4KHD": "XComposer2_4KHD", - "internvideo2": "InternVideo2", "xcomposer2d5": "XComposer2D5", - "oryx": "Oryx", - "videochat2": "VideoChat2", - "llama_vision": "LlamaVision", - "aria": "Aria", - "ross": "Ross", - "vita": "VITA", } diff --git a/lmms_eval/models/vllm.py b/lmms_eval/models/vllm.py new file mode 100644 index 00000000..3b125718 --- /dev/null +++ b/lmms_eval/models/vllm.py @@ -0,0 +1,194 @@ +import asyncio +import base64 +import json +import os +import time +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from io import BytesIO +from multiprocessing import cpu_count +from typing import List, Optional, Tuple, Union + +import numpy as np +from accelerate import Accelerator, DistributedType +from decord import VideoReader, cpu +from loguru import logger as eval_logger +from PIL import Image +from tqdm import tqdm + +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +NUM_SECONDS_TO_SLEEP = 5 + +try: + from vllm import LLM, SamplingParams +except ImportError: + vllm = None + + +@register_model("vllm") +class VLLM(lmms): + def __init__( + self, + model_version: str = "Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.8, + batch_size: int = 1, + timeout: int = 60, + max_images: int = 32, + max_videos: int = 8, + max_audios: int = 8, + max_frame_num: int = 32, + threads: int = 16, # Threads to use for decoding visuals + trust_remote_code: Optional[bool] = True, + **kwargs, + ) -> None: + super().__init__() + # Manually set a image token for GPT4V so that we can search for it + # and split the text and image + # Here we just use the same token as llava for convenient + self.model_version = model_version + self.max_images = max_images + self.max_frame_num = max_frame_num + self.threads = threads + + accelerator = Accelerator() + self.client = LLM( + model=self.model_version, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={"image": max_images, "video": max_videos, "audio": max_audios}, + trust_remote_code=trust_remote_code, + ) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.accelerator = accelerator + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + self.device = self.accelerator.device + self.batch_size_per_gpu = int(batch_size) + + # Function to encode the image + def encode_image(self, image: Union[Image.Image, str]): + if isinstance(image, str): + img = Image.open(image).convert("RGB") + else: + img = image.copy() + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + + base64_str = base64.b64encode(byte_data).decode("utf-8") + return base64_str + + # Function to encode the video + def encode_video(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frame_num, dtype=int) + + # Ensure the last frame is included + if total_frame_num - 1 not in uniform_sampled_frames: + uniform_sampled_frames = np.append(uniform_sampled_frames, total_frame_num - 1) + + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + img = Image.fromarray(frame) + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode("utf-8") + base64_frames.append(base64_str) + + return base64_frames + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests) -> List[str]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + batch_size = self.batch_size_per_gpu + batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)] + for batch_requests in batched_requests: + batched_messages = [] + for idx in range(len(batch_requests)): + contexts, gen_kwargs, doc_to_visual, doc_id, task, split = batch_requests[idx].arguments + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if gen_kwargs["max_new_tokens"] > 4096: + gen_kwargs["max_new_tokens"] = 4096 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = 0.95 + + params = { + "temperature": gen_kwargs["temperature"], + "max_tokens": gen_kwargs["max_new_tokens"], + "top_p": gen_kwargs["top_p"], + } + sampling_params = SamplingParams(**params) + + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + if None in visuals: + visuals = [] + imgs = [] + else: + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + all_tasks = [] + with ThreadPoolExecutor(max_workers=self.threads) as executor: + for visual in visuals: + if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual): + all_tasks.append(executor.submit(self.encode_video, visual)) + elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual): + all_tasks.append(executor.submit(self.encode_image, visual)) + elif isinstance(visual, Image.Image): + all_tasks.append(executor.submit(self.encode_image, visual)) + + for task in all_tasks: + imgs.append(task.result()) + + messages = [{"role": "user", "content": []}] + # When there is no image token in the context, append the image to the text + messages[0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + + batched_messages.append(messages) + + response = self.client.chat(sampling_params=sampling_params, messages=batched_messages) + response_text = [o.outputs[0].text for o in response] + + assert len(response_text) == len(batch_requests) + res.extend(response_text) + pbar.update(len(batch_requests)) + + pbar.close() + return res + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "GPT4V not support" + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation") diff --git a/lmms_eval/tasks/mmmu/mmmu_val.yaml b/lmms_eval/tasks/mmmu/mmmu_val.yaml index a301f7cb..5d1394d2 100755 --- a/lmms_eval/tasks/mmmu/mmmu_val.yaml +++ b/lmms_eval/tasks/mmmu/mmmu_val.yaml @@ -13,4 +13,10 @@ metric_list: aggregation: !function utils.mmmu_aggregate_results higher_is_better: true +lmms_eval_specific_kwargs: + default: + prompt_type: "format" + multiple_choice_prompt: "Answer with the option's letter from the given choices directly." + open_ended_prompt: "Answer the question using a single word or phrase." + include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml b/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml index d9d7a21b..1b3c467a 100755 --- a/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml +++ b/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml @@ -7,6 +7,16 @@ doc_to_text: !function utils.mmmu_doc_to_text doc_to_target: "answer" # The return value of process_results will be used by metrics process_results: !function utils.mmmu_reasoning_process_results +# repeats: 8 +# filter_list: +# # - name: "pass@64" +# # filter: +# # - function: "take_first_k" +# # k: 64 +# - name: "pass@8" +# filter: +# - function: "take_first_k" +# k: 8 metric_list: - metric: mmmu_judge_acc @@ -16,11 +26,15 @@ metric_list: lmms_eval_specific_kwargs: default: prompt_type: "reasoning" - multiple_choice_prompt: "Please show step-by-step reasoning, and answer the question with option letter from given choices." - open_ended_prompt: "Please show step-by-step reasoning, and answer the question using a single word or phrase." + multiple_choice_prompt: "Please reason step by step, and answer the question with option letter from given choices in the format of Answer: