Skip to content

Commit

Permalink
Optimize parsing of multiple-choice responses in videosearch
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Sep 21, 2024
1 parent f708591 commit 9344227
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 68 deletions.
6 changes: 3 additions & 3 deletions check_missing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from datasets import load_dataset, Dataset

# Load the deduplicated VideoSearch dataset
videosearch_dataset = load_dataset('lmms-lab/VideoSearch', 'deduplicated_combined_milestone', split='test')
videosearch_dataset = load_dataset("lmms-lab/VideoSearch", "deduplicated_combined_milestone", split="test")

# ID to be removed
id_to_remove = 'validation_Biology_18'
id_to_remove = "validation_Biology_18"

# Filter out the row with the missing ID
filtered_rows = [row for row in videosearch_dataset if row['id'] != id_to_remove]
filtered_rows = [row for row in videosearch_dataset if row["id"] != id_to_remove]

# Create a new dataset from the filtered rows
filtered_dataset = Dataset.from_list(filtered_rows)
Expand Down
12 changes: 6 additions & 6 deletions check_reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
from datasets import load_dataset

# Load the VideoSearch dataset
videosearch_dataset = load_dataset('lmms-lab/VideoSearch', 'final_combined_milestone', split='test')
videosearch_dataset = load_dataset("lmms-lab/VideoSearch", "final_combined_milestone", split="test")

# Path to the videos directory (replace with your actual path)
videos_directory = '/mnt/sfs-common/krhu/.cache/huggingface/Combined_milestone/videos/'
videos_directory = "/mnt/sfs-common/krhu/.cache/huggingface/Combined_milestone/videos/"

# Get all IDs from the dataset
videosearch_ids = set(videosearch_dataset['id'])
videosearch_ids = set(videosearch_dataset["id"])

# List to store IDs of files that are not in the dataset
extra_files = []

# Loop through all .mp4 files in the videos directory
for file in os.listdir(videos_directory):
if file.endswith('.mp4'):
if file.endswith(".mp4"):
# Extract the ID from the file name (remove the .mp4 extension)
file_id = file.replace('.mp4', '')
file_id = file.replace(".mp4", "")

# Check if the file ID exists in the VideoSearch dataset
if file_id not in videosearch_ids:
extra_files.append(file_id)
Expand Down
4 changes: 2 additions & 2 deletions lmms_eval/models/gpt4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
class GPT4V(lmms):
def __init__(
self,
#model_version: str = "gpt-4-vision-preview",
# model_version: str = "gpt-4-vision-preview",
modality: str = "video",
max_frames_num: int = 32,
timeout: int = 120,
Expand All @@ -59,7 +59,7 @@ def __init__(
# Manually set a image token for GPT4V so that we can search for it
# and split the text and image
# Here we just use the same token as llava for convenient
#self.model_version = model_version
# self.model_version = model_version
self.modality = modality
self.max_frames_num = max_frames_num
self.image_token = "<image>"
Expand Down
67 changes: 35 additions & 32 deletions lmms_eval/models/model_utils/load_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,56 +28,59 @@ def record_video_length_packet(container):
return frames


def read_video_pyav(video_path: str, *, num_frm=8, fps=None, format = "rgb24") -> np.ndarray:
def load_video_stream(container, num_frm: int = 8, fps: float = None):
# container = av.open(video_path)
total_frames = container.streams.video[0].frames
frame_rate = container.streams.video[0].average_rate
if fps is not None:
video_length = total_frames / frame_rate
num_frm = min(num_frm, int(video_length * fps))
sampled_frm = min(total_frames, num_frm)
indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int)

# Append the last frame index if not already included
if total_frames - 1 not in indices:
indices = np.append(indices, total_frames - 1)

return record_video_length_stream(container, indices)


def load_video_packet(container, num_frm: int = 8, fps: float = None):
frames = record_video_length_packet(container)
total_frames = len(frames)
sampled_frm = min(total_frames, num_frm)
indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int)

# Append the last frame index if not already included
if total_frames - 1 not in indices:
indices = np.append(indices, total_frames - 1)

return [frames[i] for i in indices]


def read_video_pyav(video_path: str, *, num_frm: int = 8, fps: float = None, format="rgb24") -> np.ndarray:
"""
Read video using the PyAV library.
Args:
video_path (str): The path to the video file.
num_frm (int, optional): The maximum number of frames to extract. Defaults to 8.
fps (optional): The frames per second for extraction. If `None`, the maximum number of frames will be extracted. Defaults to None.
fps (float, optional): The frames per second for extraction. If `None`, the maximum number of frames will be extracted. Defaults to None.
format (str, optional): The format of the extracted frames. Defaults to "rgb24".
Returns:
np.ndarray: A numpy array containing the extracted frames in RGB format.
"""

container = av.open(video_path)

if "webm" not in video_path and "mkv" not in video_path:
# For mp4, we try loading with stream first
try:
container = av.open(video_path)
total_frames = container.streams.video[0].frames
sampled_frm = min(total_frames, num_frm)
indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int)

# Append the last frame index if not already included
if total_frames - 1 not in indices:
indices = np.append(indices, total_frames - 1)

frames = record_video_length_stream(container, indices)
frames = load_video_stream(container, num_frm, fps)
except:
container = av.open(video_path)
frames = record_video_length_packet(container)
total_frames = len(frames)
sampled_frm = min(total_frames, num_frm)
indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int)

# Append the last frame index if not already included
if total_frames - 1 not in indices:
indices = np.append(indices, total_frames - 1)

frames = [frames[i] for i in indices]
else:
container = av.open(video_path)
frames = record_video_length_packet(container)
total_frames = len(frames)
sampled_frm = min(total_frames, num_frm)
indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int)

