diff --git a/.gitignore b/.gitignore index 6bcecb71..419d53da 100755 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ checkpoints/ *.txt pipeline/serve/deploy/otterhd_endpoint.py pipeline/benchmarks/models/llava_model.py +eval_results/ diff --git a/eval_results b/eval_results deleted file mode 100644 index 9d65d82f..00000000 --- a/eval_results +++ /dev/null @@ -1,175 +0,0 @@ -================================================================================ - EVALUATION REPORT -================================================================================ - - -MODEL INFO: {'name': 'otter_image', 'model_path': '/mnt/petrelfs/zhangyuanhan/Otter/checkpoints/otter_llava_sft_nonconv_nogroup'} --------------------------------------------------------------------------------- -[2023-12-20 17:11:37,449] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect) -Imported class: -The current model version is configured for Otter-Image with max_num_frames set to None. -Parameter: lang_encoder.model.embed_tokens.weight, Size: 131.084288 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.3.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.7.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.11.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.15.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.19.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.23.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.27.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.attn_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.ff_gate, Size: 0.000001 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.attn.norm.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.attn.norm.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.attn.to_q.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.attn.to_kv.weight, Size: 1.048576 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.attn.to_out.weight, Size: 2.097152 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.feed_forward.0.weight, Size: 0.004096 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.feed_forward.0.bias, Size: 0.004096 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.feed_forward.1.weight, Size: 67.108864 M -Parameter: lang_encoder.model.layers.31.gated_cross_attn_layer.feed_forward.3.weight, Size: 67.108864 M -Parameter: lang_encoder.lm_head.weight, Size: 131.084288 M -Parameter: perceiver.latents, Size: 0.065536 M -Parameter: perceiver.layers.0.norm_media.weight, Size: 0.001024 M -Parameter: perceiver.layers.0.norm_media.bias, Size: 0.001024 M -Parameter: perceiver.layers.0.norm_latents.weight, Size: 0.001024 M -Parameter: perceiver.layers.0.norm_latents.bias, Size: 0.001024 M -Parameter: perceiver.layers.0.to_q.weight, Size: 0.524288 M -Parameter: perceiver.layers.0.to_kv.weight, Size: 1.048576 M -Parameter: perceiver.layers.0.to_out.weight, Size: 0.524288 M -Parameter: perceiver.layers.0.feed_forward.0.weight, Size: 0.001024 M -Parameter: perceiver.layers.0.feed_forward.0.bias, Size: 0.001024 M -Parameter: perceiver.layers.0.feed_forward.1.weight, Size: 4.194304 M -Parameter: perceiver.layers.0.feed_forward.3.weight, Size: 4.194304 M -Parameter: perceiver.layers.1.norm_media.weight, Size: 0.001024 M -Parameter: perceiver.layers.1.norm_media.bias, Size: 0.001024 M -Parameter: perceiver.layers.1.norm_latents.weight, Size: 0.001024 M -Parameter: perceiver.layers.1.norm_latents.bias, Size: 0.001024 M -Parameter: perceiver.layers.1.to_q.weight, Size: 0.524288 M -Parameter: perceiver.layers.1.to_kv.weight, Size: 1.048576 M -Parameter: perceiver.layers.1.to_out.weight, Size: 0.524288 M -Parameter: perceiver.layers.1.feed_forward.0.weight, Size: 0.001024 M -Parameter: perceiver.layers.1.feed_forward.0.bias, Size: 0.001024 M -Parameter: perceiver.layers.1.feed_forward.1.weight, Size: 4.194304 M -Parameter: perceiver.layers.1.feed_forward.3.weight, Size: 4.194304 M -Parameter: perceiver.layers.2.norm_media.weight, Size: 0.001024 M -Parameter: perceiver.layers.2.norm_media.bias, Size: 0.001024 M -Parameter: perceiver.layers.2.norm_latents.weight, Size: 0.001024 M -Parameter: perceiver.layers.2.norm_latents.bias, Size: 0.001024 M -Parameter: perceiver.layers.2.to_q.weight, Size: 0.524288 M -Parameter: perceiver.layers.2.to_kv.weight, Size: 1.048576 M -Parameter: perceiver.layers.2.to_out.weight, Size: 0.524288 M -Parameter: perceiver.layers.2.feed_forward.0.weight, Size: 0.001024 M -Parameter: perceiver.layers.2.feed_forward.0.bias, Size: 0.001024 M -Parameter: perceiver.layers.2.feed_forward.1.weight, Size: 4.194304 M -Parameter: perceiver.layers.2.feed_forward.3.weight, Size: 4.194304 M -Parameter: perceiver.layers.3.norm_media.weight, Size: 0.001024 M -Parameter: perceiver.layers.3.norm_media.bias, Size: 0.001024 M -Parameter: perceiver.layers.3.norm_latents.weight, Size: 0.001024 M -Parameter: perceiver.layers.3.norm_latents.bias, Size: 0.001024 M -Parameter: perceiver.layers.3.to_q.weight, Size: 0.524288 M -Parameter: perceiver.layers.3.to_kv.weight, Size: 1.048576 M -Parameter: perceiver.layers.3.to_out.weight, Size: 0.524288 M -Parameter: perceiver.layers.3.feed_forward.0.weight, Size: 0.001024 M -Parameter: perceiver.layers.3.feed_forward.0.bias, Size: 0.001024 M -Parameter: perceiver.layers.3.feed_forward.1.weight, Size: 4.194304 M -Parameter: perceiver.layers.3.feed_forward.3.weight, Size: 4.194304 M -Parameter: perceiver.layers.4.norm_media.weight, Size: 0.001024 M -Parameter: perceiver.layers.4.norm_media.bias, Size: 0.001024 M -Parameter: perceiver.layers.4.norm_latents.weight, Size: 0.001024 M -Parameter: perceiver.layers.4.norm_latents.bias, Size: 0.001024 M -Parameter: perceiver.layers.4.to_q.weight, Size: 0.524288 M -Parameter: perceiver.layers.4.to_kv.weight, Size: 1.048576 M -Parameter: perceiver.layers.4.to_out.weight, Size: 0.524288 M -Parameter: perceiver.layers.4.feed_forward.0.weight, Size: 0.001024 M -Parameter: perceiver.layers.4.feed_forward.0.bias, Size: 0.001024 M -Parameter: perceiver.layers.4.feed_forward.1.weight, Size: 4.194304 M -Parameter: perceiver.layers.4.feed_forward.3.weight, Size: 4.194304 M -Parameter: perceiver.layers.5.norm_media.weight, Size: 0.001024 M -Parameter: perceiver.layers.5.norm_media.bias, Size: 0.001024 M -Parameter: perceiver.layers.5.norm_latents.weight, Size: 0.001024 M -Parameter: perceiver.layers.5.norm_latents.bias, Size: 0.001024 M -Parameter: perceiver.layers.5.to_q.weight, Size: 0.524288 M -Parameter: perceiver.layers.5.to_kv.weight, Size: 1.048576 M -Parameter: perceiver.layers.5.to_out.weight, Size: 0.524288 M -Parameter: perceiver.layers.5.feed_forward.0.weight, Size: 0.001024 M -Parameter: perceiver.layers.5.feed_forward.0.bias, Size: 0.001024 M -Parameter: perceiver.layers.5.feed_forward.1.weight, Size: 4.194304 M -Parameter: perceiver.layers.5.feed_forward.3.weight, Size: 4.194304 M -Parameter: perceiver.norm.weight, Size: 0.001024 M -Parameter: perceiver.norm.bias, Size: 0.001024 M -Total Trainable param: 1.441004 B -Imported class: - -DATASET: MMEDataset --------------------- -=========== Cognition =========== diff --git a/pipeline/accelerate_configs/accelerate_config_ddp.yaml b/pipeline/accelerate_configs/accelerate_config_ddp.yaml index 9cc01be2..90b96540 100755 --- a/pipeline/accelerate_configs/accelerate_config_ddp.yaml +++ b/pipeline/accelerate_configs/accelerate_config_ddp.yaml @@ -5,7 +5,7 @@ machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 -num_processes: 2 +num_processes: 1 rdzv_backend: static same_network: false tpu_use_cluster: false diff --git a/pipeline/accelerate_configs/accelerate_config_zero2_pretrain.yaml b/pipeline/accelerate_configs/accelerate_config_zero2_pretrain.yaml new file mode 100755 index 00000000..164d18db --- /dev/null +++ b/pipeline/accelerate_configs/accelerate_config_zero2_pretrain.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 4 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +use_cpu: false \ No newline at end of file diff --git a/pipeline/accelerate_configs/accelerate_config_zero2_slurm.yaml b/pipeline/accelerate_configs/accelerate_config_zero2_slurm.yaml index 97939ea9..1dcf0bf6 100755 --- a/pipeline/accelerate_configs/accelerate_config_zero2_slurm.yaml +++ b/pipeline/accelerate_configs/accelerate_config_zero2_slurm.yaml @@ -1,7 +1,7 @@ compute_environment: LOCAL_MACHINE deepspeed_config: deepspeed_multinode_launcher: standard - gradient_accumulation_steps: 4 + gradient_accumulation_steps: 1 gradient_clipping: 1.0 offload_optimizer_device: none offload_param_device: none diff --git a/pipeline/accelerate_configs/accelerate_config_zero3.yaml b/pipeline/accelerate_configs/accelerate_config_zero3.yaml index cb2f2ef9..a9646504 100755 --- a/pipeline/accelerate_configs/accelerate_config_zero3.yaml +++ b/pipeline/accelerate_configs/accelerate_config_zero3.yaml @@ -1,11 +1,12 @@ compute_environment: LOCAL_MACHINE deepspeed_config: - gradient_accumulation_steps: 4 + gradient_accumulation_steps: 16 gradient_clipping: 1.0 offload_optimizer_device: none offload_param_device: none zero3_init_flag: true zero3_save_16bit_model: true + stage3_gather_16bit_weights_on_model_save: true zero_stage: 3 distributed_type: DEEPSPEED fsdp_config: {} @@ -15,5 +16,5 @@ main_process_port: 20333 main_training_function: main mixed_precision: bf16 num_machines: 1 -num_processes: 4 +num_processes: 8 use_cpu: false \ No newline at end of file diff --git a/pipeline/benchmarks/datasets/base_eval_dataset.py b/pipeline/benchmarks/datasets/base_eval_dataset.py index 597c733a..fa1bce2c 100644 --- a/pipeline/benchmarks/datasets/base_eval_dataset.py +++ b/pipeline/benchmarks/datasets/base_eval_dataset.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from PIL import Image from typing import Dict, List, Any -import base64 import io +import os +import base64 import importlib AVAILABLE_EVAL_DATASETS: Dict[str, str] = { @@ -17,20 +18,6 @@ "mmmu": "MMMUDataset", } -def get_pil_image(raw_image_data) -> Image.Image: - if isinstance(raw_image_data, Image.Image): - return raw_image_data - - elif isinstance(raw_image_data, dict) and "bytes" in raw_image_data: - return Image.open(io.BytesIO(raw_image_data["bytes"])) - - elif isinstance(raw_image_data, str): # Assuming this is a base64 encoded string - image_bytes = base64.b64decode(raw_image_data) - return Image.open(io.BytesIO(image_bytes)) - - else: - raise ValueError("Unsupported image data format") - class BaseEvalDataset(ABC): def __init__(self, name: str, dataset_path: str, *, max_batch_size: int = 1): @@ -38,6 +25,17 @@ def __init__(self, name: str, dataset_path: str, *, max_batch_size: int = 1): self.dataset_path = dataset_path self.max_batch_size = max_batch_size + def get_pil_image(self, raw_image_data) -> Image.Image: + if isinstance(raw_image_data, Image.Image): + return raw_image_data.convert("RGB") + elif isinstance(raw_image_data, dict) and "bytes" in raw_image_data: + return Image.open(io.BytesIO(raw_image_data["bytes"])).convert("RGB") + elif isinstance(raw_image_data, str): # Assuming this is a base64 encoded string + image_bytes = base64.b64decode(raw_image_data) + return Image.open(io.BytesIO(image_bytes)).convert("RGB") + else: + raise ValueError("Unsupported image data format") + def evaluate(self, model, **kwargs): return self._evaluate(model, **kwargs) # batch = min(model.max_batch_size, self.max_batch_size) @@ -50,10 +48,18 @@ def evaluate(self, model, **kwargs): @abstractmethod def _evaluate(self, model: str): pass + + # @abstractmethod # TODO: add back after every dataset has been updated + def evaluate_multi_gpu(self, model, model_version, rank, world_size): + pass - -def load_dataset(dataset_name: str, dataset_args: Dict[str, str] = {}) -> BaseEvalDataset: - assert dataset_name in AVAILABLE_EVAL_DATASETS, f"{dataset_name} is not an available eval dataset." +def load_dataset( + dataset_name: str, + dataset_args: Dict[str, str] = {}, +) -> BaseEvalDataset: + assert ( + dataset_name in AVAILABLE_EVAL_DATASETS + ), f"{dataset_name} is not an available eval dataset." module_path = "pipeline.benchmarks.datasets." + dataset_name dataset_formal_name = AVAILABLE_EVAL_DATASETS[dataset_name] imported_module = importlib.import_module(module_path) @@ -63,4 +69,4 @@ def load_dataset(dataset_name: str, dataset_args: Dict[str, str] = {}) -> BaseEv # get dataset args without "name" init_args = dataset_args.copy() init_args.pop("name") - return dataset_class(**init_args) + return dataset_class(**init_args) \ No newline at end of file diff --git a/pipeline/benchmarks/datasets/mme.py b/pipeline/benchmarks/datasets/mme.py index 49cfaabd..5a5f19ad 100644 --- a/pipeline/benchmarks/datasets/mme.py +++ b/pipeline/benchmarks/datasets/mme.py @@ -2,17 +2,24 @@ import io from PIL import Image import json -from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + confusion_matrix, +) import os import numpy as np from datasets import load_dataset from typing import Union -from .base_eval_dataset import BaseEvalDataset +from .base_eval_dataset import BaseEvalDataset from tqdm import tqdm import datetime import pytz -utc_plus_8 = pytz.timezone("Asia/Singapore") # You can also use 'Asia/Shanghai', 'Asia/Taipei', etc. +utc_plus_8 = pytz.timezone( + "Asia/Singapore" +) # You can also use 'Asia/Shanghai', 'Asia/Taipei', etc. utc_now = pytz.utc.localize(datetime.datetime.utcnow()) utc_plus_8_time = utc_now.astimezone(utc_plus_8) @@ -27,18 +34,18 @@ "scene", "landmark", "artwork", - "ocr", + "OCR", + ], + "Cognition": [ + "commonsense_reasoning", + "numerical_calculation", + "text_translation", + "code_reasoning", ], - "Cognition": ["commonsense", "numerical", "text", "code"], } class MMEDataset(BaseEvalDataset): - def decode_base64_to_image(self, base64_string): - image_data = base64.b64decode(base64_string) - image = Image.open(io.BytesIO(image_data)) - return image - def __init__( self, data_path: str = "Otter-AI/MME", @@ -47,7 +54,8 @@ def __init__( default_output_path: str = "./logs/MME", split: str = "test", debug: bool = False, - prompt: str = None, + replace_prompt: str = "Please answer yes or no.", + prompt: str = "\nAnswer the question using a single word or phrase.", ): super().__init__("MMEDataset", data_path) @@ -56,18 +64,23 @@ def __init__( self.data = load_dataset(data_path, split=split, cache_dir=cache_dir) self.debug = debug self.prompt = prompt + self.replace_prompt = replace_prompt self.category_data = {} # for idx in range(len(self.ids)): for item in tqdm(self.data, desc="Loading data"): - id = item["id"] - category = id.split("_")[0].lower() - question = item["instruction"] + question_id = item["question_id"] + category = item["category"] + question = item["question"] answer = item["answer"] - image_id = item["image_ids"][0] - image = item["images"][0] + image = item["image"] - data = {"question": question, "answer": answer, "image": image} + data = { + "question_id": question_id, + "question": question, + "answer": answer, + "image": image, + } if category in eval_type_dict["Cognition"]: eval_type = "Cognition" @@ -82,10 +95,10 @@ def __init__( if category not in self.category_data[eval_type]: self.category_data[eval_type][category] = {} - if image_id not in self.category_data[eval_type][category]: - self.category_data[eval_type][category][image_id] = [] + if question_id not in self.category_data[eval_type][category]: + self.category_data[eval_type][category][question_id] = [] - self.category_data[eval_type][category][image_id].append(data) + self.category_data[eval_type][category][question_id].append(data) def parse_pred_ans(self, pred_ans): pred_ans = pred_ans.lower().strip().replace(".", "") @@ -153,38 +166,56 @@ def compute_metric(self, gts, preds): def _evaluate(self, model): model_score_dict = {} - self.default_output_path = os.path.join(self.default_output_path, f"{model.name}_{self.cur_datetime}") + self.default_output_path = os.path.join( + self.default_output_path, f"{self.cur_datetime}" + ) if not os.path.exists(self.default_output_path): os.makedirs(self.default_output_path) + debug_jsonl_file = open(os.path.join(self.default_output_path, "output.jsonl"), "w") + for eval_type in self.category_data.keys(): print("===========", eval_type, "===========") scores = 0 task_score_dict = {} - for task_name in tqdm(self.category_data[eval_type].keys(), desc=f"Evaluating {eval_type}"): + for task_name in tqdm( + self.category_data[eval_type].keys(), desc=f"Evaluating {eval_type}" + ): img_num = len(self.category_data[eval_type][task_name]) task_other_ans_num = 0 task_score = 0 acc_plus_correct_num = 0 gts = [] preds = [] - for image_pair in tqdm(self.category_data[eval_type][task_name].values(), desc=f"Evaluating {eval_type} {task_name}"): + for image_pair in tqdm( + self.category_data[eval_type][task_name].values(), + desc=f"Evaluating {eval_type} {task_name}", + ): assert len(image_pair) == 2 img_correct_num = 0 for item in image_pair: question = item["question"] - image = item["image"] - gt_ans = item["answer"].lower().strip().replace(".", "") + if self.replace_prompt is not None: + question = question.replace(self.replace_prompt, "").strip() if self.prompt is not None: question = f"{question}{self.prompt}" + + image = self.get_pil_image(item["image"]) + gt_ans = item["answer"].lower().strip().replace(".", "") response = model.generate(question, image) if self.debug: - print(f"\n# Query: {question}") - print(f"\n# Response: {response}") - pred_ans = self.parse_pred_ans(response) + # print(f"\n# Query: {question}") + # print(f"\n# Response: {response}") + jsonl_data = { + "question_id": f"{item['question_id']}", + "prompt": question, + "text": response, + } + debug_jsonl_file.write(json.dumps(jsonl_data) + "\n") + pred_ans = self.parse_pred_ans(response) assert gt_ans in ["yes", "no"] assert pred_ans in ["yes", "no", "other"] @@ -212,10 +243,12 @@ def _evaluate(self, model): task_score_dict[task_name] = task_score scores += task_score - output_path = os.path.join(self.default_output_path, f"{task_name}.json") + output_path = os.path.join( + self.default_output_path, f"{task_name}.json" + ) with open(output_path, "w") as f: json.dump(metric_dict, f) print(f"total score: {scores}") for task_name, score in task_score_dict.items(): - print(f"\t {task_name} score: {score}") + print(f"\t {task_name} score: {score}") \ No newline at end of file diff --git a/pipeline/benchmarks/models/base_model.py b/pipeline/benchmarks/models/base_model.py index ab78af82..f9e53808 100644 --- a/pipeline/benchmarks/models/base_model.py +++ b/pipeline/benchmarks/models/base_model.py @@ -20,6 +20,7 @@ "llava_model": "LLaVA_Model", "instructblip": "InstructBLIP", "gpt4v": "OpenAIGPT4Vision", + "otter_image_llava": "OtterImageLlava" } diff --git a/pipeline/benchmarks/models/llava_model.py b/pipeline/benchmarks/models/llava_model.py index a365e490..b380ec3a 100644 --- a/pipeline/benchmarks/models/llava_model.py +++ b/pipeline/benchmarks/models/llava_model.py @@ -30,6 +30,7 @@ def __init__( ): super().__init__(model_name, model_path) init_model_name = get_model_name_from_path(model_path) + # import pdb;pdb.set_trace() self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, init_model_name) self.conv_mode = conv_mode self.temperature = temperature diff --git a/pipeline/benchmarks/models/otter_image.py b/pipeline/benchmarks/models/otter_image.py index 20f9ccda..5ad8f785 100644 --- a/pipeline/benchmarks/models/otter_image.py +++ b/pipeline/benchmarks/models/otter_image.py @@ -11,16 +11,20 @@ from src.otter_ai import OtterForConditionalGeneration from .base_model import BaseModel +from torchvision import transforms # Disable warnings requests.packages.urllib3.disable_warnings() +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" def get_pil_image(raw_image_data) -> Image.Image: + # import pdb;pdb.set_trace() if isinstance(raw_image_data, Image.Image): - return raw_image_data + return raw_image_data.convert('RGB') else: - return Image.open(BytesIO(raw_image_data["bytes"])) + return Image.open(BytesIO(raw_image_data["bytes"])).convert('RGB') def get_formatted_prompt(prompt: str) -> str: @@ -41,10 +45,40 @@ def __init__(self, model_path="luodian/OTTER-Image-MPT7B", load_bit="bf16"): precision["torch_dtype"] = torch.float16 elif load_bit == "fp32": precision["torch_dtype"] = torch.float32 - self.model = OtterForConditionalGeneration.from_pretrained(model_path, device_map="sequential", **precision) + # quantization_config = BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_compute_dtype=torch.bfloat16 + # ) + self.model = OtterForConditionalGeneration.from_pretrained(model_path, device_map="auto", **precision) + dirname = os.path.dirname(model_path) + state_dict = {} + counter = 0 + # for _ in os.listdir(model_path): + # if "pytorch_model" in _ and "-of-" in _: + # counter += 1 + # for _ in range(counter): + # state_dict.update(torch.load(f"{dirname}/pytorch_model-0000{_+1}-of-0000{counter}.bin", map_location="cpu")) + # load_msg = self.model.load_state_dict( + # state_dict, + # False, + # ) + # print(load_mzsg) self.model.text_tokenizer.padding_side = "left" self.tokenizer = self.model.text_tokenizer - self.image_processor = transformers.CLIPImageProcessor() + self.patch_image_size = 336 + self.mean = [0.481, 0.458, 0.408] + self.std = [0.269, 0.261, 0.276] + # self.image_processor = transformers.CLIPImageProcessor() + self.image_processor = transforms.Compose( + [ + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ] + ) self.model.eval() def generate(self, question: str, raw_image_data): @@ -53,7 +87,12 @@ def generate(self, question: str, raw_image_data): if input_data.size == (224, 224) and not any(input_data.getdata()): # Check if image is blank 224x224 image vision_x = torch.zeros(1, 1, 1, 3, 224, 224, dtype=next(self.model.parameters()).dtype) else: - vision_x = self.image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0) + # vision_x = self.image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0) + try: + vision_x = self.image_processor(input_data).unsqueeze(0).unsqueeze(0).unsqueeze(0) + except: + import pdb;pdb.set_trace() + # import pdb;pdb.set_trace() else: raise ValueError("Invalid input data. Expected PIL Image.") @@ -68,15 +107,16 @@ def generate(self, question: str, raw_image_data): vision_x = vision_x.to(dtype=model_dtype) lang_x_input_ids = lang_x["input_ids"] lang_x_attention_mask = lang_x["attention_mask"] - + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): generated_text = self.model.generate( vision_x=vision_x.to(self.model.device), lang_x=lang_x_input_ids.to(self.model.device), attention_mask=lang_x_attention_mask.to(self.model.device), - max_new_tokens=512, + max_new_tokens=128, num_beams=3, no_repeat_ngram_size=3, pad_token_id=self.tokenizer.eos_token_id, + early_stopping=True, ) parsed_output = self.model.text_tokenizer.decode(generated_text[0]).split("")[-1].split("<|endofchunk|>")[0].strip() return parsed_output diff --git a/pipeline/benchmarks/models/otter_image_llava.py b/pipeline/benchmarks/models/otter_image_llava.py new file mode 100644 index 00000000..d153fdb3 --- /dev/null +++ b/pipeline/benchmarks/models/otter_image_llava.py @@ -0,0 +1,174 @@ +import mimetypes +import os +from io import BytesIO +from typing import Union +import cv2 +import requests +import torch +import transformers +from PIL import Image + +from src.otter_ai import OtterForConditionalGeneration +from .base_model import BaseModel + +from torchvision import transforms + +# Disable warnings +requests.packages.urllib3.disable_warnings() + +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +def get_pil_image(raw_image_data) -> Image.Image: + # import pdb;pdb.set_trace() + if isinstance(raw_image_data, Image.Image): + return raw_image_data.convert('RGB') + else: + return Image.open(BytesIO(raw_image_data["bytes"])).convert('RGB') + + +def get_formatted_prompt(prompt: str) -> str: + return f"User: {prompt} GPT:" + + +def get_formatted_forward_prompt(question: str, answer: str) -> str: + return f"User: {question} GPT: {answer}" + + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +def process_image(image, image_processor): + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return image + +class OtterImageLlava(BaseModel): + def __init__(self, model_path="luodian/OTTER-Image-MPT7B", load_bit="bf16"): + super().__init__("otter", model_path) + precision = {} + if load_bit == "bf16": + precision["torch_dtype"] = torch.bfloat16 + elif load_bit == "fp16": + precision["torch_dtype"] = torch.float16 + elif load_bit == "fp32": + precision["torch_dtype"] = torch.float32 + # quantization_config = BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_compute_dtype=torch.bfloat16 + # ) + self.model = OtterForConditionalGeneration.from_pretrained(model_path, device_map="auto", **precision) + self.image_processor = self.model.image_processor + dirname = os.path.dirname(model_path) + state_dict = {} + counter = 0 + # for _ in os.listdir(model_path): + # if "pytorch_model" in _ and "-of-" in _: + # counter += 1 + # for _ in range(counter): + # state_dict.update(torch.load(f"{dirname}/pytorch_model-0000{_+1}-of-0000{counter}.bin", map_location="cpu")) + # load_msg = self.model.load_state_dict( + # state_dict, + # False, + # ) + # print(load_mzsg) + self.model.text_tokenizer.padding_side = "left" + self.tokenizer = self.model.text_tokenizer + self.patch_image_size = 336 + # self.mean = [0.481, 0.458, 0.408] + # self.std = [0.269, 0.261, 0.276] + # self.image_processor = transformers.CLIPImageProcessor() + # self.image_processor = transforms.Compose( + # [ + # transforms.Resize( + # (self.patch_image_size, self.patch_image_size), + # interpolation=transforms.InterpolationMode.BICUBIC, + # ), + # transforms.ToTensor(), + # transforms.Normalize(mean=self.mean, std=self.std), + # ] + # ) + self.patch_resize_transform = lambda x,y: process_image(x,y) + self.model.eval() + + def generate(self, question: str, raw_image_data): + input_data = get_pil_image(raw_image_data) + if isinstance(input_data, Image.Image): + if input_data.size == (224, 224) and not any(input_data.getdata()): # Check if image is blank 224x224 image + vision_x = torch.zeros(1, 1, 1, 3, 224, 224, dtype=next(self.model.parameters()).dtype) + else: + # vision_x = self.image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0) + try: + vision_x = self.patch_resize_transform(input_data,self.image_processor).unsqueeze(0).unsqueeze(0).unsqueeze(0) + except: + import pdb;pdb.set_trace() + # import pdb;pdb.set_trace() + else: + raise ValueError("Invalid input data. Expected PIL Image.") + + lang_x = self.model.text_tokenizer( + [ + get_formatted_prompt(question), + ], + return_tensors="pt", + ) + + model_dtype = next(self.model.parameters()).dtype + vision_x = vision_x.to(dtype=model_dtype) + lang_x_input_ids = lang_x["input_ids"] + lang_x_attention_mask = lang_x["attention_mask"] + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + generated_text = self.model.generate( + vision_x=vision_x.to(self.model.device), + lang_x=lang_x_input_ids.to(self.model.device), + attention_mask=lang_x_attention_mask.to(self.model.device), + max_new_tokens=128, + num_beams=3, + no_repeat_ngram_size=3, + pad_token_id=self.tokenizer.eos_token_id, + early_stopping=True, + ) + parsed_output = self.model.text_tokenizer.decode(generated_text[0]).split("")[-1].split("<|endofchunk|>")[0].strip() + return parsed_output + + def get_vision_x(self, input_data): + if isinstance(input_data, Image.Image): + if input_data.size == (224, 224) and not any(input_data.getdata()): # Check if image is blank 224x224 image + vision_x = torch.zeros(1, 1, 1, 3, 224, 224, dtype=next(self.model.parameters()).dtype) + else: + vision_x = self.image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0) + else: + raise ValueError("Invalid input data. Expected PIL Image.") + model_dtype = next(self.model.parameters()).dtype + vision_x = vision_x.to(dtype=model_dtype) + return vision_x + + def eval_forward(self, question, answer, image): + query = get_formatted_forward_prompt(question, answer) + tokens = self.tokenizer(query, return_tensors="pt") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + with torch.no_grad(): + vision_x = self.get_vision_x(image) + loss = self.model(vision_x=vision_x.to(self.model.device), lang_x=input_ids.to(self.model.device), attention_mask=attention_mask.to(self.model.device))[0] + return loss + + +if __name__ == "__main__": + model = OtterImage("/data/pufanyi/training_data/checkpoints/OTTER-Image-MPT7B") + image = Image.open("/data/pufanyi/project/Otter-2/pipeline/evaluation/test_data/test.jpg") + response = model.generate("What is this?", image) + print(response) + response = model.generate("What is this?", image) + print(response) diff --git a/pipeline/mimicit_utils/data.py b/pipeline/mimicit_utils/data.py index c15b2c07..f90ca70d 100755 --- a/pipeline/mimicit_utils/data.py +++ b/pipeline/mimicit_utils/data.py @@ -680,7 +680,7 @@ def get_mimicit_dataset(args, image_processor, tokenizer, epoch=0, floor=False): # Converting multiple types of mimic-it datasets into a unified format dataset for key, item in dataset_info.items(): if item != {}: # if the category is not empty - unified_dataset = MimicitDataset(args, dataset_info=dataset_info[key], task_group=key) + unified_dataset = MimicitDataset(args, image_processor=image_processor, dataset_info=dataset_info[key], task_group=key) unified_datasets.append(unified_dataset) # round_fn = math.floor if floor else math.ceil diff --git a/pipeline/mimicit_utils/llava_pretrain_dataset.py b/pipeline/mimicit_utils/llava_pretrain_dataset.py index d4d7a574..c908db15 100755 --- a/pipeline/mimicit_utils/llava_pretrain_dataset.py +++ b/pipeline/mimicit_utils/llava_pretrain_dataset.py @@ -61,6 +61,24 @@ def random_seed(seed, *addl_seeds): import numpy as np +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +def process_image(image, image_processor): + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return image + def resample_data(data, N): # If N is equal to the length of the list, return the list @@ -97,11 +115,12 @@ def extract_rgb_number(path): class LlavaPretrainDataset(Dataset): - def __init__(self, args, dataset_info, task_group=""): + def __init__(self, args, image_processor, dataset_info, task_group=""): self.args = args self.tokenizer = args.tokenizer self.keep_symbols = args.keep_symbols if hasattr(args, "keep_symbols") else True self.task_group = task_group + self.image_processor = image_processor # remove more symbols in the question and answer, make the question and answer more clean and training loss more stable. self.mimicit_paths = [] @@ -132,16 +151,17 @@ def __init__(self, args, dataset_info, task_group=""): (self.mean, self.std) = (IDEFICS_STANDARD_MEAN, IDEFICS_STANDARD_STD) if args.model_name == "idefics" else (FLAMINGO_MEAN, FLAMINGO_STD) if args.model_name == "otter" or args.model_name == "fuyu": - self.patch_resize_transform = transforms.Compose( - [ - transforms.Resize( - (args.patch_image_size, args.patch_image_size), - interpolation=transforms.InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize(mean=self.mean, std=self.std), - ] - ) + self.patch_resize_transform = lambda x,y: process_image(x,y).squeeze(0) + # self.patch_resize_transform = transforms.Compose( + # [ + # transforms.Resize( + # (args.patch_image_size, args.patch_image_size), + # interpolation=transforms.InterpolationMode.BICUBIC, + # ), + # transforms.ToTensor(), + # transforms.Normalize(mean=self.mean, std=self.std), + # ] + # ) elif args.model_name == "idefics": checkpoint_path = os.environ.get("IDEFICS_LOCAL_PATH", "HuggingFaceM4/idefics-9b-instruct") master_print(f"Local Idefics Checkpoints Path: {checkpoint_path}") @@ -315,7 +335,7 @@ def process_images(self, image_ids, is_video=False): # import pdb;pdb.set_trace() cur_image = Image.open(f"{self.images_paths[0]}/{cur_image_id}").convert("RGB") - cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0) + cur_patch_image = self.patch_resize_transform(cur_image,self.image_processor).unsqueeze(0) if len(patch_images) == 0: patch_images = cur_patch_image else: @@ -607,7 +627,9 @@ def preload_dataset(path): args.tokenizer = text_tokenizer + from transformers import CLIPImageProcessor + image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") dataset_info = preload_dataset("/mnt/petrelfs/zhangyuanhan/Otter/shared_scripts/llava_pretrain.yaml") - dataset = LlavaPretrainDataset(args, dataset_info["IMAGE_TEXT"], "IMAGE_TEXT") + dataset = LlavaPretrainDataset(args, image_processor,dataset_info["IMAGE_TEXT"], "IMAGE_TEXT") for _ in dataset: print(_) \ No newline at end of file diff --git a/pipeline/mimicit_utils/mimicit_dataset.py b/pipeline/mimicit_utils/mimicit_dataset.py index 952d1191..e7edf9da 100755 --- a/pipeline/mimicit_utils/mimicit_dataset.py +++ b/pipeline/mimicit_utils/mimicit_dataset.py @@ -63,6 +63,24 @@ def random_seed(seed, *addl_seeds): import numpy as np +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +def process_image(image, image_processor): + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return image + def resample_data(data, N): # If N is equal to the length of the list, return the list if N == -1 or N == 0: @@ -98,11 +116,12 @@ def extract_rgb_number(path): class MimicitDataset(Dataset): - def __init__(self, args, dataset_info, task_group=""): + def __init__(self, args, image_processor,dataset_info, task_group=""): self.args = args self.tokenizer = args.tokenizer self.keep_symbols = args.keep_symbols if hasattr(args, "keep_symbols") else True self.task_group = task_group + self.image_processor = image_processor # remove more symbols in the question and answer, make the question and answer more clean and training loss more stable. self.mimicit_paths = [] @@ -130,17 +149,18 @@ def __init__(self, args, dataset_info, task_group=""): self.resample_frames = args.resample_frames self.wrap_sys = f"<>\nYou are a helpful vision language assistant. You are able to understand the visual content\n<>\n\n" (self.mean, self.std) = (IDEFICS_STANDARD_MEAN, IDEFICS_STANDARD_STD) if args.model_name == "idefics" else (FLAMINGO_MEAN, FLAMINGO_STD) - if args.model_name == "otter" or args.model_name == "fuyu": - self.patch_resize_transform = transforms.Compose( - [ - transforms.Resize( - (args.patch_image_size, args.patch_image_size), - interpolation=transforms.InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize(mean=self.mean, std=self.std), - ] - ) + if args.model_name == "otter": + self.patch_resize_transform = lambda x,y: process_image(x,y).squeeze(0) + # self.patch_resize_transform = transforms.Compose( + # [ + # transforms.Resize( + # (args.patch_image_size, args.patch_image_size), + # interpolation=transforms.InterpolationMode.BICUBIC, + # ), + # transforms.ToTensor(), + # transforms.Normalize(mean=self.mean, std=self.std), + # ] + # ) elif args.model_name == "idefics": checkpoint_path = os.environ.get("IDEFICS_LOCAL_PATH", "HuggingFaceM4/idefics-9b-instruct") master_print(f"Local Idefics Checkpoints Path: {checkpoint_path}") @@ -182,8 +202,10 @@ def __init__(self, args, dataset_info, task_group=""): # Load the dataset assert os.path.exists(cur_mimicit_path), f"Error: The local mimicit_path {cur_mimicit_path} not exists!" + # import pdb;pdb.set_trace() with open(cur_mimicit_path, "rb") as f: cur_mimicit_data = orjson.loads(f.read())["data"] + # import pdb;pdb.set_trace() self.dataset.update(cur_mimicit_data) # Load the train_config @@ -352,7 +374,7 @@ def process_images(self, image_ids, is_video=False, in_context=False): if self.args.model_name == self.args.model_name == "fuyu": pil_images.append(cur_image) # fuyu doesnt need following process. else: - cur_patch_image = self.patch_resize_transform(cur_image).unsqueeze(0) + cur_patch_image = self.patch_resize_transform(cur_image,self.image_processor).unsqueeze(0) if len(patch_images) == 0: patch_images = cur_patch_image else: @@ -403,7 +425,7 @@ def process_general(self, instruction_id, in_context_example_ids, task_group): # patch_images = torch.tensor([]) # import pdb;pdb.set_trace() if task_group == "TEXT_ONLY": - patch_images = torch.zeros(3, self.patch_image_size, self.patch_image_size).unsqueeze(0).unsqueeze(0) + patch_images = torch.zeros(3, self.patch_image_size, self.patch_image_size).double().unsqueeze(0).unsqueeze(0) pil_images = [Image.fromarray(patch_images[0, 0].numpy().astype(np.uint8).transpose(1, 2, 0))] elif task_group == "IMAGE_TEXT": pil_images, patch_images = self.process_images(all_image_ids, is_video=False, in_context=False) @@ -689,6 +711,8 @@ def preload_dataset(path): sys.path.append("/mnt/petrelfs/zhangyuanhan/Otter") from pipeline.train.train_utils import DistributedProxySampler from itertools import cycle + from transformers import CLIPImageProcessor + parser = argparse.ArgumentParser(description="Main training script for the model") @@ -711,7 +735,9 @@ def preload_dataset(path): args.tokenizer = text_tokenizer dataset_info = preload_dataset("/mnt/petrelfs/zhangyuanhan/Otter/shared_scripts/llava_sft_noconv_nogrounp.yaml") - dataset = MimicitDataset(args, dataset_info["IMAGE_TEXT"], "IMAGE_TEXT") + image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") + + dataset = MimicitDataset(args, image_processor,dataset_info["IMAGE_TEXT"], "IMAGE_TEXT") sampler = RandomSampler(dataset, replacement=True, num_samples=len(dataset)) # sampler = DistributedProxySampler(sampler, num_replicas=8, rank=7) # import pdb;pdb.set_trace() diff --git a/pipeline/train/instruction_following.py b/pipeline/train/instruction_following.py index 7b2b6f9b..e9873183 100755 --- a/pipeline/train/instruction_following.py +++ b/pipeline/train/instruction_following.py @@ -309,7 +309,7 @@ def masking(masking_number: int = -100): def main(): args = parse_args() - verify_yaml(args) + # verify_yaml(args) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16", @@ -337,8 +337,20 @@ def main(): ) args.tokenizer = model.text_tokenizer tokenizer = model.text_tokenizer - image_processor = CLIPImageProcessor() + image_processor = model.image_processor # model.gradient_checkpointing_enable() + if args.enable_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=32, + lora_dropout=0.05, + task_type=TaskType.CAUSAL_LM, + target_modules=["q_proj", "v_proj"], + ) + model.lang_encoder = get_peft_model(model.lang_encoder, lora_config) + model.lang_encoder.print_trainable_parameters() + model.lang_encoder.__class__.__name__ = f"{model.lang_encoder.__class__.__name__}LoRA" + master_print(f"Init LoRA model with config {lora_config}") elif args.model_name.lower() == "flamingo": model = FlamingoForConditionalGeneration.from_pretrained( @@ -462,7 +474,7 @@ def main(): name=args.run_name, ) - mimicit_loaders = get_data(args, image_processor, tokenizer, "mimicit") + mimicit_loaders = get_data(args, image_processor, tokenizer, args.dataset_type) total_training_steps = sum(len(dataloader) for dataloader in mimicit_loaders) * args.num_epochs resume_from_epoch = 0 args.external_save_dir = os.path.join(args.external_save_dir, args.run_name) if args.external_save_dir else args.run_name diff --git a/pipeline/train/pretrain_llava.py b/pipeline/train/pretrain_llava.py index 6cf7a542..c30097d8 100755 --- a/pipeline/train/pretrain_llava.py +++ b/pipeline/train/pretrain_llava.py @@ -39,12 +39,11 @@ master_print, random_seed, save_checkpoint, - save_final_weights, + save_hf_weights, verify_yaml, - get_weights_for_dataloaders, - get_next_dataloader, find_and_remove_tokens, delete_tensors_from_dict, + precompute_dataloader_sequence, ) from src.otter_ai.models.flamingo.modeling_flamingo import FlamingoForConditionalGeneration from src.otter_ai.models.otter.modeling_otter import OtterForConditionalGeneration @@ -116,8 +115,16 @@ def forward_pass(args, model, tokenizer, images, input_ids, attention_mask, labe def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, lr_scheduler, device_id, accelerator, wandb): dataloader_iterators = [cycle(dataloader) for dataloader in mimicit_loaders] - weights = get_weights_for_dataloaders(mimicit_loaders) - num_batches_per_epoch = sum(len(dataloader) for dataloader in mimicit_loaders) // args.gradient_accumulation_steps + num_batches_per_epoch = sum(len(dataloader) for dataloader in mimicit_loaders) + # Precompute the sequence before starting the training loop + num_dataloaders = len(dataloader_iterators) + seed = 42 + epoch + dataloader_sequence = precompute_dataloader_sequence(mimicit_loaders, num_batches_per_epoch, seed) + + def get_dataloader_from_sequence(sequence, current_step): + index = current_step % len(sequence) + dataloader_index = sequence[index] + return dataloader_iterators[dataloader_index] # Special Design for Idefics Model's prompt strategy if args.model_name.lower() == "idefics": @@ -148,21 +155,30 @@ def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, l if num_steps == num_batches_per_epoch: break data_time_m.update(time.time() - end) - dataloader_iterator = get_next_dataloader(dataloader_iterators, weights) - batch_mimicit = next(dataloader_iterator) # Fetch a batch from the chosen dataloader global_step = num_steps + epoch * num_batches_per_epoch + # dataloader_iterator = get_next_dataloader(dataloader_iterators, weights) + dataloader_iterator = get_dataloader_from_sequence(dataloader_sequence, num_steps) + # import pdb;pdb.set_trace() + batch_mimicit = next(dataloader_iterator) # Fetch a batch from the chosen dataloader #### MIMIC-IT FORWARD PASS #### - net_input = batch_mimicit.pop("net_input") - images = net_input.pop("patch_images").to(device_id, non_blocking=True) - input_ids = net_input.pop("input_ids").to(device_id, non_blocking=True) - attention_mask = net_input.pop("attention_masks").to(device_id, non_blocking=True) - - labels = input_ids.clone() - labels[labels == tokenizer.pad_token_id] = -100 - labels[:, 0] = -100 - labels[labels == media_token_id] = -100 - labels[labels == endofchunk_token_id] = -100 + try: + net_input = batch_mimicit.pop("net_input") + images = net_input.pop("patch_images").to(device_id, non_blocking=True) + input_ids = net_input.pop("input_ids").to(device_id, non_blocking=True) + attention_mask = net_input.pop("attention_masks").to(device_id, non_blocking=True) + + labels = input_ids.clone() + labels[labels == tokenizer.pad_token_id] = -100 + labels[:, 0] = -100 + labels[labels == media_token_id] = -100 + labels[labels == endofchunk_token_id] = -100 + except: + master_print(e) + # print("batch_mimicit",batch_mimicit) + # print("dataloader_iterator":dataloader_iterator) + # import pdb;pdb.set_trace() + continue # import pdb;pdb.set_trace() if args.remove_answer_token: @@ -177,18 +193,23 @@ def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, l master_print(f"model: {unwrapped_model.__class__.__name__}") master_print(f"model dtype: {unwrapped_model.dtype if hasattr(unwrapped_model, 'dtype') else 'None'}") - loss_mimicit = forward_pass( - args, - model, - tokenizer, - images, - input_ids, - attention_mask, - labels, - device_id, - autocast_type, - batch_mimicit, - ) + try: + loss_mimicit = forward_pass( + args, + model, + tokenizer, + images, + input_ids, + attention_mask, + labels, + device_id, + autocast_type, + batch_mimicit, + ) + except Exception as e: + # import pdb;pdb.set_trace() + master_print(batch_mimicit) + continue if accelerator.mixed_precision == "fp16": accelerator.backward(loss_mimicit.to(device_id)) @@ -199,23 +220,6 @@ def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, l mean_loss = loss_mimicit.detach().mean() cur_batch_max_tokens = input_ids.shape[1] - def mask_embedding(m): - if m.weight.requires_grad: - zero_mask = torch.zeros_like(m.weight.grad) - zero_mask[answer_token_id] = torch.ones_like(zero_mask[answer_token_id]) - # zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) - # zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id]) - m.weight.grad = m.weight.grad * zero_mask - - if args.mask_lm_head and args.distributed_type != "DEEPSPEED": - unwrapped_model = accelerator.unwrap_model(model) - if isinstance(unwrapped_model, IdeficsForVisionText2Text): - unwrapped_model.lm_head.apply(mask_embedding) - elif unwrapped_model.lang_encoder.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]: - unwrapped_model.lang_encoder.transformer.wte.apply(mask_embedding) - elif "LlamaForCausalLM" in unwrapped_model.lang_encoder.__class__.__name__: - unwrapped_model.lang_encoder.model.embed_tokens.apply(mask_embedding) - unwrapped_model.lang_encoder.lm_head.apply(mask_embedding) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) @@ -246,7 +250,7 @@ def mask_embedding(m): "mimicit_samples_per_second_per_gpu": mimicit_samples_per_second_per_gpu, "lr": optimizer.param_groups[0]["lr"], "loss_mimicit": mean_loss, - "global_step": global_step // args.gradient_accumulation_steps, + "global_step": global_step, group_name: mean_loss, }, commit=True, @@ -281,7 +285,7 @@ def mask_embedding(m): def main(): args = parse_args() - verify_yaml(args) + # verify_yaml(args) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16", @@ -485,25 +489,26 @@ def main(): accelerator.wait_for_everyone() if args.save_ckpt_each_epoch: # save_checkpoint(epoch, model, args, accelerator) - save_final_weights( + save_hf_weights( model, args, accelerator, processor=processor if "idefics" in args.model_name.lower() or "fuyu" in args.model_name.lower() else None, tokenizer=tokenizer if "llama2" in args.model_name.lower() else None, + epoch=epoch + 1, ) master_print(f"Saved checkpoint at epoch {epoch+1}.") accelerator.wait_for_everyone() # Save the final weights - save_final_weights( + save_hf_weights( model, args, accelerator, processor=processor if "idefics" in args.model_name.lower() or "fuyu" in args.model_name.lower() else None, tokenizer=tokenizer if "llama2" in args.model_name.lower() else None, ) - # accelerator.wait_for_everyone() + accelerator.wait_for_everyone() if __name__ == "__main__": diff --git a/pipeline/train/train_args.py b/pipeline/train/train_args.py index b3e79ca3..ee07ea38 100644 --- a/pipeline/train/train_args.py +++ b/pipeline/train/train_args.py @@ -45,6 +45,13 @@ def parse_args(): choices=["simple", "llama2", "idefics", "fuyu","pretrain"], help="simple is for mpt/llama1, rest are in different instruction templates.", ) + parser.add_argument( + "--dataset_type", + type=str, + default="mimicit", + choices=["laion", "mmc4", "mimicit", "cc3m","llava_pretrain"], + help="dataset_type.", + ) parser.add_argument( "--training_data_yaml", type=str, diff --git a/pipeline/train/train_utils.py b/pipeline/train/train_utils.py index 57a2074c..5c855773 100755 --- a/pipeline/train/train_utils.py +++ b/pipeline/train/train_utils.py @@ -236,7 +236,14 @@ def save_hf_weights(model, args, accelerator, processor=None, tokenizer=None, ep is_main_process = accelerator.is_main_process save_path = args.external_save_dir if epoch is None else f"{args.external_save_dir}/epoch_{epoch}" unwrapped_model.config.save_pretrained(save_path) - unwrapped_model.save_pretrained(save_path, is_main_process=is_main_process, accelerator=accelerator, max_shard_size="5GB", safe_serialization=False) + unwrapped_model.save_pretrained( + save_path, + is_main_process=is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + safe_serialization=False + ) + # unwrapped_model.save_pretrained(save_path, is_main_process=is_main_process, accelerator=accelerator, max_shard_size="5GB", safe_serialization=False) model_name = args.model_name.lower() if "idefics" in model_name or "fuyu" in model_name: diff --git a/shared_scripts/benchmark.yaml b/shared_scripts/benchmark.yaml index c6de0e6b..89f3f2c9 100644 --- a/shared_scripts/benchmark.yaml +++ b/shared_scripts/benchmark.yaml @@ -28,5 +28,5 @@ datasets: models: # - name: llava_model # model_path: /mnt/petrelfs/zhangyuanhan/LLaVA/checkpoints/llava-v1.5-7b - - name: otter_image - model_path: /mnt/petrelfs/zhangyuanhan/Otter/checkpoints/otter_llava_sft_nonconv_nogroup \ No newline at end of file + - name: otter_image_llava + model_path: /mnt/petrelfs/zhangyuanhan/Otter/checkpoints/otter_llava_direct_sft_nopretrain_llava_transform_1epoch/epoch_1/ \ No newline at end of file diff --git a/src/otter_ai/models/otter/converting_otter_to_lora.py b/src/otter_ai/models/otter/converting_otter_to_lora.py index 972807e0..7d8b35d3 100755 --- a/src/otter_ai/models/otter/converting_otter_to_lora.py +++ b/src/otter_ai/models/otter/converting_otter_to_lora.py @@ -2,7 +2,7 @@ import torch import sys -from .modeling_otter import OtterForConditionalGeneration +from modeling_otter import OtterForConditionalGeneration from peft import get_peft_model, LoraConfig, TaskType MODEL_CLASSES = { @@ -33,8 +33,11 @@ # Parse the input arguments args = parser.parse_args() +precision = {"torch_dtype": torch.bfloat16} + # Load the model -model = OtterForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto") +model = OtterForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto", **precision) + # adding lora standard_modules = ["q_proj", "v_proj"] @@ -47,15 +50,16 @@ "mpt": ["Wqkv"], } lora_config = LoraConfig( - r=16, - lora_alpha=32, + r=128, + lora_alpha=256, lora_dropout=0.05, task_type=TaskType.CAUSAL_LM, target_modules=model_to_lora_modules[lang_encoder_short_name], ) -model.config.update({"lora_config": {"r": 16, "lora_alpha": 32, "lora_dropout": 0.05}}) +model.config.update({"lora_config": {"r": 128, "lora_alpha": 256, "lora_dropout": 0.05}}) model.lang_encoder = get_peft_model(model.lang_encoder, lora_config) # Save the model checkpoint_path = args.save_path + OtterForConditionalGeneration.save_pretrained(model, checkpoint_path) diff --git a/src/otter_ai/models/otter/injecting_vicuna_into_otter.py b/src/otter_ai/models/otter/injecting_vicuna_into_otter.py index 7fd2a71f..d16fef42 100644 --- a/src/otter_ai/models/otter/injecting_vicuna_into_otter.py +++ b/src/otter_ai/models/otter/injecting_vicuna_into_otter.py @@ -44,12 +44,12 @@ ] save_path = f"{save_root_dir}/flamingo-vicuna-33B-v1.3-init" elif model_choice == "7B": - config_file = "/mnt/petrelfs/zhangyuanhan/Otter/src/otter_ai/models/flamingo/flamingo_vicuna-7B-v1.5_clip-vit-large-patch14-336.json" + config_file = "/mnt/petrelfs/zhangyuanhan/Otter/checkpoints/otter_vicuna-7B-v1.5_clip-vit-large-patch14-336_resampler256_only_input_init/config.json" state_dict_files = [ f"{root_dir}/pytorch_model-00001-of-00002.bin", f"{root_dir}/pytorch_model-00002-of-00002.bin", ] - save_path = f"{save_root_dir}/otter_vicuna-7B-v1.5_clip-vit-large-patch14-336_init" + save_path = f"{save_root_dir}/otter_vicuna-7B-v1.5_clip-vit-large-patch14-336_resampler256_only_input_init" else: raise ValueError("Invalid model_choice. Choose either '33B' or '7B'.") diff --git a/src/otter_ai/models/otter/modeling_otter.py b/src/otter_ai/models/otter/modeling_otter.py index f15875ad..7845fa76 100755 --- a/src/otter_ai/models/otter/modeling_otter.py +++ b/src/otter_ai/models/otter/modeling_otter.py @@ -13,7 +13,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer - +from transformers import CLIPImageProcessor # from pipeline.utils.modeling_value_head import AutoModelForCausalLMWithValueHead import sys @@ -214,6 +214,7 @@ def __init__( self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None self.layers = nn.ModuleList([]) + # import pdb;pdb.set_trace() for _ in range(depth): self.layers.append(OtterPerceiverBlock(dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult)) @@ -475,18 +476,32 @@ def init_otter( """ Initialize Otter by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. """ - - gated_cross_attn_layers = nn.ModuleList( - [ - OtterGatedCrossAttentionBlock( - dim=self.config.hidden_size, - dim_visual=vis_hidden_size, - ) - if (layer_idx + 1) % cross_attn_every_n_layers == 0 - else None - for layer_idx, _ in enumerate(self._get_decoder_layers()) - ] - ) + # import pdb;pdb.set_trace() + if cross_attn_every_n_layers >= 32: + gated_cross_attn_layers = nn.ModuleList( + [ + OtterGatedCrossAttentionBlock( + dim=self.config.hidden_size, + dim_visual=vis_hidden_size, + ) + # if (layer_idx + 1) % cross_attn_every_n_layers == 0 + if (layer_idx) % cross_attn_every_n_layers == 0 + else None + for layer_idx, _ in enumerate(self._get_decoder_layers()) + ] + ) + else: + gated_cross_attn_layers = nn.ModuleList( + [ + OtterGatedCrossAttentionBlock( + dim=self.config.hidden_size, + dim_visual=vis_hidden_size, + ) + if (layer_idx + 1) % cross_attn_every_n_layers == 0 + else None + for layer_idx, _ in enumerate(self._get_decoder_layers()) + ] + ) self._set_decoder_layers(nn.ModuleList([OtterLayer(gated_cross_attn_layer, decoder_layer) for gated_cross_attn_layer, decoder_layer in zip(gated_cross_attn_layers, self._get_decoder_layers())])) self.media_token_id = media_token_id self.use_media_placement_augmentation = use_media_placement_augmentation @@ -621,7 +636,7 @@ def __init__( target_modules=model_to_lora_modules[lang_encoder_short_name], ) self.lang_encoder = get_peft_model(self.lang_encoder, lora_config) - self.lang_encoder.master_print_trainable_parameters() + self.lang_encoder.print_trainable_parameters() self.post_init() @@ -768,6 +783,7 @@ def __init__( lang_encoder = RWForCausalLM(config=config.text_config) elif config.text_config.architectures[0] == "LlamaForCausalLM": text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path) + # torch.set_default_dtype(torch.bfloat16) lang_encoder = LlamaForCausalLM(config=config.text_config) else: import pdb @@ -777,6 +793,7 @@ def __init__( text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path) lang_encoder = LlamaForCausalLM(config=config.text_config) vision_encoder = CLIPVisionModel(config=config.vision_config) + self.image_processor = CLIPImageProcessor.from_pretrained(config.vision_config._name_or_path) text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "", ""]}) if text_tokenizer.pad_token is None: @@ -793,6 +810,7 @@ def __init__( self.lang_encoder = lang_encoder self.cross_attn_every_n_layers = config.cross_attn_every_n_layers + self.resampler_dim = config.resampler_dim if hasattr(config, "resampler_dim") else 64 # use_media_placement_augmentation is strictly false for Otter model self.use_media_placement_augmentation = False # config.use_media_placement_augmentation self.max_num_frames = config.max_num_frames if hasattr(config, "max_num_frames") else None @@ -807,7 +825,7 @@ def __init__( self.vision_encoder = vision_encoder self.vis_dim = 1024 - self.perceiver = OtterPerceiverResampler(dim=self.vis_dim, max_num_frames=self.max_num_frames) + self.perceiver = OtterPerceiverResampler(dim=self.vis_dim, max_num_frames=self.max_num_frames, dim_head = self.resampler_dim) self.lang_encoder.init_otter( media_token_id=self.media_token_id, @@ -836,7 +854,7 @@ def __init__( target_modules=model_to_lora_modules[lang_encoder_short_name], ) self.lang_encoder = get_peft_model(self.lang_encoder, lora_config) - self.lang_encoder.master_print_trainable_parameters() + self.lang_encoder.print_trainable_parameters() self.lang_encoder.__class__.__name__ = f"{original_architecture_name}LoRA" self.post_init() @@ -865,38 +883,18 @@ def init_weights(self): for param in self.parameters(): param.requires_grad = False - # Freeze all parameters in vision encoder + # Unfreeze all parameters in vision encoder if "train_vision_encoder" in self.config.__dict__ and self.config.train_vision_encoder is True: master_print("Unfreeze vision encoder.") for param in self.vision_encoder.parameters(): param.requires_grad = True - # Freeze all parameters in lang encoders except gated_cross_attn_layers + # Unfreeze all parameters in lang encoders except gated_cross_attn_layers if "train_lang_encoder" in self.config.__dict__ and self.config.train_lang_encoder is True: master_print("Unfreeze language decoder.") for name, param in self.lang_encoder.named_parameters(): param.requires_grad = True - # Freeze all parameters in vision encoder - if "train_vision_encoder" in self.config.__dict__ and self.config.train_vision_encoder is True: - for param in self.vision_encoder.parameters(): - param.requires_grad = True - - # Freeze all parameters in lang encoders except gated_cross_attn_layers - if "train_lang_encoder" in self.config.__dict__ and self.config.train_lang_encoder is True: - for name, param in self.lang_encoder.named_parameters(): - param.requires_grad = True - - # Freeze all parameters in vision encoder - if "train_vision_encoder" in self.config.__dict__ and self.config.train_vision_encoder is True: - for param in self.vision_encoder.parameters(): - param.requires_grad = True - - # Freeze all parameters in lang encoders except gated_cross_attn_layers - if "train_lang_encoder" in self.config.__dict__ and self.config.train_lang_encoder is True: - for name, param in self.lang_encoder.named_parameters(): - param.requires_grad = True - if "lora_config" in self.config.__dict__: # Use another logic to unfreeze gated_cross_attn_layers and perceivers master_print(f"LoRA trainable param: {(sum(param.numel() for name, param in self.lang_encoder.named_parameters() if 'lora' in name)) / 1e6:.3f} M")