From 4e7238c83e7e821d009b2da7ff1c3bacd6536a1c Mon Sep 17 00:00:00 2001 From: Bo Li Date: Wed, 19 Feb 2025 22:49:56 +0000 Subject: [PATCH 1/3] Update API usage and add environment configuration - Adjust API client initialization for OpenAI and Azure. - Modify image encoding to handle size limits and maintain aspect ratio. - Add new `OpenAICompatible` model class. - Introduce retry mechanism for API requests with configurable retries. - Update `.gitignore` to exclude `.env` and `scripts/`. .gitignore: - Exclude `.env` file for security. - Ensure no scripts directory is tracked. lmms_eval/models/gpt4v.py: - Refactor API client initialization to use new OpenAI and Azure clients. - Update image encoding to handle size limits with resizing logic. - Adjust retry logic for API calls, reducing sleep time. lmms_eval/models/openai_compatible.py: - Create new `OpenAICompatible` model class with similar structure. - Implement encoding functions for images and videos. - Integrate environment variable loading and persistent response caching. miscs/model_dryruns/openai_compatible.sh: - Add sample script for running the new model. --- .gitignore | 3 +- lmms_eval/models/gpt4v.py | 124 +++++++------ lmms_eval/models/openai_compatible.py | 225 +++++++++++++++++++++++ miscs/model_dryruns/openai_compatible.sh | 13 ++ 4 files changed, 305 insertions(+), 60 deletions(-) create mode 100644 lmms_eval/models/openai_compatible.py create mode 100644 miscs/model_dryruns/openai_compatible.sh diff --git a/.gitignore b/.gitignore index ea20fe526..a254579f2 100755 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ VATEX/ lmms_eval/tasks/vatex/__pycache__/utils.cpython-310.pyc lmms_eval/tasks/mlvu/__pycache__/utils.cpython-310.pyc -scripts/ \ No newline at end of file +scripts/ +.env \ No newline at end of file diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index af313a573..662b73882 100755 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -4,11 +4,12 @@ import time from copy import deepcopy from io import BytesIO -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import requests as url_requests from accelerate import Accelerator, DistributedType +from openai import AzureOpenAI, OpenAI from tqdm import tqdm from lmms_eval.api.instance import Instance @@ -20,26 +21,19 @@ except ImportError: pass +from loguru import logger as eval_logger from PIL import Image API_TYPE = os.getenv("API_TYPE", "openai") -NUM_SECONDS_TO_SLEEP = 30 -from loguru import logger as eval_logger - +NUM_SECONDS_TO_SLEEP = 10 if API_TYPE == "openai": API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") - headers = { - "Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json", - } + elif API_TYPE == "azure": API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") - headers = { - "api-key": API_KEY, - "Content-Type": "application/json", - } + API_VERSION = os.getenv("AZURE_API_VERSION", "2023-07-01-preview") @register_model("gpt4v") @@ -52,6 +46,7 @@ def __init__( timeout: int = 120, continual_mode: bool = False, response_persistent_folder: str = None, + max_size_in_mb: int = 20, **kwargs, ) -> None: super().__init__() @@ -80,6 +75,11 @@ def __init__( self.response_cache = {} self.cache_mode = "start" + if API_TYPE == "openai": + self.client = OpenAI(api_key=API_KEY) + elif API_TYPE == "azure": + self.client = AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL, api_version=API_VERSION) + accelerator = Accelerator() # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." if accelerator.num_processes > 1: @@ -94,13 +94,30 @@ def __init__( self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes + self.max_size_in_mb = max_size_in_mb self.device = self.accelerator.device # Function to encode the image - def encode_image(self, image: Image): + def encode_image(self, image: Union[Image.Image, str]): + max_size = self.max_size_in_mb * 1024 * 1024 # 20MB in bytes + if isinstance(image, str): + img = Image.open(image).convert("RGB") + else: + img = image.copy() + output_buffer = BytesIO() - image.save(output_buffer, format="PNG") + img.save(output_buffer, format="PNG") byte_data = output_buffer.getvalue() + + # If image is too large, resize it while maintaining aspect ratio + while len(byte_data) > max_size and img.size[0] > 100 and img.size[1] > 100: + new_size = (int(img.size[0] * 0.75), int(img.size[1] * 0.75)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode("utf-8") return base64_str @@ -150,39 +167,30 @@ def generate_until(self, requests) -> List[str]: continue visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] - visuals = self.flatten(visuals) - imgs = [] # multiple images or frames for video - for visual in visuals: - if self.modality == "image": - img = self.encode_image(visual) - imgs.append(img) - elif self.modality == "video": - frames = self.encode_video(visual, self.max_frames_num) - imgs.extend(frames) + if None in visuals: + visuals = [] + imgs = [] + else: + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + for visual in visuals: + if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual): + frames = self.encode_video(visual, self.max_frames_num) + imgs.extend(frames) + elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual): + img = self.encode_image(visual) + imgs.append(img) + elif isinstance(visual, Image.Image): + img = self.encode_image(visual) + imgs.append(img) payload = {"messages": []} - if API_TYPE == "openai": - payload["model"] = self.model_version - - response_json = {"role": "user", "content": []} - # When there is no image token in the context, append the image to the text - if self.image_token not in contexts: - payload["messages"].append(deepcopy(response_json)) - payload["messages"][0]["content"].append({"type": "text", "text": contexts}) - for img in imgs: - payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - else: - contexts = contexts.split(self.image_token) - for idx, img in enumerate(imgs): - payload["messages"].append(deepcopy(response_json)) - payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) - payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - - # If n image tokens are in the contexts - # contexts will be splitted into n+1 chunks - # Manually add it into the payload - payload["messages"].append(deepcopy(response_json)) - payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]}) + payload["model"] = self.model_version + + payload["messages"].append({"role": "user", "content": []}) + payload["messages"][0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 @@ -198,26 +206,24 @@ def generate_until(self, requests) -> List[str]: payload["max_tokens"] = gen_kwargs["max_new_tokens"] payload["temperature"] = gen_kwargs["temperature"] - for attempt in range(5): + MAX_RETRIES = 5 + for attempt in range(MAX_RETRIES): try: - response = url_requests.post(API_URL, headers=headers, json=payload, timeout=self.timeout) - response_data = response.json() - - response_text = response_data["choices"][0]["message"]["content"].strip() + response = self.client.chat.completions.create(**payload) + response_text = response.choices[0].message.content break # If successful, break out of the loop except Exception as e: - try: - error_msg = response.json() - except: - error_msg = "" + error_msg = str(e) + eval_logger.info(f"Attempt {attempt + 1}/{MAX_RETRIES} failed with error: {error_msg}") - eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.\nReponse: {error_msg}") - if attempt <= 5: - time.sleep(NUM_SECONDS_TO_SLEEP) - else: # If this was the last attempt, log and return empty string - eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.\nResponse: {response.json()}") + # On last attempt, log error and set empty response + if attempt == MAX_RETRIES - 1: + eval_logger.error(f"All {MAX_RETRIES} attempts failed. Last error: {error_msg}") response_text = "" + else: + time.sleep(NUM_SECONDS_TO_SLEEP) + res.append(response_text) pbar.update(1) diff --git a/lmms_eval/models/openai_compatible.py b/lmms_eval/models/openai_compatible.py new file mode 100644 index 000000000..50dfb4414 --- /dev/null +++ b/lmms_eval/models/openai_compatible.py @@ -0,0 +1,225 @@ +import base64 +import json +import os +import time +from copy import deepcopy +from io import BytesIO +from typing import List, Tuple, Union + +import numpy as np +import requests as url_requests +from accelerate import Accelerator, DistributedType +from tqdm import tqdm + +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +try: + from decord import VideoReader, cpu +except ImportError: + pass + +from dotenv import find_dotenv, load_dotenv +from loguru import logger as eval_logger +from openai import OpenAI +from PIL import Image + +load_dotenv(verbose=True) + +@register_model("openai_compatible") +class OpenAICompatible(lmms): + def __init__( + self, + model_version: str = "grok-2-latest", + timeout: int = 120, + max_retries: int = 5, + max_size_in_mb: int = 20, + continual_mode: bool = False, + response_persistent_folder: str = None, + **kwargs, + ) -> None: + super().__init__() + self.model_version = model_version + self.timeout = timeout + self.max_retries = max_retries + self.max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image + self.continual_mode = continual_mode + if self.continual_mode: + if response_persistent_folder is None: + raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") + + os.makedirs(response_persistent_folder, exist_ok=True) + self.response_persistent_folder = response_persistent_folder + self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") + + if os.path.exists(self.response_persistent_file): + with open(self.response_persistent_file, "r") as f: + self.response_cache = json.load(f) + self.cache_mode = "resume" + else: + self.response_cache = {} + self.cache_mode = "start" + + self.client = OpenAI(api_key=os.getenv("OPENAI_COMPATIBLE_API_KEY"), base_url=os.getenv("OPENAI_COMPATIBLE_API_URL")) + + accelerator = Accelerator() + # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.accelerator = accelerator + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + self.device = self.accelerator.device + + # Function to encode the image + def encode_image(self, image: Union[Image.Image, str]): + max_size = self.max_size_in_mb * 1024 * 1024 # 20MB in bytes + if isinstance(image, str): + img = Image.open(image).convert("RGB") + else: + img = image.copy() + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + + # If image is too large, resize it while maintaining aspect ratio + while len(byte_data) > max_size and img.size[0] > 100 and img.size[1] > 100: + new_size = (int(img.size[0] * 0.75), int(img.size[1] * 0.75)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + + base64_str = base64.b64encode(byte_data).decode("utf-8") + return base64_str + + # Function to encode the video + def encode_video(self, video_path, for_get_frames_num): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int) + + # Ensure the last frame is included + if total_frame_num - 1 not in uniform_sampled_frames: + uniform_sampled_frames = np.append(uniform_sampled_frames, total_frame_num - 1) + + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + img = Image.fromarray(frame) + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode("utf-8") + base64_frames.append(base64_str) + + return base64_frames + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests) -> List[str]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + if self.continual_mode is True and self.cache_mode == "resume": + doc_uuid = f"{task}___{split}___{doc_id}" + if doc_uuid in self.response_cache: + response_text = self.response_cache[doc_uuid] + if response_text: + res.append(response_text) + pbar.update(1) + continue + + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + if None in visuals: + visuals = [] + imgs = [] + else: + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + for visual in visuals: + if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual): + frames = self.encode_video(visual, self.max_frames_num) + imgs.extend(frames) + elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual): + img = self.encode_image(visual) + imgs.append(img) + elif isinstance(visual, Image.Image): + img = self.encode_image(visual) + imgs.append(img) + + payload = {"messages": []} + payload["model"] = self.model_version + + # When there is no image token in the context, append the image to the text + payload["messages"].append({"role": "user", "content": []}) + payload["messages"][0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if gen_kwargs["max_new_tokens"] > 4096: + gen_kwargs["max_new_tokens"] = 4096 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + payload["max_tokens"] = gen_kwargs["max_new_tokens"] + payload["temperature"] = gen_kwargs["temperature"] + + for attempt in range(self.max_retries): + try: + response = self.client.chat.completions.create(**payload) + response_text = response.choices[0].message.content + break # If successful, break out of the loop + + except Exception as e: + error_msg = str(e) + eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}") + + # On last attempt, log error and set empty response + if attempt == self.max_retries - 1: + eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}") + response_text = "" + else: + time.sleep(self.timeout) + + res.append(response_text) + pbar.update(1) + + if self.continual_mode is True: # Cache the response + doc_uuid = f"{task}___{split}___{doc_id}" + self.response_cache[doc_uuid] = response_text + with open(self.response_persistent_file, "w") as f: + json.dump(self.response_cache, f) + + pbar.close() + return res + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation for OpenAI compatible models") + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + raise NotImplementedError("TODO: Implement loglikelihood for OpenAI compatible models") diff --git a/miscs/model_dryruns/openai_compatible.sh b/miscs/model_dryruns/openai_compatible.sh new file mode 100644 index 000000000..c93ec29f2 --- /dev/null +++ b/miscs/model_dryruns/openai_compatible.sh @@ -0,0 +1,13 @@ +# cd ~/prod/lmms-eval-public +# pip3 install -e . +# pip3 install openai + +python3 -m lmms_eval \ + --model openai_compatible \ + --model_args model_version=grok-2-vision-1212 \ + --tasks mme,mmmu_val \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix openai_compatible \ + --output_path ./logs \ + --limit=8 \ No newline at end of file From e0aa1c727f4fb7a8c31c71ec457e63644a7fa72b Mon Sep 17 00:00:00 2001 From: Bo Li Date: Wed, 19 Feb 2025 23:13:41 +0000 Subject: [PATCH 2/3] Improve code readability and organization - Remove unused import for deepcopy in `openai_compatible.py`. - Add a blank line for better separation of code sections. - Adjust comment formatting for `max_size_in_mb` for consistency. - Ensure consistent spacing around comments. File: `lmms_eval/models/openai_compatible.py` - Removed `deepcopy` import: cleaned up unnecessary code. - Added blank line after `load_dotenv`: improved readability. - Reformatted comment on `max_size_in_mb`: enhanced clarity. - Removed extra blank line before `Accelerator`: tightened spacing. --- lmms_eval/models/openai_compatible.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmms_eval/models/openai_compatible.py b/lmms_eval/models/openai_compatible.py index 50dfb4414..7be1cd91e 100644 --- a/lmms_eval/models/openai_compatible.py +++ b/lmms_eval/models/openai_compatible.py @@ -2,7 +2,6 @@ import json import os import time -from copy import deepcopy from io import BytesIO from typing import List, Tuple, Union @@ -27,6 +26,7 @@ load_dotenv(verbose=True) + @register_model("openai_compatible") class OpenAICompatible(lmms): def __init__( @@ -43,7 +43,7 @@ def __init__( self.model_version = model_version self.timeout = timeout self.max_retries = max_retries - self.max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image + self.max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image self.continual_mode = continual_mode if self.continual_mode: if response_persistent_folder is None: @@ -62,7 +62,7 @@ def __init__( self.cache_mode = "start" self.client = OpenAI(api_key=os.getenv("OPENAI_COMPATIBLE_API_KEY"), base_url=os.getenv("OPENAI_COMPATIBLE_API_URL")) - + accelerator = Accelerator() # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." if accelerator.num_processes > 1: From 00942f2d6807c8a153c16356026a71e0cdf4ed43 Mon Sep 17 00:00:00 2001 From: kcz358 Date: Thu, 20 Feb 2025 05:47:33 +0000 Subject: [PATCH 3/3] Fix init --- lmms_eval/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index 1a3dbd617..3927b2944 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -39,6 +39,7 @@ "minimonkey": "MiniMonkey", "moviechat": "MovieChat", "mplug_owl_video": "mplug_Owl", + "openai_compatible": "OpenAICompatible", "oryx": "Oryx", "phi3v": "Phi3v", "qwen2_5_vl": "Qwen2_5_VL",