From 93442279cb39f9b892e6cc22a9dd96b8f6676154 Mon Sep 17 00:00:00 2001 From: Pu Fanyi Date: Sun, 22 Sep 2024 02:05:54 +0800 Subject: [PATCH] Optimize parsing of multiple-choice responses in videosearch --- check_missing.py | 6 +- check_reverse.py | 12 ++-- lmms_eval/models/gpt4v.py | 4 +- lmms_eval/models/model_utils/load_video.py | 67 +++++++++++----------- lmms_eval/tasks/mmmu/utils.py | 13 +++-- lmms_eval/tasks/mmmu_for_testing/utils.py | 13 ++--- lmms_eval/tasks/videosearch/utils.py | 20 +++---- 7 files changed, 67 insertions(+), 68 deletions(-) diff --git a/check_missing.py b/check_missing.py index c4462f8f4..75626faff 100644 --- a/check_missing.py +++ b/check_missing.py @@ -1,13 +1,13 @@ from datasets import load_dataset, Dataset # Load the deduplicated VideoSearch dataset -videosearch_dataset = load_dataset('lmms-lab/VideoSearch', 'deduplicated_combined_milestone', split='test') +videosearch_dataset = load_dataset("lmms-lab/VideoSearch", "deduplicated_combined_milestone", split="test") # ID to be removed -id_to_remove = 'validation_Biology_18' +id_to_remove = "validation_Biology_18" # Filter out the row with the missing ID -filtered_rows = [row for row in videosearch_dataset if row['id'] != id_to_remove] +filtered_rows = [row for row in videosearch_dataset if row["id"] != id_to_remove] # Create a new dataset from the filtered rows filtered_dataset = Dataset.from_list(filtered_rows) diff --git a/check_reverse.py b/check_reverse.py index 32ff6f277..bad0d9e59 100644 --- a/check_reverse.py +++ b/check_reverse.py @@ -2,23 +2,23 @@ from datasets import load_dataset # Load the VideoSearch dataset -videosearch_dataset = load_dataset('lmms-lab/VideoSearch', 'final_combined_milestone', split='test') +videosearch_dataset = load_dataset("lmms-lab/VideoSearch", "final_combined_milestone", split="test") # Path to the videos directory (replace with your actual path) -videos_directory = '/mnt/sfs-common/krhu/.cache/huggingface/Combined_milestone/videos/' +videos_directory = "/mnt/sfs-common/krhu/.cache/huggingface/Combined_milestone/videos/" # Get all IDs from the dataset -videosearch_ids = set(videosearch_dataset['id']) +videosearch_ids = set(videosearch_dataset["id"]) # List to store IDs of files that are not in the dataset extra_files = [] # Loop through all .mp4 files in the videos directory for file in os.listdir(videos_directory): - if file.endswith('.mp4'): + if file.endswith(".mp4"): # Extract the ID from the file name (remove the .mp4 extension) - file_id = file.replace('.mp4', '') - + file_id = file.replace(".mp4", "") + # Check if the file ID exists in the VideoSearch dataset if file_id not in videosearch_ids: extra_files.append(file_id) diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index 6491078ae..6d4b464ed 100755 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -47,7 +47,7 @@ class GPT4V(lmms): def __init__( self, - #model_version: str = "gpt-4-vision-preview", + # model_version: str = "gpt-4-vision-preview", modality: str = "video", max_frames_num: int = 32, timeout: int = 120, @@ -59,7 +59,7 @@ def __init__( # Manually set a image token for GPT4V so that we can search for it # and split the text and image # Here we just use the same token as llava for convenient - #self.model_version = model_version + # self.model_version = model_version self.modality = modality self.max_frames_num = max_frames_num self.image_token = "" diff --git a/lmms_eval/models/model_utils/load_video.py b/lmms_eval/models/model_utils/load_video.py index 747ae016a..e7e23a91b 100644 --- a/lmms_eval/models/model_utils/load_video.py +++ b/lmms_eval/models/model_utils/load_video.py @@ -28,56 +28,59 @@ def record_video_length_packet(container): return frames -def read_video_pyav(video_path: str, *, num_frm=8, fps=None, format = "rgb24") -> np.ndarray: +def load_video_stream(container, num_frm: int = 8, fps: float = None): + # container = av.open(video_path) + total_frames = container.streams.video[0].frames + frame_rate = container.streams.video[0].average_rate + if fps is not None: + video_length = total_frames / frame_rate + num_frm = min(num_frm, int(video_length * fps)) + sampled_frm = min(total_frames, num_frm) + indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int) + + # Append the last frame index if not already included + if total_frames - 1 not in indices: + indices = np.append(indices, total_frames - 1) + + return record_video_length_stream(container, indices) + + +def load_video_packet(container, num_frm: int = 8, fps: float = None): + frames = record_video_length_packet(container) + total_frames = len(frames) + sampled_frm = min(total_frames, num_frm) + indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int) + + # Append the last frame index if not already included + if total_frames - 1 not in indices: + indices = np.append(indices, total_frames - 1) + + return [frames[i] for i in indices] + + +def read_video_pyav(video_path: str, *, num_frm: int = 8, fps: float = None, format="rgb24") -> np.ndarray: """ Read video using the PyAV library. Args: video_path (str): The path to the video file. num_frm (int, optional): The maximum number of frames to extract. Defaults to 8. - fps (optional): The frames per second for extraction. If `None`, the maximum number of frames will be extracted. Defaults to None. + fps (float, optional): The frames per second for extraction. If `None`, the maximum number of frames will be extracted. Defaults to None. format (str, optional): The format of the extracted frames. Defaults to "rgb24". Returns: np.ndarray: A numpy array containing the extracted frames in RGB format. """ + container = av.open(video_path) if "webm" not in video_path and "mkv" not in video_path: # For mp4, we try loading with stream first try: - container = av.open(video_path) - total_frames = container.streams.video[0].frames - sampled_frm = min(total_frames, num_frm) - indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int) - - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - - frames = record_video_length_stream(container, indices) + frames = load_video_stream(container, num_frm, fps) except: - container = av.open(video_path) frames = record_video_length_packet(container) - total_frames = len(frames) - sampled_frm = min(total_frames, num_frm) - indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int) - - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - - frames = [frames[i] for i in indices] else: - container = av.open(video_path) frames = record_video_length_packet(container) - total_frames = len(frames) - sampled_frm = min(total_frames, num_frm) - indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int) - - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - frames = [frames[i] for i in indices] - return np.stack([x.to_ndarray(format=format) for x in frames]) \ No newline at end of file + return np.stack([x.to_ndarray(format=format) for x in frames]) diff --git a/lmms_eval/tasks/mmmu/utils.py b/lmms_eval/tasks/mmmu/utils.py index 7670fb9e0..5947cb751 100755 --- a/lmms_eval/tasks/mmmu/utils.py +++ b/lmms_eval/tasks/mmmu/utils.py @@ -332,6 +332,7 @@ def evaluate_mmmu(samples): # return pred_index + def parse_multi_choice_response(response, all_choices, index2ans): """ Parse the prediction from the generated response. @@ -352,19 +353,19 @@ def parse_multi_choice_response(response, all_choices, index2ans): # print(f"Found choice with parentheses: {choice}") # candidates.append(choice) # ans_with_brack = True - + # # Step 4: If no candidates, look for choices with a period after (A. B. C. D.) # if len(candidates) == 0: # for choice in all_choices: # e.g., A. B. C. D. # if f"{choice}." in response: # print(f"Found choice with period after: {choice}") # candidates.append(choice) - # Step 2: Look for choices with parentheses e.g., (A) (B) (C) (D) + # Step 2: Look for choices with parentheses e.g., (A) (B) (C) (D) for choice in all_choices: # e.g., (A) (B) (C) (D) - if f"{choice}." in response: - print(f"Found choice with period after: {choice}") - candidates.append(choice) - + if f"{choice}." in response: + print(f"Found choice with period after: {choice}") + candidates.append(choice) + # Step 4: If no candidates, look for choices with a period after (A. B. C. D.) if len(candidates) == 0: for choice in all_choices: # e.g., A. B. C. D. diff --git a/lmms_eval/tasks/mmmu_for_testing/utils.py b/lmms_eval/tasks/mmmu_for_testing/utils.py index 770760f22..af434f830 100755 --- a/lmms_eval/tasks/mmmu_for_testing/utils.py +++ b/lmms_eval/tasks/mmmu_for_testing/utils.py @@ -290,7 +290,7 @@ def evaluate_mmmu(samples): # if f"{choice}." in response: # print(f"Found choice with period after: {choice}") # candidates.append(choice) - + # # Step 4: If no candidates, look for choices with a period after (A. B. C. D.) # if len(candidates) == 0: # for choice in all_choices: # e.g., A. B. C. D. @@ -307,7 +307,6 @@ def evaluate_mmmu(samples): # candidates.append(choice) - # # Step 5: If no candidates and response has more than 5 tokens, try parsing based on content # if len(candidates) == 0 and len(response.split()) > 5: # for index, ans in index2ans.items(): @@ -369,7 +368,7 @@ def evaluate_mmmu(samples): # print(f"Found choice with parentheses: {choice}") # candidates.append(choice) # ans_with_brack = True - + # # Step 4: If no candidates, look for choices with a period after (A. B. C. D.) # if len(candidates) == 0: # for choice in all_choices: # e.g., A. B. C. D. @@ -385,7 +384,6 @@ def evaluate_mmmu(samples): # candidates.append(choice) - # # Step 5: If no candidates and response has more than 5 tokens, try parsing based on content # if len(candidates) == 0 and len(response.split()) > 5: # for index, ans in index2ans.items(): @@ -486,6 +484,7 @@ def evaluate_mmmu(samples): # return pred_index + def parse_multi_choice_response(response, all_choices, index2ans): """ Parse the prediction from the generated response. @@ -500,7 +499,7 @@ def parse_multi_choice_response(response, all_choices, index2ans): ans_with_brack = False ans_with_period = False candidates = [] - + # Step 4: If no candidates, look for choices with a period after (A. B. C. D.) for choice in all_choices: # e.g., A. B. C. D. if f"{choice}." in response: @@ -523,8 +522,6 @@ def parse_multi_choice_response(response, all_choices, index2ans): print(f"Found choice without parentheses (space after): {choice}") candidates.append(choice) - - # Step 5: If no candidates and response has more than 5 tokens, try parsing based on content if len(candidates) == 0 and len(response.split()) > 5: for index, ans in index2ans.items(): @@ -721,4 +718,4 @@ def get_multi_choice_info(options): index2ans[chr(ord(start_chr) + i)] = option all_choices.append(chr(ord(start_chr) + i)) - return index2ans, all_choices \ No newline at end of file + return index2ans, all_choices diff --git a/lmms_eval/tasks/videosearch/utils.py b/lmms_eval/tasks/videosearch/utils.py index a7e02e67c..62c62d2e1 100755 --- a/lmms_eval/tasks/videosearch/utils.py +++ b/lmms_eval/tasks/videosearch/utils.py @@ -56,21 +56,21 @@ def videosearch_doc_to_text(doc, lmms_eval_specific_kwargs=None): lmms_eval_specific_kwargs = {} pre_prompt = "" post_prompt = "" - + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] question = doc["question"] - + if doc["question_type"] == "multiple-choice": pre_prompt += lmms_eval_specific_kwargs["mcq_prompt"] post_prompt = lmms_eval_specific_kwargs["post_mcq_prompt"] parsed_options = parse_options(ast.literal_eval(doc["options"])) - question += "\n" + parsed_options + question += "\n" + parsed_options else: pre_prompt += lmms_eval_specific_kwargs["open_ended_prompt"] - post_prompt = lmms_eval_specific_kwargs["post_open_ended_prompt"] - + post_prompt = lmms_eval_specific_kwargs["post_open_ended_prompt"] + # print(f"{pre_prompt}{question}") - return f"{pre_prompt}{question}" + return f"{pre_prompt}{question}" def parse_options(options): @@ -181,10 +181,8 @@ def calculate_ins_level_acc(results): DOMAIN_CAT2SUB_CAT = { - "Art and Design": - ["Art", "Art_Theory", "Design", "Music"], - "Business": - ["Accounting", "Economics", "Finance", "Manage", "Marketing"], + "Art and Design": ["Art", "Art_Theory", "Design", "Music"], + "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"], "Science": [ "Biology", "Chemistry", @@ -306,7 +304,7 @@ def parse_multi_choice_response(response, all_choices, index2ans): ans_with_brack = False ans_with_period = False candidates = [] - + # Step 4: If no candidates, look for choices with a period after (A. B. C. D.) for choice in all_choices: # e.g., A. B. C. D. if f"{choice}." in response: