diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index df481097..c53856bf 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -62,6 +62,7 @@ "vllm": "VLLM", "xcomposer2_4KHD": "XComposer2_4KHD", "xcomposer2d5": "XComposer2D5", + "egogpt": "EgoGPT", } diff --git a/lmms_eval/models/egogpt.py b/lmms_eval/models/egogpt.py new file mode 100644 index 00000000..f75bd5df --- /dev/null +++ b/lmms_eval/models/egogpt.py @@ -0,0 +1,472 @@ +import copy +import json +import logging +import math +import re +import warnings +from datetime import timedelta +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import transformers +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from decord import VideoReader, cpu +from packaging import version +from tqdm import tqdm +from transformers import AutoConfig + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +# Suppress warnings +warnings.filterwarnings("ignore") + +# Configure logging +eval_logger = logging.getLogger("lmms-eval") + +# Enable TF32 for CUDA +torch.backends.cuda.matmul.allow_tf32 = True + +# Import LLaVA modules +try: + import copy + import os + import re + import sys + import warnings + + import numpy as np + import requests + import soundfile as sf + import torch + import whisper + from decord import VideoReader, cpu + from egogpt.constants import ( + DEFAULT_IMAGE_TOKEN, + DEFAULT_SPEECH_TOKEN, + IGNORE_INDEX, + IMAGE_TOKEN_INDEX, + SPEECH_TOKEN_INDEX, + ) + from egogpt.conversation import SeparatorStyle, conv_templates + from egogpt.mm_utils import get_model_name_from_path, process_images + from egogpt.model.builder import load_pretrained_model + from PIL import Image + from scipy.signal import resample +except ImportError as e: + eval_logger.debug(f"egogpt is not installed. Please install egogpt to use this model.\nError: {e}") + + +# Determine best attention implementation +if version.parse(torch.__version__) >= version.parse("2.1.2"): + best_fit_attn_implementation = "sdpa" +else: + best_fit_attn_implementation = "eager" + + +@register_model("egogpt") +class EgoGPT(lmms): + """ + EgoGPT Model + """ + + def __init__( + self, + pretrained: str = "checkpoints/egogpt_IT_12k_1126_zero3", + truncation: Optional[bool] = True, + device: Optional[str] = "cuda:0", + batch_size: Optional[Union[int, str]] = 1, + model_name: Optional[str] = None, + attn_implementation: Optional[str] = best_fit_attn_implementation, + device_map: Optional[str] = "cuda:0", + conv_template: Optional[str] = "qwen_1_5", + use_cache: Optional[bool] = True, + truncate_context: Optional[bool] = False, # whether to truncate the context in generation, set it False for LLaVA-1.6 + customized_config: Optional[str] = None, # ends in json + max_frames_num: Optional[int] = 32, + mm_spatial_pool_stride: Optional[int] = 2, + mm_spatial_pool_mode: Optional[str] = "bilinear", + token_strategy: Optional[str] = "single", # could be "single" or "multiple", "multiple" denotes adding multiple tokens for each frame + video_decode_backend: str = "decord", + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + egogpt_model_args = {} + if attn_implementation is not None: + egogpt_model_args["attn_implementation"] = attn_implementation + + self.pretrained = pretrained + self.token_strategy = token_strategy + self.max_frames_num = max_frames_num + self.mm_spatial_pool_stride = mm_spatial_pool_stride + self.mm_spatial_pool_mode = mm_spatial_pool_mode + self.video_decode_backend = video_decode_backend + # Try to load the model with the multimodal argument + self._tokenizer, self._model, self._max_length = load_pretrained_model(pretrained, device_map=self.device_map, **egogpt_model_args) + self._image_processor = self._model.get_vision_tower().image_processor + self._config = self._model.config + self.model.eval() + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + self.use_cache = use_cache + self.truncate_context = truncate_context + assert self.batch_size_per_gpu == 1 + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._world_size = 1 + + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + try: + return self.tokenizer.decode(tokens) + except: + return self.tokenizer.decode([tokens]) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + raise NotImplementedError("Loglikelihood is not implemented for EgoGPT") + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def split_text(self, text, keywords): + pattern = "(" + "|".join(map(re.escape, keywords)) + ")" + parts = re.split(pattern, text) + parts = [part for part in parts if part] + return parts + + def load_video(self, video_path=None, audio_path=None, max_frames_num=16, fps=1, task_name=None): + if audio_path is not None: + speech, sample_rate = sf.read(audio_path) + if sample_rate != 16000: + target_length = int(len(speech) * 16000 / sample_rate) + speech = resample(speech, target_length) + if speech.ndim > 1: + speech = np.mean(speech, axis=1) + # max_length = 480000 + speech = whisper.pad_or_trim(speech.astype(np.float32)) + speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0) + speech_lengths = torch.LongTensor([speech.shape[0]]) + else: + speech = torch.zeros(3000, 128) + speech_lengths = torch.LongTensor([3000]) + + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + total_frame_num = len(vr) + avg_fps = round(vr.get_avg_fps() / fps) + frame_idx = [i for i in range(0, total_frame_num, avg_fps)] + frame_time = [i / avg_fps for i in frame_idx] + + if max_frames_num > 0: + if len(frame_idx) > max_frames_num: + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + if task_name == "egoplan": + # add current ovservation frame + frame_idx.append(total_frame_num - 1) + video = vr.get_batch(frame_idx).asnumpy() + return video, speech, speech_lengths + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + metadata = requests[0].metadata + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + + origin_image_aspect_ratio = getattr(self._config, "image_aspect_ratio", None) + + for chunk in chunks: + batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_id, batched_task, batched_split = zip(*chunk) + task = batched_task[0] + split = batched_split[0] + batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] # [B, N] + assert len(batched_visuals) == 1 + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + if "until" in gen_kwargs: + gen_kwargs.pop("until") + + question_input = [] + # import ipdb; ipdb.set_trace() + for visual, context in zip(batched_visuals, batched_contexts): + if origin_image_aspect_ratio is not None and self._config.image_aspect_ratio != origin_image_aspect_ratio: + self._config.image_aspect_ratio = origin_image_aspect_ratio + eval_logger.info(f"Resetting image aspect ratio to {origin_image_aspect_ratio}") + + if visual is None or visual == []: # for text-only tasks. + visual = None + task_type = "text" + placeholder_count = 0 + image_tensor = None + else: + if len(visual) > 1 or "image_aspect_ratio" not in self._config.__dict__: # for multi image case, we treat per image aspect ratio as "pad" by default. + self._config.image_aspect_ratio = getattr(gen_kwargs, "image_aspect_ratio", "pad") + eval_logger.info(f"In Multi-Image setting, image aspect ratio: {self._config.image_aspect_ratio}") + + if "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: # overwrite logic for video task with multiple static image frames + assert type(visual) == list, "sample_frames must be specified for video task" + sample_indices = np.linspace(0, len(visual) - 1, metadata["sample_frames"], dtype=int) + visual = [visual[i] for i in sample_indices] + assert len(visual) == metadata["sample_frames"] + + image_tensor = process_images(visual, self._image_processor, self._config) + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + image_tensor = [image_tensor] + task_type = "video" + placeholder_count = 1 + + elif type(visual[0]) == PIL.Image.Image: # For image, multi-image tasks + image_tensor = process_images(visual, self._image_processor, self._config) + speech = torch.zeros(3000, 128) + speech_lengths = torch.LongTensor([3000]) + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + + task_type = "image" + placeholder_count = len(visual) if isinstance(visual, list) else 1 + + elif type(visual[0]) == str: # For video task + image_tensor = [] + try: + if self.video_decode_backend == "decord": + if "egoplan" in visual[0]: + task_name = "egoplan" + else: + task_name = None + frames, speech, speech_lengths = self.load_video(video_path=visual[0], max_frames_num=self.max_frames_num, task_name=task_name) + else: + raise NotImplementedError("Only decord backend is supported for video task") + processed_frames = self._image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda() + processed_frames = processed_frames.half() + image_tensor.append(processed_frames) + image_sizes = [frames[0].size] + except Exception as e: + eval_logger.error(f"Error {e} in loading video") + image_tensor = None + + task_type = "video" + placeholder_count = len(frames) if self.token_strategy == "multiple" else 1 + if DEFAULT_IMAGE_TOKEN not in context: + question = DEFAULT_IMAGE_TOKEN + "\n" + context + else: + question = context + speech = torch.stack([speech]).to(self.device).half() + # This is much safer for llama3, as we now have some object type in it + if "llama_3" in self.conv_template: + conv = copy.deepcopy(conv_templates[self.conv_template]) + else: + conv = conv_templates[self.conv_template].copy() + + if utils.is_json(question): # conversational question input + question = json.loads(question) + for idx, item in enumerate(question): + role = conv.roles[idx % 2] + message = item["value"] + conv.append_message(role, message) + + assert len(conv.messages) % 2 == 1 + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + else: # only simple string for question + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + + # preconfigure gen_kwargs with defaults + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "do_sample" not in gen_kwargs: + gen_kwargs["do_sample"] = False + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + parts = self.split_text(prompt_question, ["", ""]) + input_ids = [] + for part in parts: + if "" == part: + input_ids += [IMAGE_TOKEN_INDEX] + elif "" == part: + input_ids += [SPEECH_TOKEN_INDEX] + else: + input_ids += self.tokenizer(part).input_ids + + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(self.device) + input_ids_list = [input_ids] + pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) + attention_masks = input_ids.ne(pad_token_ids).to(self.device) + input_ids = torch.tensor(input_ids, dtype=torch.long).squeeze(0).to(self.device) + if task_type == "image": + gen_kwargs["image_sizes"] = [batched_visuals[0][idx].size for idx in range(len(batched_visuals[0]))] + elif task_type == "video": + gen_kwargs["modalities"] = ["video"] + self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride + self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode + gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id + + # These steps are not in LLaVA's original code, but are necessary for generation to work + # TODO: attention to this major generation step... + if "image_aspect_ratio" in gen_kwargs.keys(): + gen_kwargs.pop("image_aspect_ratio") + try: + with torch.inference_mode(): + cont = self.model.generate(input_ids, images=image_tensor, speech=speech, speech_lengths=speech_lengths, **gen_kwargs) + + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) + except Exception as e: + raise e + + text_outputs = [response.strip() for response in text_outputs] + res.extend(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res + + def generate_until_multi_round(self, requests: List[Instance]) -> List[str]: + raise NotImplementedError("generate_until_multi_round is not implemented for EgoGPT") diff --git a/lmms_eval/tasks/egoplan/egoplan.yaml b/lmms_eval/tasks/egoplan/egoplan.yaml new file mode 100644 index 00000000..a4aec078 --- /dev/null +++ b/lmms_eval/tasks/egoplan/egoplan.yaml @@ -0,0 +1,43 @@ +dataset_path: EgoLife-v1/EgoPlan +dataset_kwargs: + token: True + cache_dir: egoplan + video: True + # From_YouTube: True +task: egoplan +test_split: validation +output_type: generate_until +doc_to_visual: !function utils.egoplan_doc_to_visual +doc_to_text: !function utils.egoplan_doc_to_text +doc_to_target: "answer" +generation_kwargs: + max_new_tokens: 4096 + temperature: 0 + top_p: 1.0 + num_beams: 1 + do_sample: false +# The return value of process_results will be used by metrics +process_results: !function utils.egoplan_process_results +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: egoplan_mcq_accuracy + aggregation: !function utils.egoplan_aggregate_results + higher_is_better: true +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "\nAnswer with the option's letter from the given choices directly." + gpt4v: + pre_prompt: "" + post_prompt: "\nAnswer the question with A, B, C, or D." + # qwen_vl: + # pre_prompt: "" + # post_prompt: " Answer:" + # otterhd: + # pre_prompt: "" + # post_prompt: " Answer:" + xcomposer2_4khd: + pre_prompt: "[UNUSED_TOKEN_146]user\n" + post_prompt: " Answer this question with A, B, C, or D.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n" +metadata: + version: 0.0 diff --git a/lmms_eval/tasks/egoplan/utils.py b/lmms_eval/tasks/egoplan/utils.py new file mode 100644 index 00000000..35e5c623 --- /dev/null +++ b/lmms_eval/tasks/egoplan/utils.py @@ -0,0 +1,207 @@ +import datetime +import json +import os +import re +import sys +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Union + +import cv2 +import numpy as np +import yaml +from loguru import logger as eval_logger + +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file + +# with open(Path(__file__).parent / "_default_template_yaml", "r") as f: +# raw_data = f.readlines() +# safe_data = [] +# for i, line in enumerate(raw_data): +# # remove function definition since yaml load cannot handle it +# if "!function" not in line: +# safe_data.append(line) + +# config = yaml.safe_load("".join(safe_data)) + +hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") +# cache_dir = os.path.join(hf_home, cache_dir) +# base_cache_dir = config["dataset_kwargs"]["cache_dir"] +base_cache_dir = os.path.expanduser(hf_home) +with open(Path(__file__).parent / "egoplan.yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) +cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"] + + +def parse_subtitle_time(time_str): + h, m, s_ms = time_str.split(":") + s, ms = s_ms.split(",") + return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 + + +def load_subtitles(subtitle_path): + subtitles = {} + with open(subtitle_path, "r", encoding="utf-8") as file: + content = file.read().split("\n\n") + for section in content: + if section.strip(): + lines = section.split("\n") + if len(lines) >= 3: + time_range = lines[1].split(" --> ") + start_time = parse_subtitle_time(time_range[0]) + end_time = parse_subtitle_time(time_range[1]) + text = " ".join(line for line in lines[2:]) + subtitles[(start_time, end_time)] = text + return subtitles + + +def convert_time_to_frame(time_in_seconds, fps): + return int(time_in_seconds * fps) + + +def extract_subtitles(video_path, subtitle_path): + video = cv2.VideoCapture(video_path) + fps = video.get(cv2.CAP_PROP_FPS) + total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + subtitles = load_subtitles(subtitle_path) + + subtitle_frames = [] + for (start_time, end_time), text in subtitles.items(): + start_frame = convert_time_to_frame(start_time, fps) + end_frame = convert_time_to_frame(end_time, fps) + subtitle_frames.append((start_frame, end_frame, text)) + + return subtitle_frames, total_frame + + +def parse_subtitle_time(time_str): + h, m, s_ms = time_str.split(":") + s, ms = s_ms.split(",") + return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 + + +def load_subtitles(subtitle_path): + subtitles = {} + with open(subtitle_path, "r", encoding="utf-8") as file: + content = file.read().split("\n\n") + for section in content: + if section.strip(): + lines = section.split("\n") + if len(lines) >= 3: + time_range = lines[1].split(" --> ") + start_time = parse_subtitle_time(time_range[0]) + end_time = parse_subtitle_time(time_range[1]) + text = " ".join(line for line in lines[2:]) + subtitles[(start_time, end_time)] = text + return subtitles + + +def convert_time_to_frame(time_in_seconds, fps): + return int(time_in_seconds * fps) + + +def extract_subtitles(video_path, subtitle_path): + video = cv2.VideoCapture(video_path) + fps = video.get(cv2.CAP_PROP_FPS) + total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + subtitles = load_subtitles(subtitle_path) + + subtitle_frames = [] + for (start_time, end_time), text in subtitles.items(): + start_frame = convert_time_to_frame(start_time, fps) + end_frame = convert_time_to_frame(end_time, fps) + subtitle_frames.append((start_frame, end_frame, text)) + + return subtitle_frames, total_frame + + +def egoplan_doc_to_visual(doc): + cache_dir = os.path.join(base_cache_dir, cache_name) + video_path = str(doc["sample_id"]) + ".mp4" + video_path = os.path.join(cache_dir, video_path) + if os.path.exists(video_path): + video_path = video_path + elif os.path.exists(video_path.replace("mp4", "MP4")): + video_path = video_path.replace("mp4", "MP4") + elif os.path.exists(video_path.replace("mp4", "mkv")): + video_path = video_path.replace("mp4", "mkv") + else: + sys.exit(f"video path:{video_path} does not exist, please check") + return [video_path] + + +def egoplan_doc_to_text(doc, lmms_eval_specific_kwargs=None): + task_goal = doc["task_goal"] + if "goal" in task_goal: + task_goal = task_goal.split("to", 1)[1].strip() + words = task_goal.split() + if words[0].endswith("ing"): + question_pattern = ( + "I am tasked with {}. " + "The task's progress is demonstrated in the provided video. " + "My current field of view is shown in the provided image. " + "What should be my next action? " + "Please output the most reasonable action you think, expressed in a short phrase." + ) + else: + question_pattern = ( + "My current task is to {}. " + "The task's progress is demonstrated in the provided video. " + "My current field of view is shown in the provided image. " + "What should be my next action? " + "Please output the most reasonable action you think, expressed in a short phrase." + ) + question = question_pattern.format(task_goal) + + candidates = [] + for choice_idx in ["A", "B", "C", "D"]: + question += "\n" + f"{choice_idx}. " + (doc[f"choice_{choice_idx.lower()}"]) + post_prompt = "\nAnswer with the option's letter from the given choices" + + return f"{question}{post_prompt}" + + +def extract_characters_regex(s): + s = s.strip() + answer_prefixes = [ + "The best answer is", + "The correct answer is", + "The answer is", + "The answer", + "The best option is" "The correct option is", + "Best answer:" "Best option:", + ] + for answer_prefix in answer_prefixes: + s = s.replace(answer_prefix, "") + + if len(s.split()) > 10 and not re.search("[ABCD]", s): + return "" + + matches = re.search(r"[ABCD]", s) + if matches is None: + return "" + return matches[0] + + +def egoplan_process_results(doc, results): + pred = results[0] + pred_ans = extract_characters_regex(pred) + # gt_ans = doc["answer"].lower().strip().replace(".", "") + doc["pred_answer"] = pred_ans + data_dict = doc.copy() + return {f"egoplan_mcq_accuracy": data_dict} + + +def egoplan_aggregate_results(results): + correct_num = 0 + for result in results: + if result["pred_answer"] == result["golden_choice_idx"]: + correct_num += 1 + question_num = len(results) + accuracy = correct_num / question_num + return accuracy diff --git a/lmms_eval/tasks/egothink/_default_template_yaml b/lmms_eval/tasks/egothink/_default_template_yaml new file mode 100644 index 00000000..546be29a --- /dev/null +++ b/lmms_eval/tasks/egothink/_default_template_yaml @@ -0,0 +1,7 @@ +dataset_path: EgoLife-v1/Egothink +dataset_kwargs: + token: True +test_split: test +metadata: + version: 0.0 + gpt_eval_model_name: "gpt-4" \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink.yaml b/lmms_eval/tasks/egothink/egothink.yaml new file mode 100644 index 00000000..f8bbfe74 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink.yaml @@ -0,0 +1,14 @@ +group: egothink +task: + - egothink_activity + - egothink_affordance + - egothink_assistance + - egothink_navigation + - egothink_attribute + - egothink_comparing + - egothink_counting + - egothink_existence + - egothink_forecasting + - egothink_location + - egothink_situated + - egothink_spatial diff --git a/lmms_eval/tasks/egothink/egothink_activity.yaml b/lmms_eval/tasks/egothink/egothink_activity.yaml new file mode 100644 index 00000000..4df6756c --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_activity.yaml @@ -0,0 +1,24 @@ +dataset_name: "Activity" +task: "egothink_activity" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_affordance.yaml b/lmms_eval/tasks/egothink/egothink_affordance.yaml new file mode 100644 index 00000000..3e0cae85 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_affordance.yaml @@ -0,0 +1,24 @@ +dataset_name: "Object_affordance" +task: "egothink_affordance" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_assistance.yaml b/lmms_eval/tasks/egothink/egothink_assistance.yaml new file mode 100644 index 00000000..81b4e0e8 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_assistance.yaml @@ -0,0 +1,24 @@ +dataset_name: "Planning_assistance" +task: "egothink_assistance" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 300 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in a detailed and helpful way. USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_attribute.yaml b/lmms_eval/tasks/egothink/egothink_attribute.yaml new file mode 100644 index 00000000..7466e874 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_attribute.yaml @@ -0,0 +1,24 @@ +dataset_name: "Object_attribute" +task: "egothink_attribute" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual content, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_comparing.yaml b/lmms_eval/tasks/egothink/egothink_comparing.yaml new file mode 100644 index 00000000..c91399c9 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_comparing.yaml @@ -0,0 +1,24 @@ +dataset_name: "Reasoning_comparing" +task: "egothink_comparing" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_counting.yaml b/lmms_eval/tasks/egothink/egothink_counting.yaml new file mode 100644 index 00000000..fcc0246e --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_counting.yaml @@ -0,0 +1,24 @@ +dataset_name: "Reasoning_counting" +task: "egothink_counting" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_existence.yaml b/lmms_eval/tasks/egothink/egothink_existence.yaml new file mode 100644 index 00000000..d54b7a92 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_existence.yaml @@ -0,0 +1,24 @@ +dataset_name: "Object_existence" +task: "egothink_existence" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_forecasting.yaml b/lmms_eval/tasks/egothink/egothink_forecasting.yaml new file mode 100644 index 00000000..4688ffa5 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_forecasting.yaml @@ -0,0 +1,24 @@ +dataset_name: "Forecasting" +task: "egothink_forecasting" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_location.yaml b/lmms_eval/tasks/egothink/egothink_location.yaml new file mode 100644 index 00000000..0971abe2 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_location.yaml @@ -0,0 +1,24 @@ +dataset_name: "Localization_location" +task: "egothink_location" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_navigation.yaml b/lmms_eval/tasks/egothink/egothink_navigation.yaml new file mode 100644 index 00000000..ae3a14cb --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_navigation.yaml @@ -0,0 +1,24 @@ +dataset_name: "Planning_navigation" +task: "egothink_navigation" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 300 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in a detailed and helpful way. USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_situated.yaml b/lmms_eval/tasks/egothink/egothink_situated.yaml new file mode 100644 index 00000000..22b15f48 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_situated.yaml @@ -0,0 +1,24 @@ +dataset_name: "Reasoning_situated" +task: "egothink_situated" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_spatial.yaml b/lmms_eval/tasks/egothink/egothink_spatial.yaml new file mode 100644 index 00000000..31f3dedd --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_spatial.yaml @@ -0,0 +1,24 @@ +dataset_name: "Localization_spatial" +task: "egothink_spatial" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/utils.py b/lmms_eval/tasks/egothink/utils.py new file mode 100644 index 00000000..af763c95 --- /dev/null +++ b/lmms_eval/tasks/egothink/utils.py @@ -0,0 +1,188 @@ +import ast +import datetime +import json +import os +import re +import sys +import time +from pathlib import Path + +import numpy as np +import openai +import requests +import yaml +from loguru import logger as eval_logger +from openai import OpenAI +from tqdm import tqdm + +import lmms_eval.tasks._task_utils.file_utils as file_utils + +dir_name = os.path.dirname(os.path.abspath(__file__)) + +one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") +one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +API_ERROR_OUTPUT = "$ERROR$" + +API_MAX_RETRY = 6 + +NUM_SECONDS_TO_SLEEP = 15 + +GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] + +API_TYPE = os.getenv("API_TYPE", "openai") + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } +else: + API_URL = "YOUR_API_URL" + API_KEY = "YOUR_API_KEY" + + +def egothink_doc_to_visual(doc): + return [doc["image"].convert("RGB")] + + +# format the question +def egothink_doc_to_text(doc, lmms_eval_specific_kwargs=None): + question = doc["question"].strip() + if "pre_prompt" in lmms_eval_specific_kwargs and lmms_eval_specific_kwargs["pre_prompt"] != "": + question = f"{lmms_eval_specific_kwargs['pre_prompt']}{question}" + if "post_prompt" in lmms_eval_specific_kwargs and lmms_eval_specific_kwargs["post_prompt"] != "": + question = f"{question}{lmms_eval_specific_kwargs['post_prompt']}" + return question + + +# format answer +def egothink_doc_to_answer(doc): + return doc["answer"] + + +# Process result for evaluation in generic task +def chat_compeletion_openai(model, messages, temperature, max_tokens): + # headers = { + # "Authorization": f"Bearer {API_KEY}", + # "Content-Type": "application/json", + # } + # headers = { + # "Authorization": f"Bearer {API_KEY}", + # "Content-Type": "application/json", + # } + headers = { + "Content-Type": "application/json", + "api-key": API_KEY, + } + output = API_ERROR_OUTPUT + payload = { + # "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + + for attempt in range(API_MAX_RETRY): + try: + response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + response.raise_for_status() # Raises HTTPError for bad responses + try: + response_data = response.json() # Attempt to parse JSON + except requests.exceptions.JSONDecodeError: + eval_logger.error(f"JSON decode error on attempt {attempt + 1}. Response text: {response.text}") + continue # Skip to next retry + content = response_data["choices"][0]["message"]["content"].strip() + if content != "": + return content, response_data["model"] + # Handle HTTP errors separately + except requests.exceptions.HTTPError as e: + eval_logger.error(f"HTTP error on attempt {attempt + 1}: {e}") + # Handle other requests-related errors + except requests.exceptions.RequestException as e: + eval_logger.error(f"Request exception on attempt {attempt + 1}: {e}") + except Exception as e: + eval_logger.error(f"Unexpected error on attempt {attempt + 1}: {e}") + + # Handle other unexpected errors + if attempt < API_MAX_RETRY - 1: + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {retries} attempts failed. Last error message: {e}") + return "", "" + + return "", "" + + +def judge_single(question, answer, ref_answer): + model = GPT_EVAL_MODEL_NAME + + rating = -1 + + conv = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": f"[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. The assistant has access to an image alongwith questions but you will not be given images. Therefore, please consider only how the answer is close to the reference answer. If the assistant's answer is not exactly same as or similar to the answer, then he must be wrong. Be as objective as possible. Discourage uninformative answers. Also, equally treat short and long answers and focus on the correctness of answers. After providing your explanation, you must rate the response with either 0, 0.5 or 1 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[0.5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", + }, + ] + + judgment, eval_model = chat_compeletion_openai(model, conv, temperature=0, max_tokens=2048) + for _ in range(3): + match = re.search(one_score_pattern, judgment) + if not match: + match = re.search(one_score_pattern_backup, judgment) + + if match: + rating = ast.literal_eval(match.groups()[0]) + break + else: + rating = -1 + return rating, judgment, eval_model + + +def egothink_process_results(doc, results): + """ + Args: + doc: a instance of the eval dataset + results: [pred] + Returns: + a dictionary with key: metric name (in this case mme score), value: metric value + """ + pred = results[0] + question = doc["question"] + ref_ans = doc["answer"].lower().strip().replace(".", "") + score, judge, eval_model = judge_single(question, pred, ref_ans) + return {"gpt_eval_score": {"question_id": doc["id"], "score": score, "judge": judge, "eval_model": eval_model}} + + +def egothink_aggregate_results(results): + """ + Args: + results: a list of values returned by process_results + Returns: + A score + """ + total_score = 0 + for result in results: + total_score += result["score"] + return total_score / len(results)