Skip to content

Commit

Permalink
Squashed commits from internal development repo
Browse files Browse the repository at this point in the history
  • Loading branch information
Bo Li committed Apr 16, 2024
1 parent 49a23db commit 2c03baf
Show file tree
Hide file tree
Showing 265 changed files with 1,730 additions and 159 deletions.
Empty file modified .github/issue_template.md
100644 → 100755
Empty file.
Empty file modified .github/pull_request_template.md
100644 → 100755
Empty file.
Empty file modified .github/workflows/black.yml
100644 → 100755
Empty file.
Empty file modified .gitignore
100644 → 100755
Empty file.
Empty file modified .pre-commit-config.yaml
100644 → 100755
Empty file.
Empty file modified README.md
100644 → 100755
Empty file.
Empty file modified docs/README.md
100644 → 100755
Empty file.
Empty file modified docs/commands.md
100644 → 100755
Empty file.
Empty file modified docs/model_guide.md
100644 → 100755
Empty file.
Empty file modified docs/task_guide.md
100644 → 100755
Empty file.
11 changes: 2 additions & 9 deletions example_eval.yaml
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
- model: llava
model_args: pretrained=liuhaotian/llava-v1.5-7b
tasks: ai2d
tasks: mmmu_val
batch_size: 1
log_samples: true
log_samples_suffix: eval_vizwiz_vqa
log_samples_suffix: eval_mmmu
output_path: "./logs/"

- model: llava
model_args: pretrained=liuhaotian/llava-v1.5-13b
tasks: mme
batch_size: 1
log_samples: true
log_samples_suffix: mme
output_path: "./logs/"
17 changes: 0 additions & 17 deletions llava_repr_requirements.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,6 @@ shortuuid==1.0.12
sqlitedict==2.1.0
tenacity==8.2.3
torch==2.0.1
openai>=1.0.0
pycocoevalcap
tokenizers==0.15.2
tqdm==4.66.2
tqdm-multiprocess
transformers==4.37.2
zstandard
pillow
pyyaml
sympy
mpmath
Jinja2
openpyxl
Levenshtein
hf_transfer
tenacity
wandb>=0.16.0
transformers-stream-generator
tiktoken
pre-commit
Empty file modified lmms_eval/__init__.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/__main__.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/api/__init__.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/api/filter.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/api/instance.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/api/metrics.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/api/model.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/api/registry.py
100644 → 100755
Empty file.
4 changes: 3 additions & 1 deletion lmms_eval/api/samplers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def get_context(self, doc, num_fewshot):
+ (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
else self.doc_to_target(doc)
if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str)
else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)
for doc in selected_docs
]
Expand Down
25 changes: 25 additions & 0 deletions lmms_eval/api/task.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import ast
import logging
import random
from glob import glob
import shutil
from tqdm import tqdm

import datasets
Expand All @@ -18,6 +20,7 @@
from typing import Union, List, Any
from collections.abc import Callable
from tenacity import retry, stop_after_attempt, wait_fixed
from huggingface_hub import snapshot_download

from lmms_eval import utils
from lmms_eval.api import samplers
Expand Down Expand Up @@ -678,6 +681,26 @@ def _prepare_metric_and_aggregation(self):

@retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
def download(self, dataset_kwargs=None) -> None:
# If the dataset is a video dataset,
# Recursively search whether their is a zip and unzip it to the huggingface home
if dataset_kwargs is not None and "video" in dataset_kwargs and dataset_kwargs["video"]:
hf_home = os.environ["HF_HOME"]
cache_dir = dataset_kwargs["cache_dir"]

cache_dir = os.path.join(hf_home, cache_dir)
cache_path = snapshot_download(repo_id=self.DATASET_PATH, repo_type="dataset")
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
if not os.path.exists(cache_dir):
for zip_file in zip_files:
shutil.unpack_archive(zip_file, cache_dir)

if "builder_script" in dataset_kwargs:
builder_script = dataset_kwargs["builder_script"]
self.DATASET_PATH = os.path.join(cache_path, builder_script)
dataset_kwargs.pop("builder_script")

dataset_kwargs.pop("cache_dir")
dataset_kwargs.pop("video")
download_config = DownloadConfig()
download_config.max_retries = dataset_kwargs.get("max_retries", 3) if dataset_kwargs is not None else 3
download_config.num_proc = dataset_kwargs.get("num_proc", 8) if dataset_kwargs is not None else 8
Expand Down Expand Up @@ -973,6 +996,8 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)

def process_results(self, doc, results):
# if self.OUTPUT_TYPE == "generate_until":
# results[0] = results[0].strip()
if callable(self.config.process_results):
return self.config.process_results(doc, results)

