From 236667d46b38a7864970b3f1a075d3e9232f21ea Mon Sep 17 00:00:00 2001 From: white <1102737450@qq.com> Date: Tue, 17 Sep 2024 16:53:27 +0800 Subject: [PATCH] [Model] support cogvlm2 model (#261) --- lmms_eval/models/__init__.py | 1 + lmms_eval/models/cogvlm2.py | 226 +++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 lmms_eval/models/cogvlm2.py diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index 33e4480eb..bfbb22608 100755 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -13,6 +13,7 @@ AVAILABLE_MODELS = { "batch_gpt4": "BatchGPT4", "claude": "Claude", + "cogvlm2": "CogVLM2", "from_log": "FromLog", "fuyu": "Fuyu", "gemini_api": "GeminiAPI", diff --git a/lmms_eval/models/cogvlm2.py b/lmms_eval/models/cogvlm2.py new file mode 100644 index 000000000..ee346debf --- /dev/null +++ b/lmms_eval/models/cogvlm2.py @@ -0,0 +1,226 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, DistributedType +from accelerate.state import AcceleratorState +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +warnings.filterwarnings("ignore") + +from loguru import logger as eval_logger + + +@register_model("cogvlm2") +class CogVLM2(lmms): + """ + CogVLM2 Model + """ + + def __init__( + self, + pretrained: str = "THUDM/cogvlm2-llama3-chinese-chat-19B", + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = torch.bfloat16, + batch_size: Optional[Union[int, str]] = 1, + trust_remote_code: Optional[bool] = True, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + else: + self._device = device + self.dtype = dtype + self._model = AutoModelForCausalLM.from_pretrained(pretrained, trust_remote_code=trust_remote_code, torch_dtype=dtype, device_map=self._device) + self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + self._config = self._model.config + self.model.eval() + self.model.tie_weights() + self.batch_size_per_gpu = int(batch_size) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._word_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "We have not implemented this function for CogVLM2 yet" + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tok_decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" + assert len(visuals) == 1, "CogVLM2 interface does not support bn_image > 1 for now" + context = contexts[0] + if "" in context: + context = context.replace("", "") + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + image = visuals[0] + input_by_model = self.model.build_conversation_input_ids(self.tokenizer, query=context, history=[], images=[image]) + + inputs = { + "input_ids": input_by_model["input_ids"].unsqueeze(0).to(self.device), + "token_type_ids": input_by_model["token_type_ids"].unsqueeze(0).to(self.device), + "attention_mask": input_by_model["attention_mask"].unsqueeze(0).to(self.device), + "images": [[input_by_model["images"][0].to(self.device).to(self.dtype)]], + } + if "cross_images" in input_by_model and input_by_model["cross_images"]: + inputs["cross_images"] = [[input_by_model["cross_images"][0].to(self.device).to(self.dtype)]] + + try: + outputs = self.model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs["input_ids"].shape[1] :] + response = self.tokenizer.decode(outputs[0]) + response = response.split("")[0] + response = response.split("<|end_of_text|>")[0] + + context = [{"role": "user", "content": context}, {"role": "assistant", "content": response}] + except Exception as e: + eval_logger.error(f"Error {e} in generating") + cont = "" + res.append(response) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), response) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res