Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contribute EgoLife model and evaluation pipeline for EgoPlan & Egothink #559

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e298464
[Feat] Add qwen2_audio model support and Automatic speech recognition…
Prophet-C Oct 1, 2024
b43ad10
add clotho_aqa task
pbcong Oct 4, 2024
67bceed
Apply black formatting
pbcong Oct 4, 2024
ddee5fd
formatting
pbcong Oct 4, 2024
1b7536d
excluding xl due to downloading issue.
Luodian Oct 16, 2024
93a051f
[Feat] add audiobench version of clothoaqa (#302)
pbcong Oct 7, 2024
fafee41
Add AIR_bench task (#315)
pbcong Oct 12, 2024
2a3a172
add common_voice_15 and people_speech tasks (#316)
Prophet-C Oct 12, 2024
f96f59d
add indent to yaml
Yingluo-momo Oct 16, 2024
8b44868
Add openhermes task (#323)
pbcong Oct 16, 2024
ab81297
[Refactor] Fixing doc to audio return type, qwen_audio revise (#329)
kcz358 Oct 18, 2024
4bd5b1c
add muchomusic and vocalsound task (#331)
pbcong Oct 18, 2024
6e826a0
add alpaca audio task (#333)
pbcong Oct 19, 2024
aebba33
[feat] added gigaspeech config (#334)
Yingluo-momo Oct 20, 2024
2e60ad1
add tedlium_long_form and tedlium_dev_test tasks (#345)
Prophet-C Oct 24, 2024
49651d5
[Feat] add-wavcaps (#349)
Yingluo-momo Oct 25, 2024
60f699d
Update dep and fix log samples for audio (#355)
kcz358 Oct 27, 2024
87f28e3
fix vocalsound (#362)
pbcong Oct 28, 2024
58c89e1
Add using simple prompt for Qwen2 Audio to align (#360)
kcz358 Oct 28, 2024
fdcd2a8
Add retry for gpt api call and improve air_bench aggregation function…
pbcong Oct 30, 2024
9359059
update egogpt model for lmms-eval
choiszt Feb 27, 2025
89b8876
support evaluation pipeline for EgoPlan
choiszt Feb 27, 2025
884f31a
support evaluation for EgoThink
choiszt Feb 27, 2025
08b049b
lint check
choiszt Feb 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,15 @@ def evaluate(
metrics = task.process_results(doc, [req.filtered_resps[filter_key] for req in requests])
if log_samples:
target = task.doc_to_target(doc)
saved_doc = {key: value for key, value in doc.items() if "image" not in key}
saved_doc = {}
for key, value in doc.items():
# If image is not in key
if "image" not in key:
# If audio is also not the value
if isinstance(value, dict) and "array" in value:
continue
else:
saved_doc[key] = value
filtered_arguments = []
for req in requests:
# check if req.args is a list of tuples, and each item in the list is a serializable object
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/evaluator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def consolidate_group_results(
task_root=None,
show_group_table=False,
task_aggregation_list=None,
) -> Tuple[dict, dict, bool, Union[None,]]:
) -> Tuple[dict, dict, bool, Union[None, dict]]:
"""
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.

Expand Down
2 changes: 2 additions & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"phi3v": "Phi3v",
"qwen_vl": "Qwen_VL",
"qwen2_vl": "Qwen2_VL",
"qwen2_audio": "Qwen2_Audio",
"qwen_vl_api": "Qwen_VL_API",
"reka": "Reka",
"srt_api": "SRT_API",
Expand All @@ -49,6 +50,7 @@
"oryx": "Oryx",
"videochat2": "VideoChat2",
"llama_vision": "LlamaVision",
"egogpt": "EgoGPT",
}


Expand Down
472 changes: 472 additions & 0 deletions lmms_eval/models/egogpt.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions lmms_eval/models/model_utils/audio_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np
from librosa import resample


def downsample_audio(audio_array: np.ndarray, original_sr: int, target_sr: int) -> np.ndarray:
audio_resample_array = resample(audio_array, orig_sr=original_sr, target_sr=target_sr)
return audio_resample_array
284 changes: 284 additions & 0 deletions lmms_eval/models/qwen2_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import base64
from io import BytesIO
from typing import List, Optional, Tuple, Union

import decord
import torch
from accelerate import Accelerator, DistributedType
from loguru import logger as eval_logger
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.audio_processing import downsample_audio


@register_model("qwen2_audio")
class Qwen2_Audio(lmms):
"""
Qwen2_Audio Model
"https://github.com/QwenLM/Qwen2-Audio"
"""

def __init__(
self,
pretrained: str = "Qwen/Qwen2-Audio-7B", # Qwen/Qwen2-Audio-7B-Instruct
device: Optional[str] = "cuda",
device_map: Optional[str] = "cuda",
batch_size: Optional[Union[int, str]] = 1,
use_cache=True,
add_generation_prompt: bool = True,
add_system_prompt: bool = True,
simple_prompt: bool = False,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator = Accelerator()
self.add_generation_prompt = add_generation_prompt
self.add_system_prompt = add_system_prompt
# If using simple prompt, only add "<|audio_bos|><|AUDIO|><|audio_eos|>"
# and then prompt to align with original Qwen2 Audio
self.simple_prompt = simple_prompt
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"

self._model = Qwen2AudioForConditionalGeneration.from_pretrained(
pretrained,
torch_dtype="auto",
device_map=device_map,
).eval()

self.processor = AutoProcessor.from_pretrained(pretrained)
self.processor.tokenizer.padding_side = "left"
self._tokenizer = self.processor.tokenizer

if not self.add_system_prompt:
# Overwrite chat template to exclude system prompt
self.processor.chat_template = (
"{% set audio_count = namespace(value=0) %}"
"{% for message in messages %}"
"<|im_start|>{{ message['role'] }}\n"
"{% if message['content'] is string %}"
"{{ message['content'] }}<|im_end|>\n"
"{% else %}"
"{% for content in message['content'] %}"
"{% if 'audio' in content or 'audio_url' in content %}"
"{% set audio_count.value = audio_count.value + 1 %}"
"Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
"{% elif 'text' in content %}"
"{{ content['text'] }}"
"{% endif %}"
"{% endfor %}"
"<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"<|im_start|>assistant\n"
"{% endif %}"
)

self._config = self.model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.model.to(self._device)
self._rank = 0
self._word_size = 1

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
return self.tokenizer.eos_token_id

@property
def max_length(self):
return self._max_length

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for Qwen2_Audio")

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tokenizer.encode(x[0])
return -len(toks), x[0]

pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
batched_audios = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
flattened_audios = self.flatten(batched_audios)

# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]

# Set default values for until and max_new_tokens
until = [self.tokenizer.decode(self.eot_token_id)]

# Update values from gen_kwargs if present
if "until" in gen_kwargs:
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")

# contexts = "<|audio_bos|><|AUDIO|><|audio_eos|>" + contexts

if isinstance(contexts, tuple):
contexts = list(contexts)

if not self.simple_prompt:
conversations = []
for idx, context in enumerate(contexts):
conv = [{"role": "user", "content": []}]
for _ in batched_audios[idx]:
# This placeholder is just use to make chat template work
# We already have the sampled audio array
conv[0]["content"].append({"type": "audio", "audio_url": "placeholder.wav"})
conv[0]["content"].append({"type": "text", "text": context})
conversations.append(conv)

text = [self.processor.apply_chat_template(conversation, add_generation_prompt=self.add_generation_prompt, tokenize=False) for conversation in conversations]
else:
text = ["<|audio_bos|><|AUDIO|><|audio_eos|>" + context for context in contexts]
audios = [downsample_audio(audio["array"], audio["sampling_rate"], self.processor.feature_extractor.sampling_rate) for audio in flattened_audios]

inputs = self.processor(text=text, audios=audios, return_tensors="pt", padding=True, sampling_rate=self.processor.feature_extractor.sampling_rate)

if self.device_map == "auto":
inputs = inputs.to("cuda")
else:
inputs = inputs.to(self.device)

if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 256
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1

try:
cont = self.model.generate(
**inputs,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
min_new_tokens=1,
use_cache=self.use_cache,
)

# cont = self.model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)

generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
# generated_ids_trimmed = cont[:, inputs.input_ids.size(1):]
answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for i, ans in enumerate(answers):
for term in until:
if len(term) > 0:
ans = ans.split(term)[0]
answers[i] = ans

except Exception as e:
eval_logger.debug(f"Error while generating: {e}. It is possibly due to blank audio in {contexts}")
answers = [""] * len(contexts)

for ans, context in zip(answers, contexts):
res.append(ans)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
pbar.update(1)

# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)

pbar.close()
return res

def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation")
7 changes: 7 additions & 0 deletions lmms_eval/tasks/air_bench/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dataset_path: lmms-lab/AIR_Bench
dataset_kwargs:
token: True

metadata:
gpt_eval_model_name: gpt-4o
version: 0.0
6 changes: 6 additions & 0 deletions lmms_eval/tasks/air_bench/air_bench_chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
group: air_bench_chat
tasks:
- air_bench_chat_sound
- air_bench_chat_music
- air_bench_chat_speech
- air_bench_chat_mixed
25 changes: 25 additions & 0 deletions lmms_eval/tasks/air_bench/air_bench_chat_mixed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
task: "air_bench_chat_mixed"
dataset_name: "Chat"
test_split: mixed
doc_to_target: "answer_gt"
doc_to_visual: !function utils.air_bench_doc_to_audio
doc_to_text: !function utils.air_bench_doc_to_text_chat

generation_kwargs:
max_new_tokens: 1024
temperature: 0.2
top_p: 1.0
num_beams: 1

lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "Give a detail answer to the question in English."
metric_list:
- metric: gpt_eval
aggregation: !function utils.air_bench_aggregate_results_chat
higher_is_better: true

process_results: !function utils.air_bench_process_results_chat

include: _default_template_yaml
Loading