Skip to content

Commit

Permalink
use threads to encode visuals
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed Feb 20, 2025
1 parent f86961b commit 3e569e5
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions lmms_eval/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from io import BytesIO
from multiprocessing import cpu_count
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from accelerate import Accelerator, DistributedType
from decord import VideoReader, cpu
from loguru import logger as eval_logger
from openai import AsyncOpenAI, OpenAI
from PIL import Image
from tqdm import tqdm

Expand All @@ -23,7 +23,7 @@
NUM_SECONDS_TO_SLEEP = 5

try:
import vllm
from vllm import LLM, SamplingParams
except ImportError:
vllm = None

Expand All @@ -38,6 +38,11 @@ def __init__(
batch_size: int = 1,
timeout: int = 60,
max_images: int = 32,
max_videos: int = 8,
max_audios: int = 8,
max_frame_num: int = 32,
threads: int = 16, # Threads to use for decoding visuals
trust_remote_code: Optional[bool] = True,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -46,9 +51,17 @@ def __init__(
# Here we just use the same token as llava for convenient
self.model_version = model_version
self.max_images = max_images
self.max_frame_num = max_frame_num
self.threads = threads

accelerator = Accelerator()
self.client = vllm.LLM(model=self.model_version, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=gpu_memory_utilization, limit_mm_per_prompt={"image": max_images})
self.client = LLM(
model=self.model_version,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
limit_mm_per_prompt={"image": max_images, "video": max_videos, "audio": max_audios},
trust_remote_code=trust_remote_code,
)
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
self.accelerator = accelerator
Expand Down Expand Up @@ -79,10 +92,10 @@ def encode_image(self, image: Union[Image.Image, str]):
return base64_str

# Function to encode the video
def encode_video(self, video_path, max_frames_num=8):
def encode_video(self, video_path):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frame_num, dtype=int)

# Ensure the last frame is included
if total_frame_num - 1 not in uniform_sampled_frames:
Expand Down Expand Up @@ -133,7 +146,7 @@ def generate_until(self, requests) -> List[str]:
"max_tokens": gen_kwargs["max_new_tokens"],
"top_p": gen_kwargs["top_p"],
}
sampling_params = vllm.SamplingParams(**params)
sampling_params = SamplingParams(**params)

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
if None in visuals:
Expand All @@ -142,16 +155,18 @@ def generate_until(self, requests) -> List[str]:
else:
visuals = self.flatten(visuals)
imgs = [] # multiple images or frames for video
for visual in visuals:
if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual):
frames = self.encode_video(visual)
imgs.extend(frames)
elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual):
img = self.encode_image(visual)
imgs.append(img)
elif isinstance(visual, Image.Image):
img = self.encode_image(visual)
imgs.append(img)
all_tasks = []
with ThreadPoolExecutor(max_workers=self.threads) as executor:
for visual in visuals:
if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual):
all_tasks.append(executor.submit(self.encode_video, visual))
elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual):
all_tasks.append(executor.submit(self.encode_image, visual))
elif isinstance(visual, Image.Image):
all_tasks.append(executor.submit(self.encode_image, visual))

for task in all_tasks:
imgs.append(task.result())

messages = [{"role": "user", "content": []}]
# When there is no image token in the context, append the image to the text
Expand Down

0 comments on commit 3e569e5

Please sign in to comment.