Expand Down
Empty file modified lmms_eval/evaluator.py
100644 → 100755
Empty file.
3 changes: 2 additions & 1 deletion lmms_eval/filters/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from lmms_eval.api.filter import FilterEnsemble
from lmms_eval.api.filter import FilterEnsemble, Filter
from . import selection
from . import extraction
from . import transformation
Expand All @@ -13,6 +13,7 @@
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
Expand Down
Empty file modified lmms_eval/filters/decontamination.py
100644 → 100755
Empty file.
186 changes: 170 additions & 16 deletions lmms_eval/filters/extraction.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import re

import sys
import unicodedata
from lmms_eval.api.filter import Filter


class WhitespaceFilter(Filter):
""" """

def __init__(self) -> None:
pass

def apply(self, resps, docs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]

filtered_resp.append(resp)

return filtered_resp

filtered_resps = [filter_set(resp) for resp in resps]

return filtered_resps


class RegexFilter(Filter):
""" """

def __init__(self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]") -> None:
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback

def apply(self, resps, docs):
Expand All @@ -23,9 +52,12 @@ def apply(self, resps, docs):
def filter_set(inst):
filtered = []
for resp in inst:
match = self.regex.search(resp)
match = self.regex.findall(resp)
if match:
match = match.group(1).strip()
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
else:
match = self.fallback
filtered.append(match)
Expand All @@ -38,23 +70,145 @@ def filter_set(inst):
return filtered_resps


class WhitespaceFilter(Filter):
""" """
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""

def __init__(self) -> None:
pass
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore

def apply(self, resps, docs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)

filtered_resp.append(resp)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match

return filtered_resp
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))

filtered_resps = [filter_set(resp) for resp in resps]
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)

if self.ignore_case:
st = st.lower()

if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st

filtered_resps = []

for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"

without_paren_fallback_regexes = []
without_paren_to_target = {}

choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"

without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"

next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")

filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
if not match:
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)

return filtered_resps


class ExtendedRegexFilter(RegexFilter):
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))

def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore

def filter_ignores(self, st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)

if self.ignore_case:
st = st.lower()

if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(self.punct_tbl)
return st

def find_match(self, regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
Empty file modified lmms_eval/filters/selection.py
100644 → 100755
Empty file.
Empty file modified lmms_eval/filters/transformation.py
100644 → 100755
Empty file.
18 changes: 11 additions & 7 deletions lmms_eval/logging_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def finish(self):
def init_run(self):
if "name" not in self.wandb_args:
if "config" in self.all_args_dict and self.all_args_dict["config"] != "":
self.wandb_args["name"] = self.all_args_dict["config"].split("/")[-1].replace(".yaml", "") + "_" + self.args.log_samples_suffix
self.wandb_args["name"] = self.all_args_dict["config"].split("/")[-1].replace(".yaml", "") + "/" + self.args.log_samples_suffix
else:
task_names = self.args.tasks.replace(",", "/")
self.wandb_args["name"] = f"{self.args.model}_{task_names}_{self.args.log_samples_suffix}"
self.wandb_args["name"] = f"{self.args.model}/<{task_names}>/{self.args.log_samples_suffix}"
if self.args.num_fewshot:
self.wandb_args["name"] += f"_{self.args.num_fewshot}shot"
if "project" not in self.wandb_args:
Expand All @@ -119,6 +119,7 @@ def _get_config(self) -> Dict[str, Any]:
def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
"""Sanitize the results dictionary."""
_results = copy.deepcopy(self.results.get("results", dict()))
_results["model_configs"] = self.results.get("model_configs", dict())

# Remove None from the metric string name
tmp_results = copy.deepcopy(_results)
Expand All @@ -138,15 +139,18 @@ def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
if isinstance(metric_value, str):
wandb_summary[f"{task}/{metric_name}"] = metric_value

wandb_summary["model_configs"] = self.results.get("model_configs", dict())
for summary_metric, summary_value in wandb_summary.items():
_task, _summary_metric = summary_metric.split("/")
_results[_task].pop(_summary_metric)
if summary_metric != "model_configs":
_task, _summary_metric = summary_metric.split("/")
_results[_task].pop(_summary_metric)

tmp_results = copy.deepcopy(_results)
for task_name, task_results in tmp_results.items():
for metric_name, metric_value in task_results.items():
_results[f"{task_name}/{metric_name}"] = metric_value
_results[task_name].pop(metric_name)
if task_name != "model_configs":
for metric_name, metric_value in task_results.items():
_results[f"{task_name}/{metric_name}"] = metric_value
_results[task_name].pop(metric_name)
for task in self.task_names:
_results.pop(task)

Expand Down
Loading

0 comments on commit 2c03baf

Please sign in to comment.