# Append the last frame index if not already included
if total_frames - 1 not in indices:
indices = np.append(indices, total_frames - 1)

frames = [frames[i] for i in indices]
return np.stack([x.to_ndarray(format=format) for x in frames])
return np.stack([x.to_ndarray(format=format) for x in frames])
13 changes: 7 additions & 6 deletions lmms_eval/tasks/mmmu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def evaluate_mmmu(samples):

# return pred_index


def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Expand All @@ -352,19 +353,19 @@ def parse_multi_choice_response(response, all_choices, index2ans):
# print(f"Found choice with parentheses: {choice}")
# candidates.append(choice)
# ans_with_brack = True

# # Step 4: If no candidates, look for choices with a period after (A. B. C. D.)
# if len(candidates) == 0:
# for choice in all_choices: # e.g., A. B. C. D.
# if f"{choice}." in response:
# print(f"Found choice with period after: {choice}")
# candidates.append(choice)
# Step 2: Look for choices with parentheses e.g., (A) (B) (C) (D)
# Step 2: Look for choices with parentheses e.g., (A) (B) (C) (D)
for choice in all_choices: # e.g., (A) (B) (C) (D)
if f"{choice}." in response:
print(f"Found choice with period after: {choice}")
candidates.append(choice)
if f"{choice}." in response:
print(f"Found choice with period after: {choice}")
candidates.append(choice)

# Step 4: If no candidates, look for choices with a period after (A. B. C. D.)
if len(candidates) == 0:
for choice in all_choices: # e.g., A. B. C. D.
Expand Down
13 changes: 5 additions & 8 deletions lmms_eval/tasks/mmmu_for_testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def evaluate_mmmu(samples):
# if f"{choice}." in response:
# print(f"Found choice with period after: {choice}")
# candidates.append(choice)

# # Step 4: If no candidates, look for choices with a period after (A. B. C. D.)
# if len(candidates) == 0:
# for choice in all_choices: # e.g., A. B. C. D.
Expand All @@ -307,7 +307,6 @@ def evaluate_mmmu(samples):
# candidates.append(choice)



# # Step 5: If no candidates and response has more than 5 tokens, try parsing based on content
# if len(candidates) == 0 and len(response.split()) > 5:
# for index, ans in index2ans.items():
Expand Down Expand Up @@ -369,7 +368,7 @@ def evaluate_mmmu(samples):
# print(f"Found choice with parentheses: {choice}")
# candidates.append(choice)
# ans_with_brack = True

# # Step 4: If no candidates, look for choices with a period after (A. B. C. D.)
# if len(candidates) == 0:
# for choice in all_choices: # e.g., A. B. C. D.
Expand All @@ -385,7 +384,6 @@ def evaluate_mmmu(samples):
# candidates.append(choice)



# # Step 5: If no candidates and response has more than 5 tokens, try parsing based on content
# if len(candidates) == 0 and len(response.split()) > 5:
# for index, ans in index2ans.items():
Expand Down Expand Up @@ -486,6 +484,7 @@ def evaluate_mmmu(samples):

# return pred_index


def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Expand All @@ -500,7 +499,7 @@ def parse_multi_choice_response(response, all_choices, index2ans):
ans_with_brack = False
ans_with_period = False
candidates = []

# Step 4: If no candidates, look for choices with a period after (A. B. C. D.)
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}." in response:
Expand All @@ -523,8 +522,6 @@ def parse_multi_choice_response(response, all_choices, index2ans):
print(f"Found choice without parentheses (space after): {choice}")
candidates.append(choice)



# Step 5: If no candidates and response has more than 5 tokens, try parsing based on content
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
Expand Down Expand Up @@ -721,4 +718,4 @@ def get_multi_choice_info(options):
index2ans[chr(ord(start_chr) + i)] = option
all_choices.append(chr(ord(start_chr) + i))

return index2ans, all_choices
return index2ans, all_choices
20 changes: 9 additions & 11 deletions lmms_eval/tasks/videosearch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ def videosearch_doc_to_text(doc, lmms_eval_specific_kwargs=None):
lmms_eval_specific_kwargs = {}
pre_prompt = ""
post_prompt = ""

pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
question = doc["question"]

if doc["question_type"] == "multiple-choice":
pre_prompt += lmms_eval_specific_kwargs["mcq_prompt"]
post_prompt = lmms_eval_specific_kwargs["post_mcq_prompt"]
parsed_options = parse_options(ast.literal_eval(doc["options"]))
question += "\n" + parsed_options
question += "\n" + parsed_options
else:
pre_prompt += lmms_eval_specific_kwargs["open_ended_prompt"]
post_prompt = lmms_eval_specific_kwargs["post_open_ended_prompt"]
post_prompt = lmms_eval_specific_kwargs["post_open_ended_prompt"]

# print(f"{pre_prompt}{question}")
return f"{pre_prompt}{question}"
return f"{pre_prompt}{question}"


def parse_options(options):
Expand Down Expand Up @@ -181,10 +181,8 @@ def calculate_ins_level_acc(results):


DOMAIN_CAT2SUB_CAT = {
"Art and Design":
["Art", "Art_Theory", "Design", "Music"],
"Business":
["Accounting", "Economics", "Finance", "Manage", "Marketing"],
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
"Science": [
"Biology",
"Chemistry",
Expand Down Expand Up @@ -306,7 +304,7 @@ def parse_multi_choice_response(response, all_choices, index2ans):
ans_with_brack = False
ans_with_period = False
candidates = []

# Step 4: If no candidates, look for choices with a period after (A. B. C. D.)
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}." in response:
Expand Down

0 comments on commit 9344227

Please sign in to comment.