diff --git a/examples/audio-classification/run_audio_classification.py b/examples/audio-classification/run_audio_classification.py index 74e148efd5..bcf643692b 100644 --- a/examples/audio-classification/run_audio_classification.py +++ b/examples/audio-classification/run_audio_classification.py @@ -46,8 +46,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") diff --git a/examples/contrastive-image-text/clip_media_pipe.py b/examples/contrastive-image-text/clip_media_pipe.py index a4248959c7..be2fa4a419 100755 --- a/examples/contrastive-image-text/clip_media_pipe.py +++ b/examples/contrastive-image-text/clip_media_pipe.py @@ -31,7 +31,6 @@ media_ext_reader_op_impl, media_ext_reader_op_tensor_info, ) - from habana_frameworks.torch.hpu import get_device_name except ImportError: pass @@ -47,7 +46,7 @@ class read_image_text_from_dataset(media_ext_reader_op_impl): """ - def __init__(self, params): + def __init__(self, params, fw_params): self.batch_size = 1 params = params["priv_params"] self.meta_dtype = params["label_dtype"] @@ -64,9 +63,7 @@ def __init__(self, params): else: self.max_file = get_max_file([img["path"] for img in self.dataset["image"]]) logger.info(f"The largest file is {self.max_file}.") - - def set_params(self, params): - self.batch_size = params.batch_size + self.batch_size = fw_params.batch_size def gen_output_info(self): out_info = [] @@ -134,7 +131,7 @@ class ClipMediaPipe(MediaPipe): instance_count = 0 def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False, queue_depth=1): - self.device = get_device_name() + self.device = "legacy" self.dataset = dataset self.drop_last = drop_last self.sampler = sampler @@ -157,7 +154,7 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False, def_output_image_size = [self.image_size, self.image_size] res_pp_filter = ftype.BICUBIC self.decode = fn.ImageDecoder( - device=self.device, + device="hpu", output_format=imgtype.RGB_P, random_crop_type=randomCropType.CENTER_CROP, resize=def_output_image_size, diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py index 0d3498a511..65bf2a35f0 100644 --- a/examples/contrastive-image-text/run_bridgetower.py +++ b/examples/contrastive-image-text/run_bridgetower.py @@ -56,8 +56,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py index c1abae0011..ae16c041f8 100644 --- a/examples/contrastive-image-text/run_clip.py +++ b/examples/contrastive-image-text/run_clip.py @@ -61,8 +61,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md index bad96c134a..4118195015 100644 --- a/examples/image-classification/README.md +++ b/examples/image-classification/README.md @@ -45,7 +45,7 @@ python run_image_classification.py \ --num_train_epochs 5 \ --per_device_train_batch_size 128 \ --per_device_eval_batch_size 64 \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --save_strategy epoch \ --load_best_model_at_end True \ --save_total_limit 3 \ @@ -197,7 +197,7 @@ python ../gaudi_spawn.py \ --num_train_epochs 5 \ --per_device_train_batch_size 128 \ --per_device_eval_batch_size 64 \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --save_strategy epoch \ --load_best_model_at_end True \ --save_total_limit 3 \ @@ -237,7 +237,7 @@ python ../gaudi_spawn.py \ --num_train_epochs 5 \ --per_device_train_batch_size 128 \ --per_device_eval_batch_size 64 \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --save_strategy epoch \ --load_best_model_at_end True \ --save_total_limit 3 \ diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index 30779dbc03..344f43308b 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -63,8 +63,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index f42fe0c0a0..0f1a2624d4 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -136,9 +136,9 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ ### Inference with FusedSDPA -Habana FusedSDPA is a fused and optimized implementation of torch.nn.functional.scaled_dot_product_attention() for Gaudi. For more details, refer to [Gaudi online documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa). Currently FusedSDPA works with BF16 precision for Llava models. +Habana FusedSDPA is a fused and optimized implementation of torch.nn.functional.scaled_dot_product_attention() for Gaudi. For more details, refer to [Gaudi online documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa). -Use the following commands to run Llava-1.5-7b inference with FusedSDPA +Use the following command to run Llava-1.5-7b BF16 inference with FusedSDPA ```bash python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-1.5-7b-hf \ @@ -149,7 +149,7 @@ python3 run_pipeline.py \ ``` -Use the following commands to run Llava-v1.6-mistral-7b inference with FusedSDPA +Use the following command to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA ```bash python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ @@ -157,4 +157,25 @@ python3 run_pipeline.py \ --use_hpu_graphs \ --bf16 \ --use_flash_attention -``` \ No newline at end of file +``` + + +Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA + +Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b: +```bash +QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \ +--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ +--image_path "https://llava-vl.github.io/static/images/view.jpg" \ +--use_hpu_graphs \ +--bf16 --use_flash_attention +``` + +Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b: +```bash +QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ +--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ +--image_path "https://llava-vl.github.io/static/images/view.jpg" \ +--use_hpu_graphs \ +--bf16 --use_flash_attention +``` diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index f6667cc012..dc4f5fdc7e 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -377,7 +377,7 @@ python3 run_lora_clm.py \ --output_dir ./model_lora_llama \ --num_train_epochs 3 \ --per_device_train_batch_size 16 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 1e-4 \ --warmup_ratio 0.03 \ @@ -410,7 +410,7 @@ LOWER_LIST=ops_bf16.txt python3 run_lora_clm.py \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 3e-4 \ --max_grad_norm 0.3 \ @@ -445,7 +445,7 @@ python ../gaudi_spawn.py \ --num_train_epochs 3 \ --per_device_train_batch_size 8 \ --gradient_accumulation_steps 2 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 3e-4 \ --warmup_ratio 0.03 \ @@ -480,7 +480,7 @@ LOWER_LIST=ops_bf16.txt python ../gaudi_spawn.py \ --num_train_epochs 3 \ --per_device_train_batch_size 16 \ --gradient_accumulation_steps 1 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 3e-4 \ --warmup_ratio 0.03 \ @@ -518,7 +518,7 @@ python ../gaudi_spawn.py \ --num_train_epochs 5 \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 1e-4 \ --logging_steps 1 \ @@ -547,7 +547,7 @@ LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 4e-4 \ --max_grad_norm 0.3 \ @@ -589,7 +589,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --per_device_train_batch_size 10 \ --per_device_eval_batch_size 1 \ --gradient_checkpointing \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --eval_delay 2 \ --save_strategy no \ --learning_rate 0.0018 \ @@ -641,7 +641,7 @@ python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \ --fsdp_config fsdp_config.json \ --fsdp auto_wrap \ --num_train_epochs 2 \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --per_device_eval_batch_size 1 \ --eval_delay 2 \ --do_eval \ @@ -668,7 +668,7 @@ DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 .. --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 4e-4 \ --max_grad_norm 0.3 \ diff --git a/examples/language-modeling/requirements.txt b/examples/language-modeling/requirements.txt index 3a09ba2535..955398ad19 100644 --- a/examples/language-modeling/requirements.txt +++ b/examples/language-modeling/requirements.txt @@ -4,4 +4,4 @@ sentencepiece != 0.1.92 protobuf evaluate scikit-learn -peft == 0.10.0 +peft == 0.12.0 diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 9433e8f3bf..06dfdf89bb 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -62,8 +62,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 18015ca515..a129e7e4e7 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -61,8 +61,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/protein-folding/README.md b/examples/protein-folding/README.md index d5003e1e41..8997c75143 100644 --- a/examples/protein-folding/README.md +++ b/examples/protein-folding/README.md @@ -66,7 +66,7 @@ python ../gaudi_spawn.py --world_size 8 --use_mpi run_sequence_classification.py --num_train_epochs 100 \ --lr_scheduler_type constant \ --do_eval \ - --evaluation_strategy epoch \ + --eval_strategy epoch \ --per_device_eval_batch_size 32 \ --logging_strategy epoch \ --save_strategy epoch \ diff --git a/examples/protein-folding/run_esmfold.py b/examples/protein-folding/run_esmfold.py index 096f18055a..d211a5b6a0 100644 --- a/examples/protein-folding/run_esmfold.py +++ b/examples/protein-folding/run_esmfold.py @@ -40,7 +40,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.11.0") +check_optimum_habana_min_version("1.12.0") def convert_outputs_to_pdb(outputs): diff --git a/examples/protein-folding/run_sequence_classification.py b/examples/protein-folding/run_sequence_classification.py index 502aec7d12..f41e7535cb 100644 --- a/examples/protein-folding/run_sequence_classification.py +++ b/examples/protein-folding/run_sequence_classification.py @@ -41,7 +41,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.11.0") +check_optimum_habana_min_version("1.12.0") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/examples/protein-folding/run_zero_shot_eval.py b/examples/protein-folding/run_zero_shot_eval.py index 56b6e16411..dd79a1d12c 100644 --- a/examples/protein-folding/run_zero_shot_eval.py +++ b/examples/protein-folding/run_zero_shot_eval.py @@ -36,7 +36,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.11.0") +check_optimum_habana_min_version("1.12.0") logging.basicConfig( diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index b7022310c2..7976c632b1 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -60,8 +60,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/question-answering/run_seq2seq_qa.py b/examples/question-answering/run_seq2seq_qa.py index ff56d5b4e6..9bf6f7ff07 100644 --- a/examples/question-answering/run_seq2seq_qa.py +++ b/examples/question-answering/run_seq2seq_qa.py @@ -56,8 +56,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 92e576be3f..1f75a778d3 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -237,7 +237,7 @@ python run_speech_recognition_seq2seq.py \ --logging_steps="25" \ --learning_rate="1e-5" \ --warmup_steps="500" \ - --evaluation_strategy="steps" \ + --eval_strategy="steps" \ --eval_steps="1000" \ --save_strategy="steps" \ --save_steps="1000" \ diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py index c01778d6d9..429df6e815 100644 --- a/examples/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/speech-recognition/run_speech_recognition_ctc.py @@ -59,8 +59,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index 5985825601..05243af132 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -55,8 +55,8 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/stable-diffusion/image_to_image_generation.py b/examples/stable-diffusion/image_to_image_generation.py index d24b2eba4a..1c9d5b086d 100755 --- a/examples/stable-diffusion/image_to_image_generation.py +++ b/examples/stable-diffusion/image_to_image_generation.py @@ -40,7 +40,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.10.0") +check_optimum_habana_min_version("1.12.0") logger = logging.getLogger(__name__) diff --git a/examples/stable-diffusion/image_to_video_generation.py b/examples/stable-diffusion/image_to_video_generation.py index 7beb73a1ac..b5d614f6f6 100755 --- a/examples/stable-diffusion/image_to_video_generation.py +++ b/examples/stable-diffusion/image_to_video_generation.py @@ -34,7 +34,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.8.1") +check_optimum_habana_min_version("1.12.0") logger = logging.getLogger(__name__) diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 0b3dee4699..689665b66b 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -39,7 +39,7 @@ def check_optimum_habana_min_version(*a, **b): # Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. -check_optimum_habana_min_version("1.11.0") +check_optimum_habana_min_version("1.12.0") logger = logging.getLogger(__name__) diff --git a/examples/stable-diffusion/training/media_pipe_imgdir.py b/examples/stable-diffusion/training/media_pipe_imgdir.py index 6e7a618453..cf6536f338 100644 --- a/examples/stable-diffusion/training/media_pipe_imgdir.py +++ b/examples/stable-diffusion/training/media_pipe_imgdir.py @@ -31,8 +31,6 @@ try: from habana_frameworks.mediapipe import fn - from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info - from habana_frameworks.mediapipe.backend.operator_specs import schema from habana_frameworks.mediapipe.media_types import ( dtype, ftype, @@ -41,8 +39,10 @@ ) from habana_frameworks.mediapipe.mediapipe import MediaPipe from habana_frameworks.mediapipe.operators.cpu_nodes.cpu_nodes import media_function - from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode - from habana_frameworks.torch.hpu import get_device_name + from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import ( + media_ext_reader_op_impl, + media_ext_reader_op_tensor_info, + ) except ImportError: pass @@ -65,14 +65,14 @@ def gen(): return DatasetHF.from_generator(gen) -class ReadImageTextFromDataset(MediaReaderNode): +class ReadImageTextFromDataset(media_ext_reader_op_impl): """ Class defining read image/text from directory node. """ - def __init__(self, name, guid, device, inputs, params, cparams, node_attr): - super().__init__(name, guid, device, inputs, params, cparams, node_attr) - self.dataset = params["dataset"] + def __init__(self, params, fw_params): + priv_params = params["priv_params"] + self.dataset = priv_params["dataset"] self.dataset_image = [] self.dataset_prompt_embeds = [] @@ -92,33 +92,31 @@ def __init__(self, name, guid, device, inputs, params, cparams, node_attr): self.dataset_original_sizes = np.array(self.dataset_original_sizes, dtype=np.uint32) self.dataset_crop_top_lefts = np.array(self.dataset_crop_top_lefts, dtype=np.uint32) self.epoch = 0 - self.batch_sampler = params["batch_sampler"] + self.batch_sampler = priv_params["batch_sampler"] self.num_imgs_slice = len(self.batch_sampler.sampler) self.num_batches_slice = len(self.batch_sampler) logger.info("Finding largest file ...") self.max_file = max(self.dataset["image"], key=lambda x: len(x)) - - def set_params(self, params): - self.batch_size = params.batch_size + self.batch_size = fw_params.batch_size def gen_output_info(self): out_info = [] - o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") out_info.append(o) sample = self.dataset[0] sample["pooled_prompt_embeds"] d0 = len(sample["pooled_prompt_embeds"]) d1 = len(sample["prompt_embeds"]) d2 = len(sample["prompt_embeds"][0]) - o = opnode_tensor_info(dtype.FLOAT32, np.array([d2, d1, self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info(dtype.FLOAT32, np.array([d2, d1, self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info(dtype.FLOAT32, np.array([d0, self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info(dtype.FLOAT32, np.array([d0, self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info("uint32", np.array([2, self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info("uint32", np.array([2, self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info("uint32", np.array([2, self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info("uint32", np.array([2, self.batch_size], dtype=np.uint32), "") out_info.append(o) return out_info @@ -161,25 +159,6 @@ def __next__(self): ) -read_image_text_from_dataset_params = {"dataset": None, "batch_sampler": []} - -schema.add_operator( - "SDXLDataReader", - None, - 0, - 0, - [], - 5, - read_image_text_from_dataset_params, - None, - ReadImageTextFromDataset, - dtype.NDT, -) -op_class = fn.operator_add("SDXLDataReader", False) -op_class.__module__ = fn.__name__ -setattr(fn, "SDXLDataReader", op_class) - - class RandomFlipFunction(media_function): """ Class to randomly generate input for RandomFlip media node. @@ -228,7 +207,7 @@ def __init__( drop_last=True, queue_depth=5, ): - self.device = get_device_name() + self.device = "legacy" self.dataset = dataset self.batch_size = batch_size @@ -248,11 +227,16 @@ def __init__( pipe_name=pipe_name, ) - self.input = fn.SDXLDataReader(dataset=self.dataset, batch_sampler=self.batch_sampler) + priv_params = {} + priv_params["dataset"] = self.dataset + priv_params["batch_sampler"] = self.batch_sampler + + self.input = fn.MediaExtReaderOp(impl=ReadImageTextFromDataset, num_outputs=5, priv_params=priv_params) + def_output_image_size = [self.image_size, self.image_size] res_pp_filter = ftype.BI_LINEAR self.decode = fn.ImageDecoder( - device=self.device, + device="hpu", output_format=imgtype.RGB_P, # random_crop_type=randomCropType.CENTER_CROP, resize=def_output_image_size, @@ -279,7 +263,7 @@ def __init__( dtype=dtype.UINT8, seed=100, ) - self.random_flip = fn.RandomFlip(horizontal=1, device=self.device) + self.random_flip = fn.RandomFlip(horizontal=1) SDXLMediaPipe.instance_count += 1 diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index 2d83568092..2ea2a59528 100755 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -65,8 +65,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 69a15579b0..9c827cad98 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -57,8 +57,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index fad94bdbcb..a2b541d391 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -120,6 +120,7 @@ Here are a few settings you may be interested in: - `--limit_hpu_graphs` to skip HPU Graph usage for first token to save memory - `--use_kv_cache` to use the [key/value cache](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.use_cache) to speed up generation - `--do_sample` or `--num_beams` to generate new tokens doing sampling or beam search (greedy search is the default) +- `--top_k` and `--penalty_alpha` to generate new tokens doing contrastive search (greedy search is the default) - `--prompt` to benchmark the model on one or several prompts of your choice - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it @@ -298,7 +299,7 @@ PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_s ### Running with FP8 -Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. +Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the Intel Neural Compressor (INC), which provides model measurement and quantization capabilities in PyTorch. More information on enabling fp8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html @@ -472,7 +473,7 @@ More information on usage of the unifier script can be found in fp8 Habana docs: ### CPU memory reduction on single card Some models can fit on HPU DRAM but can't fit on the CPU RAM. -When we run a model on single card and don't use deepspeed, the `--disk_offload` flag allows to offload weights to disk during model quantization in HQT. When this flag is mentioned, during the quantization process, each weight first is loaded from disk to CPU RAM, when brought to HPU DRAM and quantized there. This way not all the model is on the CPU RAM but only one weight each time. +When we run a model on single card and don't use deepspeed, the `--disk_offload` flag allows to offload weights to disk during model quantization in INC. When this flag is mentioned, during the quantization process, each weight first is loaded from disk to CPU RAM, when brought to HPU DRAM and quantized there. This way not all the model is on the CPU RAM but only one weight each time. To enable this weights offload mechanism, add `--disk_offload` flag to the topology command line. Here is an example of using disk_offload in quantize command. Please follow the "Running FP8 models on single device" section first before running the cmd below. @@ -497,6 +498,7 @@ python run_generation.py \ --flash_attention_recompute ``` + ### Using Habana Flash Attention Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. diff --git a/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json deleted file mode 100644 index 602a147baa..0000000000 --- a/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "method": "HOOKS", - "mode": "QUANTIZE", - "observer": "maxabs", - "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", - "allowlist": {"types": [], "names": []}, - "blocklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" -} diff --git a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json index 602a147baa..bfb932f098 100644 --- a/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json +++ b/examples/text-generation/quantization_config/act_maxabs_pow2_weights_pcs_opt_pow2_quant.json @@ -5,6 +5,5 @@ "scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2", "allowlist": {"types": [], "names": []}, "blocklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" + "dump_stats_path": "./hqt_output/measure" } diff --git a/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json b/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json index 6de845a54d..72dff310ee 100644 --- a/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json +++ b/examples/text-generation/quantization_config/maxabs_measure_include_outputs.json @@ -5,6 +5,5 @@ "measure_exclude": "NONE", "allowlist": {"types": [], "names": []}, "blocklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" + "dump_stats_path": "./hqt_output/measure" } \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_quant.json b/examples/text-generation/quantization_config/maxabs_quant.json index 02314a728e..34fab4601d 100644 --- a/examples/text-generation/quantization_config/maxabs_quant.json +++ b/examples/text-generation/quantization_config/maxabs_quant.json @@ -5,6 +5,5 @@ "scale_method": "maxabs_hw", "allowlist": {"types": [], "names": []}, "blocklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" + "dump_stats_path": "./hqt_output/measure" } \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json index b3fd2e26db..87dc52d08a 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json +++ b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json @@ -8,6 +8,5 @@ "model.layers.1.block_sparse_moe.experts.(3|4).w2", "model.layers.[29-31].block_sparse_moe.experts.[0-7].w2" ]}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" + "dump_stats_path": "./hqt_output/measure" } \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_quant_phi.json b/examples/text-generation/quantization_config/maxabs_quant_phi.json index 8f13c2aa38..a77200c99f 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_phi.json +++ b/examples/text-generation/quantization_config/maxabs_quant_phi.json @@ -9,6 +9,5 @@ "matmul_av", "lm_head" ]}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" + "dump_stats_path": "./hqt_output/measure" } diff --git a/examples/text-generation/quantization_config/unit_scale_quant.json b/examples/text-generation/quantization_config/unit_scale_quant.json index caad4bb2a4..6bbbde8672 100644 --- a/examples/text-generation/quantization_config/unit_scale_quant.json +++ b/examples/text-generation/quantization_config/unit_scale_quant.json @@ -5,6 +5,5 @@ "scale_method": "unit_scale", "allowlist": {"types": [], "names": []}, "blocklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" + "dump_stats_path": "./hqt_output/measure" } diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index c41664ebf3..243b9926bf 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -28,7 +28,7 @@ from pathlib import Path import torch -from utils import adjust_batch, count_hpu_graphs, initialize_model +from utils import adjust_batch, count_hpu_graphs, finalize_quantization, initialize_model from optimum.habana.utils import get_hpu_memory_stats @@ -102,6 +102,18 @@ def setup_parser(parser): type=int, help="Number of beams used for beam search generation. 1 means greedy search will be performed.", ) + parser.add_argument( + "--top_k", + default=None, + type=int, + help="Size of candidate set used for re-ranking in contrastive search. top_k > 1 enables contrastive search.", + ) + parser.add_argument( + "--penalty_alpha", + default=None, + type=float, + help="Degeneration penalty for contrastive search. penalty_alpha > 0 enables contrastive search.", + ) parser.add_argument( "--trim_logits", action="store_true", @@ -303,6 +315,9 @@ def setup_parser(parser): if not args.use_hpu_graphs: args.limit_hpu_graphs = False + if args.use_flash_attention and not args.flash_attention_fast_softmax: + args.flash_attention_fast_softmax = True + args.quant_config = os.getenv("QUANT_CONFIG", "") if args.quant_config == "" and args.disk_offload: logger.warning( @@ -661,9 +676,7 @@ def generate_dataset(batch): print(f"Graph compilation duration = {compilation_duration} seconds") print(separator) if args.quant_config: - import habana_quantization_toolkit - - habana_quantization_toolkit.finish_measurements(model) + finalize_quantization(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index ac96f0a522..8f4b7a4416 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -33,7 +33,7 @@ import torch import torch.nn.functional as F from run_generation import setup_parser -from utils import initialize_model +from utils import finalize_quantization, initialize_model from optimum.habana.utils import get_hpu_memory_stats @@ -218,9 +218,8 @@ def main(): json.dump(results, open(args.output_file, "w"), indent=2) print(json.dumps(results, indent=2)) if args.quant_config: - import habana_quantization_toolkit + finalize_quantization(model) - habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 3ad2cc2352..ee1c624230 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -186,6 +186,44 @@ def get_torch_compiled_model(model): return model +def setup_quantization(model, args): + if os.getenv("USE_INC", "1") != "0": + try: + from neural_compressor.torch.quantization import FP8Config, convert, prepare + except ImportError: + raise ImportError( + "Module neural_compressor is missing. Please use a newer Synapse version to use quantization, or set the environment variable to USE_INC=0" + ) + + config = FP8Config.from_json_file(args.quant_config) + if config.measure: + model = prepare(model, config) + elif config.quantize: + model = convert(model, config) + else: + import habana_quantization_toolkit + + habana_quantization_toolkit.prep_model(model) + + return model + + +def finalize_quantization(model): + if os.getenv("USE_INC", "1") != "0": + try: + from neural_compressor.torch.quantization import finalize_calibration + except ImportError: + raise ImportError( + "Module neural_compressor is missing. Please use a newer Synapse version to use quantization, or set the environment variable to USE_INC=0" + ) + + finalize_calibration(model) + else: + import habana_quantization_toolkit + + habana_quantization_toolkit.finish_measurements(model) + + def setup_model(args, model_dtype, model_kwargs, logger): logger.info("Single-device run.") if args.assistant_model is None: @@ -220,11 +258,7 @@ def setup_model(args, model_dtype, model_kwargs, logger): args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs ) if args.quant_config: - import habana_quantization_toolkit - - habana_quantization_toolkit.prep_model(model) - if args.assistant_model is not None: - habana_quantization_toolkit.quantize_model(assistant_model) + model = setup_quantization(model, args) model = model.eval().to(args.device) if args.assistant_model is not None: @@ -385,11 +419,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): patch_scoped_linear_all_reduce(model) if args.quant_config: - import habana_quantization_toolkit - - habana_quantization_toolkit.prep_model(model) - if args.assistant_model is not None: - habana_quantization_toolkit.prep_model(assistant_model) + model = setup_quantization(model, args) if args.torch_compile and model.config.model_type == "llama": model = get_torch_compiled_model(model) @@ -526,6 +556,8 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.bucket_internal = args.bucket_internal generation_config.do_sample = args.do_sample generation_config.num_beams = args.num_beams + generation_config.top_k = args.top_k + generation_config.penalty_alpha = args.penalty_alpha generation_config.bad_words_ids = bad_words_ids generation_config.force_words_ids = force_words_ids generation_config.num_return_sequences = args.num_return_sequences @@ -586,6 +618,7 @@ def initialize_model(args, logger): "token": args.token, "trust_remote_code": args.trust_remote_code, } + if args.trust_remote_code: logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail") diff --git a/examples/translation/run_translation.py b/examples/translation/run_translation.py index 942503c4ad..fc97e8e0db 100644 --- a/examples/translation/run_translation.py +++ b/examples/translation/run_translation.py @@ -62,8 +62,8 @@ def check_optimum_habana_min_version(*a, **b): logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -check_min_version("4.40.0") -check_optimum_habana_min_version("1.11.0") +check_min_version("4.43.0") +check_optimum_habana_min_version("1.12.0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/examples/trl/README.md b/examples/trl/README.md index 747e95be79..3649a81c13 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -19,7 +19,7 @@ The following example is for the supervised Lora finetune with Qwen2 model for c --output_dir ./model_qwen \ --num_train_epochs 1 \ --per_device_train_batch_size 16 \ - --evaluation_strategy "no" \ + --eval_strategy "no" \ --save_strategy "no" \ --learning_rate 3e-4 \ --warmup_ratio 0.03 \ @@ -244,10 +244,10 @@ python run_generation.py \ ### Training The following example is for fine-tuning stable diffusion using Denoising Diffusion Policy Optimization -([DDPO](https://huggingface.co/docs/trl/en/ddpo_trainer)). The implementation supports LoRA and +([DDPO](https://huggingface.co/docs/trl/en/ddpo_trainer)). The implementation supports LoRA and non-LoRA-based training. LoRA based training is faster and less finicky to converge than non-LoRA -based training. Recommendations for non-Lora based training (described [here](https://huggingface.co/blog/trl-ddpo)) -are setting the learning rate relatively low (e.g., 1e-5) and disabling mixed precision training. +based training. Recommendations for non-Lora based training (described [here](https://huggingface.co/blog/trl-ddpo)) +are setting the learning rate relatively low (e.g., 1e-5) and disabling mixed precision training. HPU graphs are enabled by default for better performance. There are two main steps to the DDPO training process: @@ -272,7 +272,7 @@ python ddpo.py \ --hf_hub_model_id="ddpo-finetuned-stable-diffusion" \ --push_to_hub False ``` - + 2. Inference using the fine-tuned LoRA weights as shown in the example below: ```python import torch diff --git a/examples/trl/dpo.py b/examples/trl/dpo.py index 5779296bbb..bc9049a8a7 100644 --- a/examples/trl/dpo.py +++ b/examples/trl/dpo.py @@ -149,7 +149,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]: gradient_accumulation_steps=script_args.gradient_accumulation_steps, gradient_checkpointing=script_args.gradient_checkpointing, learning_rate=script_args.learning_rate, - evaluation_strategy="steps", + eval_strategy="steps", eval_steps=script_args.eval_steps, output_dir=script_args.output_dir, report_to=script_args.report_to, diff --git a/examples/trl/reward_modeling.py b/examples/trl/reward_modeling.py index 1bd8e65ecf..ec81d2eca8 100644 --- a/examples/trl/reward_modeling.py +++ b/examples/trl/reward_modeling.py @@ -134,7 +134,7 @@ class ScriptArguments: per_device_eval_batch_size=script_args.per_device_eval_batch_size, num_train_epochs=script_args.num_train_epochs, weight_decay=script_args.weight_decay, - evaluation_strategy="steps", + eval_strategy="steps", eval_steps=script_args.eval_steps, save_strategy="steps", save_steps=script_args.save_steps, diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index f324aebd6a..1d97842a47 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -34,6 +34,7 @@ from accelerate.tracking import GeneralTracker, filter_trackers from accelerate.utils import ( AutocastKwargs, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, @@ -44,6 +45,7 @@ LoggerType, MegatronLMPlugin, PrecisionType, + ProfileKwargs, ProjectConfiguration, RNGType, check_os_kernel, @@ -82,6 +84,12 @@ logger = get_logger(__name__) +# Sentinel values for defaults +_split_batches = object() +_dispatch_batches = object() +_even_batches = object() +_use_seedable_sampler = object() + class GaudiAccelerator(Accelerator): """ @@ -91,10 +99,11 @@ class GaudiAccelerator(Accelerator): def __init__( self, device_placement: bool = True, - split_batches: bool = False, + split_batches: bool = _split_batches, mixed_precision: PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, + dataloader_config: DataLoaderConfiguration | None = None, deepspeed_plugin: DeepSpeedPlugin | None = None, fsdp_plugin: GaudiFullyShardedDataParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, @@ -103,9 +112,9 @@ def __init__( project_dir: str | os.PathLike | None = None, project_config: ProjectConfiguration | None = None, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, - dispatch_batches: bool | None = None, - even_batches: bool = True, - use_seedable_sampler: bool = False, + dispatch_batches: bool | None = _dispatch_batches, + even_batches: bool = _even_batches, + use_seedable_sampler: bool = _use_seedable_sampler, step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: GaudiDynamoBackend | str | None = None, @@ -179,6 +188,9 @@ def __init__( self.init_handler = None self.fp8_recipe_handler = None self.autocast_handler = None + self.profile_handler = None + self.has_lomo_optimizer = False + if kwargs_handlers is not None: for handler in kwargs_handlers: assert isinstance( @@ -209,6 +221,11 @@ def __init__( raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.") else: self.autocast_handler = handler + elif isinstance(handler, ProfileKwargs): + if self.profile_handler is not None: + raise ValueError("You can only pass one `ProfileKwargs` in `kwargs_handler`.") + else: + self.profile_handler = handler kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {} self.state = GaudiAcceleratorState( @@ -222,6 +239,16 @@ def __init__( **kwargs, ) + self.delayed_fp8_autocast = False + if self.fp8_recipe_handler is not None: + # We already check if FP8 is available during `self.state` + if self.state.mixed_precision != "fp8": + raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.") + self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.FSDP, + ) + if self.state.is_fp8_enabled: if self.fp8_recipe_handler is None: self.fp8_recipe_handler = GaudiFP8RecipeKwargs() @@ -256,10 +283,32 @@ def __init__( ) self.device_placement = device_placement - self.split_batches = split_batches - self.dispatch_batches = dispatch_batches - self.even_batches = even_batches - self.use_seedable_sampler = use_seedable_sampler + if dataloader_config is None: + dataloader_config = DataLoaderConfiguration() + self.dataloader_config = dataloader_config + # Deal with deprecated args + # TODO: Remove in v1.0.0 + deprecated_dl_args = {} + if dispatch_batches is not _dispatch_batches: + deprecated_dl_args["dispatch_batches"] = dispatch_batches + self.dataloader_config.dispatch_batches = dispatch_batches + if split_batches is not _split_batches: + deprecated_dl_args["split_batches"] = split_batches + self.dataloader_config.split_batches = split_batches + if even_batches is not _even_batches: + deprecated_dl_args["even_batches"] = even_batches + self.dataloader_config.even_batches = even_batches + if use_seedable_sampler is not _use_seedable_sampler: + deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler + self.dataloader_config.use_seedable_sampler = use_seedable_sampler + if len(deprecated_dl_args) > 0: + values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()]) + warnings.warn( + f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. " + "Please pass an `accelerate.DataLoaderConfiguration` instead: \n" + f"dataloader_config = DataLoaderConfiguration({values})", + FutureWarning, + ) self.step_scheduler_with_optimizer = step_scheduler_with_optimizer # Mixed precision attributes @@ -351,6 +400,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) else: model.forward = convert_outputs_to_fp32(new_forward) + if self.state.is_fp8_enabled: model = convert_model(model) @@ -364,16 +414,19 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e " In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism." " Therefore you should not specify that you are under any distributed regime in your accelerate config." ) - current_device = list(model_devices)[0] - current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device + elif len(model_devices) == 1: + current_device = list(model_devices)[0] + current_device_index = ( + current_device.index if isinstance(current_device, torch.device) else current_device + ) - if torch.device(current_device_index) != self.device: - # if on the first device (GPU 0) we don't care - if (self.device.index is not None) or (current_device_index != 0): - raise ValueError( - "You can't train a model that has been loaded in 8-bit precision on a different device than the one " - "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}" - ) + if torch.device(current_device_index) != self.device: + # if on the first device (GPU 0) we don't care + if (self.device.index is not None) or (current_device_index != 0): + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision on a different device than the one " + "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}" + ) if "cpu" in model_devices or "disk" in model_devices: raise ValueError( @@ -386,6 +439,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e if any(p.requires_grad for p in model.parameters()): kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) + if self.ddp_handler is not None: + self.ddp_handler.register_comm_hook(model) elif self.distributed_type == GaudiDistributedType.FSDP: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -430,6 +485,72 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ), auto_wrap_policy=fsdp_plugin.auto_wrap_policy, ) + # In the event the model had been loaded in low precision, but + # mixed precision had also been activated, then we follow DeepSpeed's + # strategy to hold the parameters in full precision. + # - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against + # fsdp_plugin.mixed_precision_policy. + # - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper. + # * this attribute will always set by init_utils.init_core_state so its always not None. + # * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype + # * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None, + # we sill want to upcast the flat_param. + if self.mixed_precision != "no": # if mixed precision is set + upcasted_log = [] + for module in FSDP.fsdp_modules(model): + # Referencing DeepSpeed Zero3 + # - in Init, params are converted to 16bit while partitioning. + # - in accelerator.prepare, deepspeed.initalize is called to: + # * creates the DeepSpeeedEngine. + # * since zero_optimization() is True , calls engine._configure_zero_optimizer. + # + # Inside the DeepSpeed Zero3 optimizer configuration, which initalizes + # DeepSpeedZeroOptimizer_Stage3, during which: + # * trainable_param_groups are obtained from the attached optimizer + # (already partitioned in 16bit). + # * then _setup_for_real_optimizer -> _create_fp32_partitions + # which performs the fp32 upcasting. + + # To mimick DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held + # within an FSDP wrapper. This FlatParameter will be seen by the optimizer. + # - even though there is a torch.device('meta') guard below, we + # expect _init_utils._init_param_handle_from_module to already + # sync the parameter. + + if not module._has_params: + continue # skip if FSDP module not managing parameters + param = module._flat_param + if ( + param.dtype != torch.float32 + and param.device != torch.device("meta") + and param.requires_grad + ): + # keep log of names_params that was upcasted + # NOTE: resorted to this because warnings.simplefilter("once") is somehow not working + name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns)) + if name_param_log not in upcasted_log: + upcasted_log.append(name_param_log) + + # this works because of FSDP's _runtime_utils.lazy_init. + # Have to be careful not to call anything before this that + # triggers lazy_init (e.g., _is_fsdp_root). + param.data = param.data.to(torch.float32) # upcasting + module._handle._orig_param_dtype = torch.float32 # update + + # report the warnings + # some messages can be quite repetitive, especially when reporting about layers that have identical architecture. + if self.is_main_process: + for name_log, param_log in upcasted_log: + warnings.warn( + f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. " + f"Affects: {param_log}." + ) + + if len(upcasted_log) > 0: + warnings.warn( + "FSDP upcast of low precision parameters may affect the precision of model checkpoints." + ) + # if the previous and current models are same, delete the previous one if len(self._models) > 1 and (self._models[-2] is self._models[-1]): del self._models[-2] @@ -551,6 +672,8 @@ def _prepare_deepspeed(self, *args): ) if model is not None: + # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules + deepspeed_plugin.set_moe_leaf_modules(model) # deal with config keys that use `auto` value and rely on model's hidden_size hidden_size_based_keys = [ "zero_optimization.reduce_bucket_size", @@ -580,7 +703,7 @@ def _prepare_deepspeed(self, *args): config_kwargs.update( { "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, - "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": int(0.9 * hidden_size * hidden_size), "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, } ) @@ -640,6 +763,11 @@ def _prepare_deepspeed(self, *args): os.environ["DEEPSPEED_USE_HPU"] = "true" engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + # torch.compile should be called if dynamo plugin backend is set and only if the model isn't already compiled. + if self.state.dynamo_plugin.backend == GaudiDynamoBackend.HPU_BACKEND and not is_compiled_module( + kwargs["model"] + ): + engine.compile() if optimizer is not None: optimizer = DeepSpeedOptimizerWrapper(optimizer) if scheduler is not None: @@ -723,6 +851,7 @@ def prepare_data_loader( even_batches=self.even_batches, slice_fn_for_dispatch=slice_fn_for_dispatch, use_seedable_sampler=self.use_seedable_sampler, + non_blocking=self.non_blocking, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/optimum/habana/accelerate/data_loader.py b/optimum/habana/accelerate/data_loader.py index aa9f14d1b7..ae00b8976d 100644 --- a/optimum/habana/accelerate/data_loader.py +++ b/optimum/habana/accelerate/data_loader.py @@ -8,6 +8,7 @@ DataLoaderShard, IterableDatasetShard, SeedableRandomSampler, + get_sampler, ) from accelerate.state import GradientState from accelerate.utils import ( @@ -55,7 +56,14 @@ class GaudiDataLoaderDispatcher(DataLoaderDispatcher, DataLoader): """ def __init__( - self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, slice_fn=None, **kwargs + self, + dataset, + split_batches: bool = False, + skip_batches=0, + _drop_last: bool = False, + _non_blocking: bool = False, + slice_fn=None, + **kwargs, ): shuffle = False if is_torch_version(">=", "1.11.0"): @@ -72,6 +80,7 @@ def __init__( self.gradient_state = GradientState() self.state = GaudiAcceleratorState() self._drop_last = _drop_last + self._non_blocking = _non_blocking self.skip_batches = skip_batches self.slice_fn = slice_tensors if slice_fn is None else slice_fn @@ -144,7 +153,7 @@ def __iter__(self): if self.state.process_index != 0: # Initialize tensors on other processes than process 0. batch = initialize_tensors(batch_info[0]) - batch = send_to_device(batch, self.state.device) + batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking) # Broadcast the batch before splitting it. batch = broadcast(batch, from_process=0) @@ -210,6 +219,7 @@ def gaudi_prepare_data_loader( even_batches: bool = True, slice_fn_for_dispatch: Optional[Callable] = None, use_seedable_sampler: bool = False, + non_blocking: bool = False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -266,7 +276,11 @@ def gaudi_prepare_data_loader( use_seedable_sampler (`bool`, *optional*, defaults to `False`): Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better reproducability. Comes at a cost of potentially different performances due to different shuffling - algorithms but ensures results will be the *exact* same. + algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every + `self.set_epoch` + non_blocking (`bool`, *optional*, defaults to `False`): + If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has + `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. Returns: `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches @@ -294,23 +308,34 @@ def gaudi_prepare_data_loader( process_index = state.process_index # Sanity check - batch_size = dataloader.batch_size if dataloader.batch_size is not None else dataloader.batch_sampler.batch_size - if split_batches and batch_size > 1 and batch_size % num_processes != 0: - raise ValueError( - f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) " - f"needs to be a round multiple of the number of processes ({num_processes})." - ) + if split_batches: + if dataloader.batch_size is not None: + batch_size_for_check = dataloader.batch_size + else: + # For custom batch_sampler + if hasattr(dataloader.batch_sampler, "batch_size"): + batch_size_for_check = dataloader.batch_sampler.batch_size + else: + raise ValueError( + "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed " + "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. " + "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` " + f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set." + ) + + if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0: + raise ValueError( + f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) " + f"needs to be a round multiple of the number of processes ({num_processes})." + ) new_dataset = dataloader.dataset # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None - sampler_is_batch_sampler = False - synchronized_generator = None sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - if sampler_is_batch_sampler: - sampler = dataloader.sampler.sampler - else: - sampler = dataloader.batch_sampler.sampler + synchronized_generator = None + + sampler = get_sampler(dataloader) # Commenting the block below as it makes the accuracy decrease quite a lot for a few models and tasks # e.g. audio classification with Wav2Vec2 or Seq2SeqQA with T5 # if isinstance(sampler, RandomSampler) and use_seedable_sampler: @@ -343,16 +368,10 @@ def gaudi_prepare_data_loader( # for a few models and tasks e.g. audio classification with Wav2Vec2 or Seq2SeqQA with T5 # Keeping it for now # New batch sampler for the current process. - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - if sampler_is_batch_sampler: - sampler = dataloader.sampler.sampler - else: - sampler = dataloader.batch_sampler.sampler if hasattr(sampler, "generator"): if sampler.generator is None: sampler.generator = torch.Generator() synchronized_generator = sampler.generator - batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler new_batch_sampler = BatchSamplerShard( batch_sampler, @@ -386,11 +405,6 @@ def gaudi_prepare_data_loader( kwargs["batch_size"] = ( dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size ) - if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler: - if sampler_is_batch_sampler: - dataloader.sampler.sampler = sampler - else: - dataloader.batch_sampler.sampler = sampler if dispatch_batches: kwargs.pop("generator") dataloader = GaudiDataLoaderDispatcher( @@ -398,6 +412,7 @@ def gaudi_prepare_data_loader( split_batches=split_batches, batch_sampler=new_batch_sampler, _drop_last=dataloader.drop_last, + _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, **kwargs, ) @@ -409,6 +424,7 @@ def gaudi_prepare_data_loader( batch_size=dataloader.batch_size, rng_types=rng_types, _drop_last=dataloader.drop_last, + _non_blocking=non_blocking, synchronized_generator=synchronized_generator, **kwargs, ) @@ -420,7 +436,11 @@ def gaudi_prepare_data_loader( rng_types=rng_types, synchronized_generator=synchronized_generator, _drop_last=dataloader.drop_last, + _non_blocking=non_blocking, **kwargs, ) + if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler: + dataloader.set_sampler(sampler) + return dataloader diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py index 1db6980ee7..2f50035f22 100644 --- a/optimum/habana/accelerate/utils/dataclasses.py +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import warnings from dataclasses import dataclass from enum import Enum @@ -46,7 +47,7 @@ class GaudiDistributedType(str, Enum): class GaudiDynamoBackend(str, BaseEnum): """ - Represents a dynamo backend (see https://github.com/pytorch/torchdynamo). + Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html). Values: @@ -141,6 +142,13 @@ def __post_init__(self): self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1 self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 + if str_to_bool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 and not self.sync_module_states: + warnings.warn( + "sync_module_states cannot be False since efficient cpu ram loading enabled. " + "Setting sync_module_states to True." + ) + self.sync_module_states = True + if self.sync_module_states: device = torch.device("hpu", torch.hpu.current_device()) self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) diff --git a/optimum/habana/accelerate/utils/operations.py b/optimum/habana/accelerate/utils/operations.py index 21b01680df..6cdbdbe62e 100644 --- a/optimum/habana/accelerate/utils/operations.py +++ b/optimum/habana/accelerate/utils/operations.py @@ -13,7 +13,7 @@ # limitations under the License. """ -A set of basic tensor ops compatible with tpu, gpu, and multigpu +A set of basic tensor ops compatible with hpu """ import torch diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 895b721518..aa88252868 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -94,7 +94,7 @@ def model_on_meta(config): """ Checks if load the model to meta. """ - return config.model_type in ["bloom", "llama", "falcon", "mixtral"] + return config.model_type in ["bloom", "llama", "falcon", "mixtral", "qwen2"] def get_optimized_model_name(config): diff --git a/optimum/habana/diffusers/models/attention_processor.py b/optimum/habana/diffusers/models/attention_processor.py new file mode 100755 index 0000000000..b0461a272b --- /dev/null +++ b/optimum/habana/diffusers/models/attention_processor.py @@ -0,0 +1,189 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND, logging +from diffusers.utils.import_utils import is_xformers_available +from torch import nn + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, invAttnHead=None): + return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) + + +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention +class ScaledDotProductAttention(nn.Module): + def __init__(self): + super().__init__() + self.bmm1 = Matmul() + self.bmm2 = Matmul() + self.softmax = Softmax() + + def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + # Efficient implementation: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") + attn_bias = torch.zeros(L, S, dtype=query.dtype) + + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + if S < 128: + attn_weight = self.bmm1(key, query.transpose(-2, -1)) + attn_weight = self.softmax(attn_weight, dim=-2, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return self.bmm2(attn_weight.transpose(-2, -1), value) + else: + attn_weight = self.bmm1(query, key.transpose(-2, -1)) + attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return self.bmm2(attn_weight, value) + + +# Copied from diffusers.models.attention_processor.AttnProcessor2_0 +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, attention_module=None): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.attention_module = attention_module + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + # hidden_states = F.scaled_dot_product_attention( + # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + # ) + if os.environ.get("PATCH_SDPA") is not None: + hidden_states = self.attention_module( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + import habana_frameworks.torch.hpu as ht + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + with ht.sdp_kernel(enable_recompute=True): + hidden_states = FusedSDPA.apply(query, key, value, attention_mask, 0.0, False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +AttentionProcessor = Union[AttnProcessor2_0,] diff --git a/optimum/habana/diffusers/models/unet_2d_condition.py b/optimum/habana/diffusers/models/unet_2d_condition.py index 4eca573665..204d023d65 100644 --- a/optimum/habana/diffusers/models/unet_2d_condition.py +++ b/optimum/habana/diffusers/models/unet_2d_condition.py @@ -1,13 +1,13 @@ from typing import Any, Dict, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch +import torch.utils.checkpoint from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput -from diffusers.utils import USE_PEFT_BACKEND, deprecate, scale_lora_layers, unscale_lora_layers +from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers -from optimum.utils import logging - -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def gaudi_unet_2d_condition_model_forward( @@ -27,9 +27,11 @@ def gaudi_unet_2d_condition_model_forward( return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" - Copied from: https://github.com/huggingface/diffusers/blob/v0.19.3/src/diffusers/models/unet_2d_condition.py#L700 + Copied from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/unets/unet_2d_condition.py#L843 - Adds a workaround to be able to compute `conv_in` with Torch Autocast and full bf16 precision. + Changes: + - Adds a workaround to be able to compute `conv_in` with Torch Autocast and full bf16 precision. + - Added mark_step in unet forward """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). @@ -90,6 +92,7 @@ def gaudi_unet_2d_condition_model_forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) + htcore.mark_step() # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. @@ -234,8 +237,8 @@ def gaudi_unet_2d_condition_model_forward( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ - and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ - for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py new file mode 100644 index 0000000000..78b4b52b79 --- /dev/null +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py @@ -0,0 +1,708 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import habana_frameworks.torch.core as htcore +import torch +from diffusers import StableDiffusionXLPipeline +from diffusers.image_processor import PipelineImageInput +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + StableDiffusionXLPipelineOutput, + rescale_noise_cfg, +) +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + retrieve_timesteps as retrieve_timesteps_hpu, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from optimum.utils import logging + +from ...models.attention_processor import ( + AttentionProcessor, + AttnProcessor2_0, + ScaledDotProductAttention, +) +from ...models.unet_2d_condition import gaudi_unet_2d_condition_model_forward + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor +def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if os.environ.get("PATCH_SDPA") is not None: + setattr(module, "attention_module", ScaledDotProductAttention()) + module.set_processor(processor(module.attention_module)) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor +def set_default_attn_processor_hpu(self): + """ + Disables custom attention processors and sets the default attention implementation from HPU. + """ + processor = AttnProcessor2_0 + set_attn_processor_hpu(self, processor) + + +class StableDiffusionXLPipeline_HPU(StableDiffusionXLPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__( + vae, + text_encoder, + text_encoder_2, + tokenizer, + tokenizer_2, + unet, + scheduler, + image_encoder, + feature_extractor, + force_zeros_for_empty_prompt, + add_watermarker, + ) + self.unet.set_default_attn_processor = set_default_attn_processor_hpu + self.unet.forward = gaudi_unet_2d_condition_model_forward + + def run_unet( + self, + unet, + latents, + timesteps, + t, + i, + add_text_embeds, + add_time_ids, + prompt_embeds, + extra_step_kwargs, + negative_prompt_embeds, + negative_add_time_ids, + negative_pooled_prompt_embeds, + num_warmup_steps, + progress_bar, + callback, + callback_steps, + ip_adapter_image, + image_embeds, + timestep_cond, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = unet( + unet, + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + htcore.mark_step() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + image_embeds = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps_hpu(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + timesteps = [t.item() for t in timesteps] + if self.quantized: + for i, t in enumerate(timesteps[0:-2]): + if self.interrupt: + continue + latents = self.run_unet( + self.unet, + latents, + timesteps, + t, + i, + add_text_embeds, + add_time_ids, + prompt_embeds, + extra_step_kwargs, + negative_prompt_embeds, + negative_add_time_ids, + negative_pooled_prompt_embeds, + num_warmup_steps, + progress_bar, + callback, + callback_steps, + ip_adapter_image, + image_embeds, + timestep_cond, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ) + for i, t in enumerate(timesteps[-2:], 18): + if self.interrupt: + continue + latents = self.run_unet( + self.unet_bf16, + latents, + timesteps, + t, + i, + add_text_embeds, + add_time_ids, + prompt_embeds, + extra_step_kwargs, + negative_prompt_embeds, + negative_add_time_ids, + negative_pooled_prompt_embeds, + num_warmup_steps, + progress_bar, + callback, + callback_steps, + ip_adapter_image, + image_embeds, + timestep_cond, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ) + else: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + latents = self.run_unet( + self.unet, + latents, + timesteps, + t, + i, + add_text_embeds, + add_time_ids, + prompt_embeds, + extra_step_kwargs, + negative_prompt_embeds, + negative_add_time_ids, + negative_pooled_prompt_embeds, + num_warmup_steps, + progress_bar, + callback, + callback_steps, + ip_adapter_image, + image_embeds, + timestep_cond, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ) + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py index 8fe62c34b9..977b196e29 100644 --- a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py +++ b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import List, Optional, Tuple, Union import numpy as np import torch +from diffusers import EulerDiscreteScheduler from diffusers.configuration_utils import register_to_config -from diffusers.schedulers import EulerDiscreteScheduler from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteSchedulerOutput +from diffusers.utils.torch_utils import randn_tensor from optimum.utils import logging @@ -104,6 +104,7 @@ def __init__( self._initial_timestep = None self.reset_timestep_dependent_params() + self.hpu_opt = False def reset_timestep_dependent_params(self): self.are_timestep_dependent_params_set = False @@ -160,7 +161,13 @@ def scale_model_input( A scaled input sample. """ - sigma, _ = self.get_params(timestep) + if self.hpu_opt: + if self.step_index is None: + self._init_step_index(timestep) + self.sigmas = self.sigmas.to(sample.dtype) + sigma = self.sigmas[self.step_index] + else: + sigma, _ = self.get_params(timestep) sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample @@ -224,19 +231,33 @@ def step( "See `StableDiffusionPipeline` for a usage example." ) + if self.hpu_opt and self.step_index is None: + self._init_step_index(timestep) + # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - sigma, sigma_next = self.get_params(timestep) + if self.hpu_opt: + sigma = self.sigmas[self.step_index] + else: + sigma, sigma_next = self.get_params(timestep) - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + if self.hpu_opt and sigma.device.type == "hpu": + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) + else: + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 - device = model_output.device + if self.hpu_opt: + noise = randn_tensor( + model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + ) + else: + device = model_output.device - # torch.randn is broken on HPU so running it on CPU - noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator) - if device.type == "hpu": - noise = noise.to(device) + # torch.randn is broken on HPU so running it on CPU + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator) + if device.type == "hpu": + noise = noise.to(device) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) @@ -262,7 +283,10 @@ def step( # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat - dt = sigma_next - sigma_hat + if self.hpu_opt: + dt = self.sigmas[self.step_index + 1] - sigma_hat + else: + dt = sigma_next - sigma_hat prev_sample = sample + derivative * dt @@ -271,7 +295,8 @@ def step( # upon completion increase step index by one self._step_index += 1 - self.roll_params() + if not self.hpu_opt: + self.roll_params() if not return_dict: return (prev_sample,) diff --git a/optimum/habana/transformers/generation/candidate_generator.py b/optimum/habana/transformers/generation/candidate_generator.py index 633e39e9d6..171161074f 100644 --- a/optimum/habana/transformers/generation/candidate_generator.py +++ b/optimum/habana/transformers/generation/candidate_generator.py @@ -20,17 +20,17 @@ def __init__( input_ids: torch.LongTensor, assistant_model: "PreTrainedModel", generation_config: "GaudiGenerationConfig", - logits_processor: "LogitsProcessorList", model_kwargs: Dict, inputs_tensor: Optional[torch.Tensor] = None, + logits_processor: "LogitsProcessorList" = None, ): super().__init__( input_ids, assistant_model, generation_config, - logits_processor, model_kwargs, inputs_tensor, + logits_processor, ) # Remove model kwargs that are specific to optimized models diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d12c5ba169..94ecaa89cd 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -22,6 +22,7 @@ import torch import torch.distributed as dist +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantizedCacheConfig from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation.candidate_generator import ( @@ -37,10 +38,11 @@ MaxLengthCriteria, MaxTimeCriteria, StoppingCriteriaList, - validate_stopping_criteria, + StopStringCriteria, ) from transformers.generation.utils import ( NEED_SETUP_CACHE_CLASSES_MAPPING, + QUANT_BACKEND_CLASSES_MAPPING, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, GenerateBeamOutput, @@ -50,12 +52,14 @@ GenerateOutput, GenerationMixin, GenerationMode, + _ranking_fast, _split_model_inputs, _split_model_outputs, stack_model_outputs, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from transformers.utils import ModelOutput, is_torchdynamo_compiling +from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput +from transformers.utils import ModelOutput, is_hqq_available, is_quanto_available, is_torchdynamo_compiling from optimum.utils import logging @@ -68,6 +72,7 @@ if TYPE_CHECKING: from transformers import PreTrainedModel from transformers.streamers import BaseStreamer + from transformers.tokenization_utils_base import PreTrainedTokenizerBase from .candidate_generator import GaudiCandidateGenerator @@ -173,8 +178,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: Union[int, List[int]] = None, - bos_token_id: int = None, + decoder_start_token_id: torch.Tensor, device: torch.device = None, max_new_tokens: int = None, pad_token_id: int = None, @@ -182,7 +186,6 @@ def _prepare_decoder_input_ids_for_generation( """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: decoder_input_ids = model_kwargs.pop("decoder_input_ids") elif "input_ids" in model_kwargs and model_input_name != "input_ids": @@ -192,60 +195,57 @@ def _prepare_decoder_input_ids_for_generation( token_idx = model_kwargs.get("token_idx", None) - # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + # 2. `decoder_start_token_id` must have shape (batch_size, 1) if device is None: device = self.device if token_idx is None: - if isinstance(decoder_start_token_id, list): - if len(decoder_start_token_id) != batch_size: + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: raise ValueError( - f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" ) - decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) - decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) else: - decoder_input_ids_start = ( + decoder_start_token_id = ( torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id ) else: # creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length - decoder_input_ids_start = ( + decoder_start_token_id = ( torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id ) - decoder_input_ids_start = torch.nn.functional.pad( - decoder_input_ids_start, (0, max_length - 1), value=pad_token_id + decoder_start_token_id = torch.nn.functional.pad( + decoder_start_token_id, (0, max_length - 1), value=pad_token_id ) + # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: - decoder_input_ids = decoder_input_ids_start - # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token - elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + decoder_input_ids = decoder_start_token_id + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the + # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. + # See: https://github.com/huggingface/transformers/pull/31470 + elif "donut" in self.__class__.__name__.lower() or ( + self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() + ): pass elif self.config.model_type in ["whisper"]: pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) - elif ( - isinstance(decoder_start_token_id, int) - and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item() - ) or ( - isinstance(decoder_start_token_id, torch.Tensor) - and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() - ): + elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): if token_idx is None: - decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) else: max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length - if max_length != decoder_input_ids_start.shape[-1]: - decoder_input_ids_start = torch.nn.functional.pad( - decoder_input_ids_start, - (0, max_length - decoder_input_ids_start.shape[-1]), + if max_length != decoder_start_token_id.shape[-1]: + decoder_start_token_id = torch.nn.functional.pad( + decoder_start_token_id, + (0, max_length - decoder_start_token_id.shape[-1]), value=pad_token_id, ) - decoder_input_ids = decoder_input_ids_start.index_copy(1, token_idx, decoder_input_ids) + decoder_input_ids = decoder_start_token_id.index_copy(1, token_idx, decoder_input_ids) token_idx.add_(1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] @@ -344,6 +344,7 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, + num_new_tokens: int = 1, ) -> Dict[str, Any]: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/527ab894e59b6582578008e3b47648a65063f73d/src/transformers/generation/utils.py#L745 @@ -353,10 +354,11 @@ def _update_model_kwargs_for_generation( # mark to identify starting from second token model_kwargs["first_token"] = False if not model_kwargs.get("pad_done", False): - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( + # update past_key_values keeping its naming used in model code + cache_name, cache = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) + model_kwargs[cache_name] = cache if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state @@ -400,7 +402,14 @@ def _update_model_kwargs_for_generation( model_kwargs["token_idx_cpu"] += 1 if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + if model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + else: + past_positions = model_kwargs.pop("cache_position") + new_positions = torch.arange( + past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs @@ -495,9 +504,9 @@ def _get_candidate_generator( input_ids=input_ids, assistant_model=assistant_model, generation_config=generation_config, - logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, + logits_processor=logits_processor, ) return candidate_generator @@ -505,7 +514,8 @@ def _get_stopping_criteria( self, generation_config: GaudiGenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], - ignore_eos: bool = False, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + **kwargs, ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -518,8 +528,16 @@ def _get_stopping_criteria( ) if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) - if not ignore_eos and generation_config.eos_token_id is not None: - criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) + if generation_config.stop_strings is not None: + if tokenizer is None: + raise ValueError( + "There are one or more stop strings, either in the arguments to `generate` or in the " + "model's generation config, but we could not locate a tokenizer. When generating with " + "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." + ) + criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) + if not generation_config.ignore_eos and generation_config._eos_token_tensor is not None: + criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -581,7 +599,7 @@ def _prepare_generated_length( return generation_config def _prepare_generation_config( - self, generation_config: GaudiGenerationConfig, **kwargs: Dict + self, generation_config: Optional[GaudiGenerationConfig], **kwargs: Dict ) -> Tuple[GaudiGenerationConfig, Dict]: """ Copied from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/generation/utils.py#L1230 @@ -594,6 +612,7 @@ def _prepare_generation_config( # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False if generation_config is None: # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, # three conditions must be met @@ -616,6 +635,7 @@ def _prepare_generation_config( " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config + using_model_generation_config = True generation_config = self.generation_config # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` @@ -646,6 +666,16 @@ def _prepare_generation_config( if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): for key in ["token_type_ids"]: model_kwargs.pop(key, None) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.decoder_start_token_id is None: + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id return generation_config, model_kwargs @@ -773,6 +803,7 @@ def generate( # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria if hpu_graphs and not lazy_mode: raise ValueError( "`hpu_graphs` is True but `lazy_mode` is False. HPU graphs require `lazy_mode` to be set to True." @@ -780,6 +811,7 @@ def generate( num_virtual_tokens = kwargs.pop("num_virtual_tokens", 0) generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -787,36 +819,39 @@ def generate( synced_gpus = True else: synced_gpus = False + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation." - ) - generation_config.pad_token_id = eos_token_id + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # decoder-only models must use left-padding for batched generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + generation_config._pad_token_tensor is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + # 4. Define other model kwargs - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs["output_hidden_states"] = generation_config.output_hidden_states # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": @@ -826,21 +861,22 @@ def generate( self.generation_config.max_length = generation_config.max_length - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) is_greedy_or_beam_and_bucket = ( not generation_config.bucket_internal and generation_config.bucket_size > 0 - and ( - generation_config.get_generation_mode(assistant_model) == GenerationMode.GREEDY_SEARCH - or generation_config.get_generation_mode(assistant_model) == GenerationMode.BEAM_SEARCH - ) + and generation_config.get_generation_mode(assistant_model) + in [ + GenerationMode.GREEDY_SEARCH, + GenerationMode.SAMPLE, + GenerationMode.BEAM_SEARCH, + GenerationMode.BEAM_SAMPLE, + GenerationMode.CONTRASTIVE_SEARCH, + ] ) model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1 model_kwargs["bucket_internal"] = generation_config.bucket_internal @@ -908,26 +944,10 @@ def generate( inputs_tensor.device, ) - # decoder-only models should use left-padding for generation - if not self.config.is_encoder_decoder: - # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` - # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. - if generation_config.pad_token_id is not None: - position = model_kwargs["token_idx"] - 1 if "token_idx" in model_kwargs else -1 - if ( - len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, position] == generation_config.pad_token_id) > 0 - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created - # and added to `model_kwargs` + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name + inputs_tensor, model_kwargs, model_input_name, generation_config ) # 5. Prepare `input_ids` which will be used for auto-regressive generation @@ -936,8 +956,7 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, max_new_tokens=generation_config.max_new_tokens, pad_token_id=generation_config.pad_token_id, @@ -945,6 +964,9 @@ def generate( else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + if streamer is not None: streamer.put(input_ids.cpu()) @@ -962,19 +984,76 @@ def generate( has_token_idx="token_idx" in model_kwargs, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static": - if model_kwargs.get("past_key_values", False) is not False: + use_dynamic_cache_by_default = False + if "mamba" in self.__class__.__name__.lower(): + cache_name = "cache_params" + else: + cache_name = "past_key_values" + if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: raise ValueError( - "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" ) - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): + model_kwargs[cache_name] = self._get_cache( + generation_config.cache_implementation, + getattr(generation_config, "num_beams", 1) * batch_size, + generation_config.max_length, + model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + # elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): + # past = model_kwargs.get(cache_name, None) + # requires_cross_attention_cache = ( + # self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + # ) + # if past is None: + # model_kwargs[cache_name] = ( + # DynamicCache() + # if not requires_cross_attention_cache + # else EncoderDecoderCache(DynamicCache(), DynamicCache()) + # ) + # use_dynamic_cache_by_default = True + # elif isinstance(past, tuple): + # model_kwargs[cache_name] = ( + # DynamicCache.from_legacy_cache(past) + # if not requires_cross_attention_cache + # else EncoderDecoderCache.from_legacy_cache(past) + # ) + # use_dynamic_cache_by_default = True self._validate_generated_length( generation_config, @@ -1029,11 +1108,13 @@ def generate( assert generation_config.static_shapes, "bucket_size > 0 can be set only when static_shapes is set" # if generation_config.bucket_size <= 0, padding is handled by the generating fn (like greedy_search) if generation_config.static_shapes and generation_config.bucket_size > 0: - assert ( - generation_mode == GenerationMode.GREEDY_SEARCH - or generation_mode == GenerationMode.SAMPLE - or generation_mode == GenerationMode.BEAM_SEARCH - ), "generation_config.bucket_size > 0 supported only for greedy mode" + assert generation_mode in [ + GenerationMode.GREEDY_SEARCH, + GenerationMode.SAMPLE, + GenerationMode.BEAM_SEARCH, + GenerationMode.BEAM_SAMPLE, + GenerationMode.CONTRASTIVE_SEARCH, + ], "generation_config.bucket_size > 0 supported only for greedy mode" if streamer is not None and (generation_config.num_beams > 1): raise ValueError( @@ -1060,6 +1141,7 @@ def generate( encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, + device=inputs_tensor.device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, @@ -1070,7 +1152,8 @@ def generate( prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, - ignore_eos=self.generation_config.ignore_eos, + tokenizer=tokenizer, + **kwargs, ) # In lazy mode, import Habana torch to be able to add mark_step() @@ -1090,6 +1173,14 @@ def generate( raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") + if generation_config.cache_implementation == "static": + raise ValueError("assisted generate is not supported with `static_cache`") + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) # 11. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( @@ -1101,18 +1192,24 @@ def generate( model_kwargs=model_kwargs, ) - # 12. run assisted generate + # 12. prepare logits warper (if `do_sample` is `True`) + prepared_logits_warper = ( + self._get_logits_warper( + generation_config, + device=input_ids.device, + ) + if generation_config.do_sample + else None + ) + + # 13. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, - do_sample=generation_config.do_sample, logits_processor=prepared_logits_processor, - logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, + logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, lazy_mode=lazy_mode, @@ -1122,44 +1219,47 @@ def generate( hb_gen_time=hb_gen_time, **model_kwargs, ) - if generation_mode == GenerationMode.GREEDY_SEARCH: - # 11. run greedy search - result = self._greedy_search( + elif generation_mode == GenerationMode.DOLA_GENERATION: + if self._is_stateful: + # DoLa decoding was not designed for stateful models, and would require some changes + raise ValueError( + f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" + ) + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) + result = self._dola_decoding( input_ids, + dola_layers=generation_config.dola_layers, logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, - lazy_mode=lazy_mode, - ignore_eos=generation_config.ignore_eos, - profiling_warmup_steps=profiling_warmup_steps, - profiling_steps=profiling_steps, - hb_gen_time=hb_gen_time, - profiling_record_shapes=profiling_record_shapes, **model_kwargs, ) elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") + if self._is_stateful: + # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) + raise ValueError( + f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" + ) result = self._contrastive_search( input_ids, - top_k=generation_config.top_k, - penalty_alpha=generation_config.penalty_alpha, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, - sequential=generation_config.low_memory, + lazy_mode=lazy_mode, + ignore_eos=generation_config.ignore_eos, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, hb_gen_time=hb_gen_time, @@ -1167,9 +1267,13 @@ def generate( **model_kwargs, ) - elif generation_mode == GenerationMode.SAMPLE: + elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1179,16 +1283,13 @@ def generate( **model_kwargs, ) - # 13. run sample + # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, - logits_warper=logits_warper, + logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, lazy_mode=lazy_mode, @@ -1200,47 +1301,13 @@ def generate( **model_kwargs, ) - elif generation_mode == GenerationMode.BEAM_SEARCH: - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run beam search - result = self._beam_search( - input_ids, - beam_scorer, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, - synced_gpus=synced_gpus, - sequential=generation_config.low_memory, - lazy_mode=lazy_mode, - profiling_warmup_steps=profiling_warmup_steps, - profiling_steps=profiling_steps, - hb_gen_time=hb_gen_time, - profiling_record_shapes=profiling_record_shapes, - **model_kwargs, - ) - - elif generation_mode == GenerationMode.BEAM_SAMPLE: + elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( @@ -1262,16 +1329,13 @@ def generate( ) # 14. run beam sample - result = self._beam_sample( + result = self._beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, - logits_warper=logits_warper, + logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, @@ -1306,10 +1370,7 @@ def generate( beam_scorer, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, @@ -1384,10 +1445,7 @@ def typeerror(): constrained_beam_scorer=constrained_beam_scorer, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, - pad_token_id=generation_config.pad_token_id, - output_scores=generation_config.output_scores, - output_logits=generation_config.output_logits, - return_dict_in_generate=generation_config.return_dict_in_generate, + generation_config=generation_config, synced_gpus=synced_gpus, lazy_mode=lazy_mode, profiling_warmup_steps=profiling_warmup_steps, @@ -1397,153 +1455,79 @@ def typeerror(): **model_kwargs, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not callable(getattr(self, "_reset_cache", None)): - raise ValueError( - "A `static_cache` was used to generate but there was a failure when trying to release the cache. " - " Make sure this model implements a `_reset_cache` function." - ) - self._reset_cache() + # Convert to legacy cache if needed + if use_dynamic_cache_by_default and generation_config.return_legacy_cache: + if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): + if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): + result.past_key_values = result.past_key_values.to_legacy_cache() return result - @torch.no_grad() - def _contrastive_search( + def _dola_decoding( self, input_ids: torch.LongTensor, - top_k: Optional[int] = 1, - penalty_alpha: Optional[float] = 0, - logits_processor: Optional[LogitsProcessorList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, - sequential: Optional[bool] = None, - lazy_mode: Optional[bool] = False, - profiling_warmup_steps: Optional[int] = 0, - profiling_steps: Optional[int] = 0, - hb_gen_time: Optional[HabanaGenerationtime] = None, - profiling_record_shapes: Optional[bool] = False, + dola_layers: Union[str, List[int]], + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, + streamer: "BaseStreamer", + logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head using **contrastive search** and can - be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - - In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use - generate() instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - + Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be + used for decoder-only text models. + The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language + Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - top_k (`int`, *optional*, defaults to 1): - The size of the candidate set that is used to re-rank for contrastive search - penalty_alpha (`float`, *optional*, defaults to 0): - The degeneration penalty for contrastive search; activate when it is larger than 0 - logits_processor (`LogitsProcessorList`, *optional*): + dola_layers (`Union[str, List[int]]`): + The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which + means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices + to be used for candidate layers. The 0-th layer is the word embedding layer of the model. + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors - for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - lazy_mode (`bool`, *optional*, defaults to `False`): - Whether the run is executed in lazy mode or not (i.e. eager mode). - profiling_warmup_steps (`int`, *optional*, defaults to 0): - Number of steps to ignore for profling. - profiling_steps (`int`, *optional*, defaults to 0): - Number of steps to be captured when enabling profiling. - profiling_record_shapes (`bool`, *optional*, defaults to False): - Record shapes when enabling profiling. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: - [`transformers.generation.GenerateDecoderOnlyOutput`], - [`transformers.generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` - containing the generated tokens (default behaviour) or a - [`transformers.generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`transformers.generation.GenerateEncoderDecoderOutput`] if + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + """ + + raise NotImplementedError("Dola decoding is not supported by optimum-habana yet.") - Examples: - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - >>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") - >>> # set pad_token_id to eos_token_id because OPT does not have a PAD token - >>> model.config.pad_token_id = model.config.eos_token_id - >>> input_prompt = "DeepMind Company is" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt") - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)]) - >>> outputs = model._contrastive_search( - ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria - ... ) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] - ```""" - - raise NotImplementedError("Contrastive search is not supported by optimum-habana yet.") - - def _greedy_search( + @torch.no_grad() + def _contrastive_search( self, input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], lazy_mode: Optional[bool] = False, ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, @@ -1553,48 +1537,27 @@ def _greedy_search( **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be - used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - - In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate() - instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). + Generates sequences of token ids for models with a language modeling head using **contrastive search** and can + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - + Adapted from: https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/generation/utils.py#L2453 + The changes are: + - support lazy mode and HPU graphs on Gaudi + - support static shapes and bucketing Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors - for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed @@ -1614,97 +1577,24 @@ def _greedy_search( If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: - [`transformers.generation.GenerateDecoderOnlyOutput`], [`transformers.generation.GenerateEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`transformers.generation.GenerateDecoderOnlyOutput`], + [`transformers.generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` + containing the generated tokens (default behaviour) or a [`transformers.generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`transformers.generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token - >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id - - >>> input_prompt = "It might be possible to" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - - >>> outputs = model._greedy_search( - ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] - ```""" + """ # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if not self.generation_config.ignore_eos: - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # TODO remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() - for criteria in stopping_criteria - if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + top_k = generation_config.top_k + penalty_alpha = generation_config.penalty_alpha + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + sequential = generation_config.low_memory # init attention / hidden states / scores tuples raw_logits = () if (return_dict_in_generate and output_logits) else None @@ -1724,11 +1614,13 @@ def _greedy_search( batch_size, cur_len = input_ids.shape if "inputs_embeds" in model_kwargs: cur_len = model_kwargs["inputs_embeds"].shape[1] - this_peer_finished = False + if not ignore_eos: unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + this_peer_finished = False + hb_profer = HabanaProfile( warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes ) @@ -1743,7 +1635,9 @@ def _greedy_search( inc = iter(incrementor(bucket_size, cur_len)) if bucket_size > 0: assert "position_ids" not in model_kwargs, "Untested path" + token_idx = model_kwargs.get("token_idx", None) + top_k_ids = None if token_idx is not None: # Update cur_len in case of static shapes cur_len = token_idx.item() @@ -1751,6 +1645,9 @@ def _greedy_search( time_to_first_token_done = False model_kwargs["pad_done"] = False model_kwargs["lazy_mode"] = lazy_mode + + batch_indices = torch.arange(batch_size, device=input_ids.device) + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): if lazy_mode: self.htcore_generation.mark_step() @@ -1762,55 +1659,114 @@ def _greedy_search( params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile ) - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; + # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step + if model_kwargs.get("past_key_values") is None or ( + isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) + and model_kwargs["past_key_values"].get_seq_length() == 0 + ): + # prepare inputs + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) + hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - **hpu_graphs_kwargs, - ) - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save + # the `encoder_outputs` + outputs = self( + **model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + **hpu_graphs_kwargs, + ) - token_idx = model_kwargs.get("token_idx", None) - if token_idx is not None and outputs.logits.shape[-2] > 1: - # case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size] + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with + # previous tokens) if self.config.is_encoder_decoder: - next_token_logits = outputs.logits[:, token_idx - 1, :] - next_tokens_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) + last_hidden_states = outputs.decoder_hidden_states[-1] else: - if model_kwargs.get("num_virtual_tokens", 0) > 0: - # for prompt tuning, the output logit shape > model_inputs["input_ids"].shape[-1] - if model_kwargs.get("reuse_cache", False): - output_idx = torch.tensor(outputs.logits.shape[-2], device=input_ids.device) - else: - output_idx = token_idx + outputs.logits.shape[-2] - input_ids.shape[-1] - next_token_logits = torch.index_select(outputs.logits, -2, output_idx - 1).squeeze(-2) + last_hidden_states = outputs.hidden_states[-1] + + # next logit for contrastive search to select top-k candidate tokens + token_idx = model_kwargs.get("token_idx", None) + if token_idx is not None and outputs.logits.shape[-2] > 1: + last_hidden_states = last_hidden_states[:, :token_idx, :] + # case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size] + if self.config.is_encoder_decoder: + logit_for_next_step = outputs.logits[:, token_idx - 1, :] else: - next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) + logit_for_next_step = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) + else: + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration + # (the clone itself is always small) + logit_for_next_step = outputs.logits[:, -1, :].clone() + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + standardize_cache_format=True, + ) + + if not sequential: + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + _, model_kwargs = self._expand_inputs_for_generation( + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + + past_key_values = model_kwargs.get("past_key_values") + if past_key_values is None: + raise ValueError( + f"{self.__class__.__name__} does not support caching and therefore **can't** be used " + "for contrastive search." + ) + elif ( + ( + not isinstance(past_key_values[0], (tuple, torch.Tensor)) + and not isinstance(past_key_values[0], (list, torch.Tensor)) + ) # Added list type to support GaudiLlamaForCausalLM + or past_key_values[0][0].shape[0] != batch_size + ): + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) - next_tokens_scores = logits_processor(input_ids, next_token_logits) + if lazy_mode: + self.htcore_generation.mark_step() + + # contrastive_search main logic start: + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by + # degeneration penalty + if token_idx is not None and self.config.is_encoder_decoder: + processed_logit_for_next_step = logits_processor(input_ids[:, :token_idx], logit_for_next_step) else: - next_token_logits = outputs.logits[:, -1, :] - if token_idx is not None and self.config.is_encoder_decoder: - # case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size] - next_tokens_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) - else: - # case3 (default case): token_idx is None - next_tokens_scores = logits_processor(input_ids, next_token_logits) + processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + + next_probs = torch.nn.functional.softmax(processed_logit_for_next_step, dim=-1) + + if token_idx is not None: + if top_k_ids is None: + top_k_ids = torch.full( + (batch_size, top_k, input_ids.shape[-1]), pad_token_id, dtype=torch.int64 + ).to(input_ids.device) + elif bucket_size > 0 and not bucket_internal: + if input_ids.shape[-1] > top_k_ids.shape[-1]: # needs expansion + pad_amount = input_ids.shape[-1] - top_k_ids.shape[-1] + top_k_ids = torch.nn.functional.pad(top_k_ids, (0, pad_amount), value=pad_token_id) + + top_k_probs, top_k_prob_ids = torch.topk(next_probs, dim=-1, k=top_k) + top_k_ids[:, :, token_idx - 1] = top_k_prob_ids + else: + top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) # Store scores, attentions and hidden_states when required if return_dict_in_generate: - if output_scores: - scores += (next_tokens_scores,) if output_logits: - raw_logits += (next_token_logits,) + raw_logits += (logit_for_next_step,) + if output_scores: + scores += (processed_logit_for_next_step,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -1825,24 +1781,207 @@ def _greedy_search( else (outputs.hidden_states,) ) - # argmax - next_tokens = torch.argmax(next_tokens_scores, dim=-1) + # This is needed to properly delete outputs.logits which may be very large for this first iteration + # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() + del outputs + + if not sequential: + # Replicates the new past_key_values to match the `top_k` candidates + past = model_kwargs["past_key_values"] + # If it is a static cache, modify it in-place layer after layer to save memory + if isinstance(past, DynamicCache) or ( + isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) + ): + past.batch_repeat_interleave(top_k) + else: + new_key_values = [] + for layer in past: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(tuple(items)) + + past = tuple(new_key_values) + + model_kwargs["past_key_values"] = past + + if sequential: + all_outputs = [] + for i in range(top_k): + # compute the candidate tokens by the language model and collect their hidden_states + if token_idx is not None: + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids[:, i, :].view(-1, input_ids.shape[-1]), **model_kwargs + ) + else: + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids[:, i].view(-1, 1), **model_kwargs + ) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + if isinstance(outputs["past_key_values"], DynamicCache) or ( + isinstance(outputs["past_key_values"], EncoderDecoderCache) + and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) + ): + # Remove past K-V from output since we don't need to stack later + outputs["past_key_values"] = None + # Remove last token from past K-V since we don't want to append it at this point + model_kwargs["past_key_values"].crop(-1) + + all_outputs.append(outputs) + outputs = stack_model_outputs(all_outputs) + + else: + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k + if token_idx is not None: + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids.view(-1, input_ids.shape[-1]), **model_kwargs + ) + else: + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + + # This is essential to avoid having a last reference to the big past K-V and double the necesary memory + # in the next loop + del next_model_inputs + + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + logits = outputs.logits[:, -1, :] + context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) + + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't + # introduce (noticeable) slowdowns on single-device runs. + selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + + # This will be used instead of the previous inneficient torch.stack(torch.split()) + augmented_idx = torch.tensor( + [x + i * top_k for i, x in enumerate(selected_idx)], device=selected_idx.device + ) + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing + # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores + # (model confidence minus degeneration penalty); (6) decoder hidden_states + top_k_indices = torch.arange(len(top_k_ids), device=input_ids.device) + if token_idx is not None: + next_tokens = top_k_ids[top_k_indices, selected_idx, token_idx - 1] + else: + next_tokens = top_k_ids[top_k_indices, selected_idx] + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) + next_hidden = next_hidden[batch_indices, selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + next_decoder_hidden_states = () + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer, top_k))[batch_indices, selected_idx, :] + next_decoder_hidden_states += (layer,) + + # generate past_key_values cache of only the selected token + if sequential: + if token_idx is not None: + next_model_input = self.prepare_inputs_for_generation( + top_k_ids[:, selected_idx, :].view(-1, input_ids.shape[-1]), **model_kwargs + ) + else: + next_model_input = self.prepare_inputs_for_generation( + top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs + ) + + selected_outputs = self( + **next_model_input, + return_dict=True, + output_hidden_states=False, + output_attentions=False, + ) + next_past_key_values = selected_outputs["past_key_values"] + + else: + _, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + # Do it in-place layer per layer to save memory + if isinstance(next_past_key_values, DynamicCache) or ( + isinstance(next_past_key_values, EncoderDecoderCache) + and isinstance(next_past_key_values.self_attention_cache, DynamicCache) + ): + next_past_key_values.batch_select_indices(augmented_idx) + else: + new_key_values = [] + for layer in next_past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item[augmented_idx, ...]) + new_key_values.append(tuple(items)) + + next_past_key_values = tuple(new_key_values) + + logit_for_next_step = torch.stack(torch.split(logits, top_k))[batch_indices, selected_idx, :] + + # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration + if self.config.is_encoder_decoder: + next_step_cross_attentions = () + next_step_decoder_attentions = () + if output_attentions: + for layer in outputs.cross_attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[batch_indices, selected_idx, ...] + next_step_cross_attentions += (layer,) + for layer in outputs.decoder_attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[batch_indices, selected_idx, ...] + next_step_decoder_attentions += (layer,) + outputs = Seq2SeqLMOutput( + past_key_values=next_past_key_values, + decoder_hidden_states=next_decoder_hidden_states, + decoder_attentions=next_step_decoder_attentions or None, + cross_attentions=next_step_cross_attentions or None, + ) + else: + next_step_attentions = () + if output_attentions: + for layer in outputs.attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[batch_indices, selected_idx, ...] + next_step_attentions += (layer,) + outputs = CausalLMOutputWithPast( + past_key_values=next_past_key_values, + hidden_states=next_decoder_hidden_states, + attentions=next_step_attentions or None, + ) + # contrastive_search main logic end + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + # finished sentences should have their next token be a padding token - if not ignore_eos and eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + if not ignore_eos and has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step - if not lazy_mode: - next_tokens = next_tokens.to(input_ids.dtype) - if token_idx is not None: + # Use token_idx-1 since token index is incremented twice in first iteration input_ids.index_copy_( - 1, token_idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens + 1, token_idx - 1, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens ) else: input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( @@ -1850,6 +1989,9 @@ def _greedy_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + + # increase cur_len + cur_len = cur_len + 1 if bucket_size > 0 and bucket_internal: # Calculate slice idx for kv cache during the decode phase. # Breaking down the kv cache in the attention block helps to reduce computation time. @@ -1860,37 +2002,46 @@ def _greedy_search( prev_idx = idx else: model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] - cur_len = cur_len + 1 + # stop when each sentence is finished if ignore_eos: this_peer_finished = stopping_criteria( - input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id + input_ids, + scores, + token_idx=cur_len, + ignore_eos=ignore_eos, + eos_token_id=generation_config.eos_token_id, ) else: unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id + input_ids, + scores, + token_idx=cur_len, + ignore_eos=ignore_eos, + eos_token_id=generation_config.eos_token_id, ) this_peer_finished = unfinished_sequences.max() == 0 - hb_profer.step() - if hb_gen_time is not None: - if not time_to_first_token_done: - time_to_first_token_done = True - import habana_frameworks.torch.hpu as torch_hpu - - torch_hpu.synchronize() - hb_gen_time.step() if ( not model_kwargs.get("pad_done", False) and not model_kwargs.get("reuse_cache", False) and bucket_internal ): - # Pad the returned past key values tensors from prefill phase forward run to maximum length + # Pad the returned pask key values tensors from prefill phase forward run to maximum length # before starting the decode phase. - if outputs.past_key_values[0][0].shape[2] == model_inputs["input_ids"].shape[1]: - self._pad_past_key_values(model_kwargs) + self._pad_past_key_values(model_kwargs) model_kwargs["pad_done"] = True + hb_profer.step() + + if hb_gen_time is not None: + if not time_to_first_token_done: + time_to_first_token_done = True + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + hb_gen_time.step() + if ( model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) @@ -1907,6 +2058,23 @@ def _greedy_search( streamer.end() if return_dict_in_generate: + # Contrastive search works by forward looking at the next token, so we need to exclude it from + # `past_key_values` to be consistent with the other decoding methods + if model_kwargs.get("past_key_values") is not None: + if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( + isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) + and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) + ): + model_kwargs["past_key_values"].crop(-1) + else: + past_key_values = [] + for layer in model_kwargs["past_key_values"]: + layer_past_key_values = [] + for item in layer: + layer_past_key_values.append(item[..., :-1, :]) + past_key_values.append(tuple(layer_past_key_values)) + model_kwargs["past_key_values"] = tuple(past_key_values) + if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( sequences=input_ids, @@ -1934,19 +2102,12 @@ def _greedy_search( def _sample( self, input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList], lazy_mode: Optional[bool] = False, ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, @@ -1959,52 +2120,27 @@ def _sample( Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead. - For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`GaudiGenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in + `generation_config`) lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). ignore_eos (`bool`, *optional*, defaults to `False`): @@ -2026,110 +2162,21 @@ def _sample( `return_dict_in_generate=True` or a [`transformers.generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... TopKLogitsWarper, - ... TemperatureLogitsWarper, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token - >>> model.config.pad_token_id = model.config.eos_token_id - >>> model.generation_config.pad_token_id = model.config.eos_token_id - - >>> input_prompt = "Today is a beautiful day, and" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - >>> # instantiate logits processors - >>> logits_warper = LogitsProcessorList( - ... [ - ... TopKLogitsWarper(50), - ... TemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - - >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model._sample( - ... input_ids, - ... logits_processor=logits_processor, - ... logits_warper=logits_warper, - ... stopping_criteria=stopping_criteria, - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] - ```""" - + """ # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - ), - UserWarning, + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if not self.generation_config.ignore_eos: - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # TODO remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() - for criteria in stopping_criteria - if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2146,13 +2193,10 @@ def _sample( ) # keep track of which sequences are already finished - # TODO: no ignore_eos check here since there is a compilation error, will add ignore_eos here if fixed batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) bucket_size = model_kwargs.get("bucket_size", -1) prev_idx = -1 # avoiding calculate cache_idx when its value is not changing @@ -2175,7 +2219,6 @@ def _sample( # Update cur_len in case of static shapes cur_len = token_idx.item() - # auto-regressive generation time_to_first_token_done = False model_kwargs["pad_done"] = False model_kwargs["lazy_mode"] = lazy_mode @@ -2193,14 +2236,16 @@ def _sample( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **hpu_graphs_kwargs, ) @@ -2209,21 +2254,35 @@ def _sample( token_idx = model_kwargs.get("token_idx", None) if token_idx is not None and outputs.logits.shape[-2] > 1: - if model_kwargs.get("num_virtual_tokens", 0) > 0: - # for prompt tuning, the output logit shape > model_inputs["input_ids"].shape[-1] - if model_kwargs.get("reuse_cache", False): - output_idx = torch.tensor(outputs.logits.shape[-2], device=input_ids.device) - else: - output_idx = token_idx + outputs.logits.shape[-2] - input_ids.shape[-1] - next_token_logits = torch.index_select(outputs.logits, -2, output_idx - 1).squeeze(-2) + # case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size] + if self.config.is_encoder_decoder: + next_token_logits = outputs.logits[:, token_idx - 1, :] + next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) else: - next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) + if model_kwargs.get("num_virtual_tokens", 0) > 0: + # for prompt tuning, the output logit shape > model_inputs["input_ids"].shape[-1] + if model_kwargs.get("reuse_cache", False): + output_idx = torch.tensor(outputs.logits.shape[-2], device=input_ids.device) + else: + output_idx = token_idx + outputs.logits.shape[-2] - input_ids.shape[-1] + next_token_logits = torch.index_select(outputs.logits, -2, output_idx - 1).squeeze(-2) + else: + next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) + next_token_scores = logits_processor(input_ids, next_token_logits) else: - next_token_logits = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() + if token_idx is not None and self.config.is_encoder_decoder: + # case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size] + next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits) + else: + # case3 (default case): token_idx is None + next_token_scores = logits_processor(input_ids, next_token_logits) # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) + if do_sample: + next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2245,15 +2304,16 @@ def _sample( else (outputs.hidden_states,) ) - # sample - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # token selection + if do_sample: + probs = torch.nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token # TODO: no ignore_eos check here since there is a compilation error, will add ignore_eos here if fixed - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step @@ -2266,13 +2326,16 @@ def _sample( ) else: input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + cur_len = cur_len + 1 if bucket_size > 0 and bucket_internal: # Calculate slice idx for kv cache during the decode phase. @@ -2287,13 +2350,22 @@ def _sample( if ignore_eos: this_peer_finished = stopping_criteria( - input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id + input_ids, + scores, + token_idx=cur_len, + ignore_eos=ignore_eos, + eos_token_id=generation_config.eos_token_id, ) else: unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores, token_idx=cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id + input_ids, + scores, + token_idx=cur_len, + ignore_eos=ignore_eos, + eos_token_id=generation_config.eos_token_id, ) this_peer_finished = unfinished_sequences.max() == 0 + hb_profer.step() if hb_gen_time is not None: if not time_to_first_token_done: @@ -2314,6 +2386,10 @@ def _sample( self._pad_past_key_values(model_kwargs) model_kwargs["pad_done"] = True + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + if ( model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) @@ -2326,6 +2402,7 @@ def _sample( self._remove_past_key_values(model_kwargs) hb_profer.stop() + if streamer is not None: streamer.end() @@ -2358,18 +2435,11 @@ def _beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - sequential: Optional[bool] = None, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, + logits_warper: Optional[LogitsProcessorList], lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -2381,52 +2451,27 @@ def _beam_search( Generates sequences of token ids for models with a language modeling head using **beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate() - instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`StoppingCriteriaList`: An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`GaudiGenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - sequential (`bool`, defaults to `False`): - By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for - more details). This flag will avoid parallelizing the beam search and will instead run beam search - sequentially. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in + `generation_config`) lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). profiling_warmup_steps (`int`, *optional*, defaults to 0): @@ -2445,110 +2490,22 @@ def _beam_search( [`transformers.generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`transformers.generation.GenerateBeamEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" + """ # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - sequential = sequential if sequential is not None else self.generation_config.low_memory - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - ), - UserWarning, + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + sequential = generation_config.low_memory + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if not self.generation_config.ignore_eos: - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() - for criteria in stopping_criteria - if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -2708,6 +2665,10 @@ def expand_if_needed(tensor, new_size, value, dim=-1): model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + # if sequential is True, split the input to batches of batch_size and run sequentially if sequential: if any( @@ -2733,13 +2694,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): model_inputs, split_size=batch_size, full_batch_size=batch_beam_size ) outputs_per_sub_batch = [ - self( - **inputs_per_sub_batch, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - for inputs_per_sub_batch in inputs_per_sub_batches + self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches ] outputs = stack_model_outputs(outputs_per_sub_batch) @@ -2748,8 +2703,6 @@ def expand_if_needed(tensor, new_size, value, dim=-1): outputs = self( **model_inputs, return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **hpu_graphs_kwargs, ) @@ -2769,7 +2722,9 @@ def expand_if_needed(tensor, new_size, value, dim=-1): else: next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) else: - next_token_logits = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() next_token_scores = torch.nn.functional.log_softmax( next_token_logits, dim=-1 @@ -2779,6 +2734,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1): next_token_scores_processed = logits_processor(input_ids[:, :token_idx], next_token_scores) else: next_token_scores_processed = logits_processor(input_ids, next_token_scores) + if do_sample: + next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -2806,11 +2763,20 @@ def expand_if_needed(tensor, new_size, value, dim=-1): vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True - ) + # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 + # non eos token per beam. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams + if do_sample: + probs = torch.nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + else: + next_token_scores, next_tokens = torch.topk( + next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True + ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") if self.generation_config.static_shapes: @@ -2854,6 +2820,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): beam_idx = static_beam_indices else: next_tokens = next_tokens % vocab_size + # stateless beam_outputs = beam_scorer.process( input_ids, @@ -2868,6 +2835,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] + if token_idx is not None: input_ids = torch.index_select(input_ids, 0, beam_idx) input_ids.index_copy_( @@ -2889,6 +2857,12 @@ def expand_if_needed(tensor, new_size, value, dim=-1): model_kwargs["past_key_values"], beam_idx ) + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) @@ -2990,176 +2964,14 @@ def move(obj, device): else: return sequence_outputs["sequences"] - def _beam_sample( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - lazy_mode: Optional[bool] = False, - profiling_warmup_steps: Optional[int] = 0, - profiling_steps: Optional[int] = 0, - hb_gen_time: Optional[HabanaGenerationtime] = None, - profiling_record_shapes: Optional[bool] = False, - **model_kwargs, - ) -> Union[GenerateBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search multinomial - sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - - In most cases, you do not need to call [`~generation._GenerationMixin.beam_sample`] directly. Use generate() - instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - lazy_mode (`bool`, *optional*, defaults to `False`): - Whether the run is executed in lazy mode or not (i.e. eager mode). - profiling_warmup_steps (`int`, *optional*, defaults to 0): - Number of steps to ignore for profling. - profiling_steps (`int`, *optional*, defaults to 0): - Number of steps to be captured when enabling profiling. - profiling_record_shapes (`bool`, *optional*, defaults to False): - Record shapes when enabling profiling. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`transformers.generation.GenerateBeamDecoderOnlyOutput`], [`transformers.generation.GenerateBeamEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`transformers.generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`transformers.generation.GenerateBeamEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... TopKLogitsWarper, - ... TemperatureLogitsWarper, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... max_length=model.config.max_length, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] - ... ) - >>> # instantiate logits processors - >>> logits_warper = LogitsProcessorList( - ... [ - ... TopKLogitsWarper(50), - ... TemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> outputs = model._beam_sample( - ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - - raise NotImplementedError("Beam search sampling is not supported by optimum-habana yet.") - def _group_beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -3171,47 +2983,21 @@ def _group_beam_search( Generates sequences of token ids for models with a language modeling head using **diverse beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use - generate() instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`GaudiGenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). @@ -3231,63 +3017,7 @@ def _group_beam_search( [`transformers.generation.GenerateBeamDecoderOnlyOutput`] if [`transformers.generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`transformers.generation.GenerateBeamEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... HammingDiversityLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - >>> # lets run diverse beam search using 6 beams - >>> num_beams = 6 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... max_length=model.config.max_length, - ... num_beams=num_beams, - ... device=model.device, - ... num_beam_groups=3, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model._group_beam_search( - ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" + """ raise NotImplementedError("Group beam search is not supported by optimum-habana yet.") @@ -3295,17 +3025,10 @@ def _constrained_beam_search( self, input_ids: torch.LongTensor, constrained_beam_scorer: ConstrainedBeamSearchScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = None, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, lazy_mode: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -3317,14 +3040,6 @@ def _constrained_beam_search( Generates sequences of token ids for models with a language modeling head using **constrained beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use - generate() instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. @@ -3332,37 +3047,19 @@ def _constrained_beam_search( A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation, while satisfying a list of positive constraints. For more information, the documentation of [`ConstrainedBeamSearchScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): + logits_warper (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`transformers.generationutils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`GaudiGenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). @@ -3382,113 +3079,15 @@ def _constrained_beam_search( [`transformers.generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`transformers.generation.GenerateBeamEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... ConstrainedBeamSearchScorer, - ... PhrasalConstraint, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> constraint_str = "Sie" - >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token - >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] - - >>> # instantiate beam scorer - >>> beam_scorer = ConstrainedBeamSearchScorer( - ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model._constrained_beam_search( - ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt sind Sie?'] - ```""" - + """ # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if not self.generation_config.ignore_eos: - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() - for criteria in stopping_criteria - if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate batch_size = len(constrained_beam_scorer._beam_hyps) num_beams = constrained_beam_scorer.num_beams @@ -3544,19 +3143,22 @@ def _constrained_beam_search( model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) outputs = self( **model_inputs, return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **hpu_graphs_kwargs, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need + if token_idx is not None and outputs.logits.shape[-2] > 1: if model_kwargs.get("num_virtual_tokens", 0) > 0: # for prompt tuning, the output logit shape > model_inputs["input_ids"].shape[-1] @@ -3568,7 +3170,9 @@ def _constrained_beam_search( else: next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) else: - next_token_logits = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() next_token_scores = torch.nn.functional.log_softmax( next_token_logits, dim=-1 @@ -3606,7 +3210,7 @@ def _constrained_beam_search( next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 next_token_scores, next_tokens = torch.topk( next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True ) @@ -3642,6 +3246,13 @@ def _constrained_beam_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx @@ -3715,20 +3326,13 @@ def _constrained_beam_search( def _assisted_decoding( self, input_ids: torch.LongTensor, - candidate_generator: Optional["GaudiCandidateGenerator"] = None, - do_sample: bool = False, - logits_processor: Optional[LogitsProcessorList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, + candidate_generator: "GaudiCandidateGenerator", + logits_processor: LogitsProcessorList, + logits_warper: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GaudiGenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], lazy_mode: Optional[bool] = False, ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, @@ -3743,50 +3347,25 @@ def _assisted_decoding( candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - In most cases, you do not need to call [`transformers.generation.GenerationMixin._assisted_decoding`] directly. Use - generate() instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - candidate_generator (`CandidateGenerator`, *optional*): + candidate_generator (`CandidateGenerator`): A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For more information, the documentation of [`CandidateGenerator`] should be read. - do_sample (`bool`, *optional*, defaults to `False`): - Whether or not to use sampling ; use greedy decoding otherwise. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`LogitsProcessorList`, *optional*): + logits_warper (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + sampling at each generation step. Only used if sampling is active. + stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed @@ -3809,90 +3388,14 @@ def _assisted_decoding( [`transformers.generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`transformers.generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - >>> from transformers.generation import AssistedCandidateGenerator - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token - >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id - >>> input_prompt = "It might be possible to" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> candidate_generator = AssistedCandidateGenerator( - ... input_ids=input_ids, - ... assistant_model=assistant_model, - ... generation_config=model.generation_config, - ... logits_processor=logits_processor, - ... model_kwargs={}, - ... ) - >>> outputs = model._assisted_decoding( - ... input_ids, - ... candidate_generator=candidate_generator, - ... logits_processor=logits_processor, - ... stopping_criteria=stopping_criteria, - ... ) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] - ```""" + """ # init values - # do_sample = logits_warper is not None - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) + do_sample = logits_warper is not None + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -3910,11 +3413,19 @@ def _assisted_decoding( # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] if not ignore_eos: unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + # This is needed if return_dict_in_generate is True + start_from_empty_dynamic_cache = False + past_key_values = model_kwargs.get("past_key_values", None) + if isinstance(past_key_values, DynamicCache) or ( + isinstance(past_key_values, EncoderDecoderCache) + and isinstance(past_key_values.self_attention_cache, DynamicCache) + ): + if len(past_key_values) == 0: + start_from_empty_dynamic_cache = True hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() @@ -3954,30 +3465,33 @@ def _assisted_decoding( # we use this forward pass to also pick the subsequent logits in the original model. # 2.1. Prepare the model inputs - model_kwargs = _prepare_attention_mask( - model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = _prepare_attention_mask( + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) - model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1]) - if "cache_position" in model_kwargs: - model_kwargs["cache_position"] = torch.cat( + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + if "cache_position" in candidate_kwargs: + candidate_kwargs["cache_position"] = torch.cat( ( - model_kwargs["cache_position"], + candidate_kwargs["cache_position"], torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), ), dim=0, ) - model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) if "num_logits_to_keep" in model_inputs: model_inputs["num_logits_to_keep"] = candidate_length + 1 hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) # 2.2. Run a forward pass on the candidate sequence + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + outputs = self( **model_inputs, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **hpu_graphs_kwargs, ) @@ -4052,8 +3566,10 @@ def _assisted_decoding( if output_logits: raw_logits += (next_token_logits,) - if "past_key_values" not in model_kwargs: + if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache: added_len = new_cur_len + # set it to false for other iterations + start_from_empty_dynamic_cache = False else: added_len = n_matches + 1 @@ -4091,15 +3607,24 @@ def _assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, + num_new_tokens=n_matches + 1, ) if ignore_eos: this_peer_finished = stopping_criteria( - input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id + input_ids, + scores, + token_idx=None, + ignore_eos=ignore_eos, + eos_token_id=generation_config.eos_token_id, ) else: unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id + input_ids, + scores, + token_idx=None, + ignore_eos=ignore_eos, + eos_token_id=generation_config.eos_token_id, ) this_peer_finished = unfinished_sequences.max() == 0 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 5e8c390a88..621e391bfb 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -47,6 +47,7 @@ GaudiGemmaForCausalLM, GaudiGPT2Attention, GaudiGPT2Block, + GaudiGPT2DoubleHeadsModel, GaudiGPT2LMHeadModel, GaudiGPTBigCodeForCausalLM, GaudiGPTJAttention, @@ -54,6 +55,7 @@ GaudiGPTJForCausalLM, GaudiGPTJModel, GaudiGPTNeoXForCausalLM, + GaudiGPTNeoXLayer, GaudiLlamaAttention, GaudiLlamaDecoderLayer, GaudiLlamaDynamicNTKScalingRotaryEmbedding, @@ -139,7 +141,6 @@ gaudi_gpt_bigcode_block_forward, gaudi_gpt_bigcode_model_forward, gaudi_gpt_neox_attention_forward, - gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, gaudi_invert_attention_mask, @@ -177,7 +178,6 @@ gaudi_SpeechT5DecoderLayer_forward, gaudi_stablelm_attention_forward, gaudi_stablelm_model_forward, - gaudi_swin_get_attn_mask, gaudi_t5_layernorm_forward, gaudi_T5Attention_forward, gaudi_T5Block_forward, @@ -211,9 +211,6 @@ def adapt_transformers_to_gaudi(): # Optimization tweak for ViT transformers.models.vit.modeling_vit.ViTSelfAttention.forward = gaudi_vit_self_attention_forward - # Optimization tweak for Swin - transformers.models.swin.modeling_swin.SwinLayer.get_attn_mask = gaudi_swin_get_attn_mask - # Optimization tweak for Wav2Vec2 transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices = _gaudi_wav2vec2_compute_mask_indices # transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices = _gaudi_wav2vec2_sample_negative_indices @@ -254,12 +251,12 @@ def adapt_transformers_to_gaudi(): transformers.generation.GenerationMixin._prepare_generated_length = GaudiGenerationMixin._prepare_generated_length transformers.generation.GenerationMixin._get_stopping_criteria = GaudiGenerationMixin._get_stopping_criteria transformers.generation.GenerationMixin._validate_model_kwargs = GaudiGenerationMixin._validate_model_kwargs - transformers.generation.GenerationMixin._greedy_search = GaudiGenerationMixin._greedy_search + transformers.generation.GenerationMixin._dola_decoding = GaudiGenerationMixin._dola_decoding transformers.generation.GenerationMixin._sample = GaudiGenerationMixin._sample transformers.generation.GenerationMixin._beam_search = GaudiGenerationMixin._beam_search - transformers.generation.GenerationMixin._beam_sample = GaudiGenerationMixin._beam_sample transformers.generation.GenerationMixin._group_beam_search = GaudiGenerationMixin._group_beam_search transformers.generation.GenerationMixin._constrained_beam_search = GaudiGenerationMixin._constrained_beam_search + transformers.generation.GenerationMixin._contrastive_search = GaudiGenerationMixin._contrastive_search transformers.generation.GenerationMixin._assisted_decoding = GaudiGenerationMixin._assisted_decoding transformers.generation.GenerationMixin._get_candidate_generator = GaudiGenerationMixin._get_candidate_generator transformers.generation.GenerationConfig = GaudiGenerationConfig @@ -322,6 +319,7 @@ def adapt_transformers_to_gaudi(): transformers.models.gpt2.modeling_gpt2.GPT2Attention = GaudiGPT2Attention transformers.models.gpt2.modeling_gpt2.GPT2Model.forward = gaudi_gpt2_forward transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel = GaudiGPT2LMHeadModel + transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel = GaudiGPT2DoubleHeadsModel transformers.models.gpt2.modeling_gpt2.GPT2Block = GaudiGPT2Block models_with_tracing_support.extend((GaudiGPT2Attention, GaudiGPT2LMHeadModel)) @@ -356,7 +354,7 @@ def adapt_transformers_to_gaudi(): # Optimization for gpt-neox generation on Gaudi transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.forward = gaudi_gpt_neox_model_forward - transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward = gaudi_gpt_neox_layer_forward + transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer = GaudiGPTNeoXLayer transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = gaudi_gpt_neox_attention_forward transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding._set_cos_sin_cache = ( gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 5fe87144bd..99ef65c4e4 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -70,7 +70,13 @@ gaudi_gemma_attention_forward, gaudi_gemma_model_forward, ) -from .gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward +from .gpt2 import ( + GaudiGPT2Attention, + GaudiGPT2Block, + GaudiGPT2DoubleHeadsModel, + GaudiGPT2LMHeadModel, + gaudi_gpt2_forward, +) from .gpt_bigcode import ( GaudiGPTBigCodeForCausalLM, gaudi_gpt_bigcode_attention_forward, @@ -79,8 +85,8 @@ ) from .gpt_neox import ( GaudiGPTNeoXForCausalLM, + GaudiGPTNeoXLayer, gaudi_gpt_neox_attention_forward, - gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, ) @@ -196,7 +202,6 @@ GaudiStarcoder2ForCausalLM, GaudiStarcoder2Model, ) -from .swin import gaudi_swin_get_attn_mask from .t5 import ( gaudi_t5_layernorm_forward, gaudi_T5Attention_forward, diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py index f68b1e3f81..239c65a9af 100644 --- a/optimum/habana/transformers/models/bart/modeling_bart.py +++ b/optimum/habana/transformers/models/bart/modeling_bart.py @@ -342,7 +342,7 @@ def gaudi_BartEncoder_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input) import habana_frameworks.torch.core as htcore @@ -461,7 +461,7 @@ def gaudi_BartDecoder_forward( tensor_past_key_values_length = token_idx - 1 if use_cache else torch.tensor(past_key_values_length) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) * self.embed_scale + inputs_embeds = self.embed_tokens(input) if self._use_sdpa and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/optimum/habana/transformers/models/bert/modeling_bert.py b/optimum/habana/transformers/models/bert/modeling_bert.py index b49095ba60..549cb670c2 100644 --- a/optimum/habana/transformers/models/bert/modeling_bert.py +++ b/optimum/habana/transformers/models/bert/modeling_bert.py @@ -52,9 +52,6 @@ def gaudi_BertModel_forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -63,10 +60,21 @@ def gaudi_BertModel_forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. dtype = torch.hpu.get_autocast_hpu_dtype() if torch.hpu.is_autocast_hpu_enabled() else self.dtype - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, dtype=dtype) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, dtype=dtype) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -86,13 +94,6 @@ def gaudi_BertModel_forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/optimum/habana/transformers/models/blip/modeling_blip.py b/optimum/habana/transformers/models/blip/modeling_blip.py index 6545a0662d..512664210a 100644 --- a/optimum/habana/transformers/models/blip/modeling_blip.py +++ b/optimum/habana/transformers/models/blip/modeling_blip.py @@ -13,6 +13,7 @@ def gaudi_BlipForConditionalGeneration_generate( pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -29,7 +30,10 @@ def gaudi_BlipForConditionalGeneration_generate( self.text_decoder = wrap_in_hpu_graph(self.text_decoder) batch_size = pixel_values.shape[0] - vision_outputs = self.vision_model(pixel_values=pixel_values) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) image_embeds = vision_outputs[0] @@ -66,6 +70,7 @@ def gaudi_BlipForQuestionAnswering_generate( input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -84,7 +89,10 @@ def gaudi_BlipForQuestionAnswering_generate( if not hasattr(self.text_decoder, "clear_cache"): self.text_decoder = wrap_in_hpu_graph(self.text_decoder) - vision_outputs = self.vision_model(pixel_values=pixel_values) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) image_embeds = vision_outputs[0] diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py index 99854a3799..b22c61972d 100644 --- a/optimum/habana/transformers/models/clip/modeling_clip.py +++ b/optimum/habana/transformers/models/clip/modeling_clip.py @@ -3,7 +3,9 @@ import torch from torch import nn from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.models.clip.configuration_clip import CLIPConfig from transformers.models.clip.modeling_clip import ( + CLIPMLP, CLIPAttention, CLIPEncoder, CLIPEncoderLayer, @@ -76,7 +78,7 @@ def forward( causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, use_flash_attention: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Copied from CLIPAttention.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py The only differences are: @@ -161,6 +163,14 @@ def forward( class GaudiCLIPEncoderLayer(CLIPEncoderLayer): + def __init__(self, config: CLIPConfig): + super(CLIPEncoderLayer, self).__init__() + self.embed_dim = config.hidden_size + self.self_attn = GaudiCLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + def forward( self, hidden_states: torch.Tensor, diff --git a/optimum/habana/transformers/models/codegen/modeling_codegen.py b/optimum/habana/transformers/models/codegen/modeling_codegen.py index b568085971..536cb5d423 100644 --- a/optimum/habana/transformers/models/codegen/modeling_codegen.py +++ b/optimum/habana/transformers/models/codegen/modeling_codegen.py @@ -313,7 +313,9 @@ class GaudiCodeGenForCausalLM(CodeGenForCausalLM): - when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx """ - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, token_idx=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, inputs_embeds=None, past_key_values=None, token_idx=None, **kwargs + ): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values if past_key_values: @@ -339,15 +341,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, token_i else: position_ids = position_ids[:, -1] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "token_idx": token_idx, - } + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "token_idx": token_idx, + } + ) + return model_inputs def forward( self, diff --git a/optimum/habana/transformers/models/decilm/modeling_decilm.py b/optimum/habana/transformers/models/decilm/modeling_decilm.py index 92b1c52a1d..562033f2cb 100644 --- a/optimum/habana/transformers/models/decilm/modeling_decilm.py +++ b/optimum/habana/transformers/models/decilm/modeling_decilm.py @@ -30,6 +30,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask +from ..llama.modeling_llama import GaudiLlamaRotaryEmbedding from .configuration_decilm import DeciLMConfig @@ -64,7 +65,7 @@ def __init__(self, config: DeciLMConfig, layer_idx: int): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() + self.rotary_emb = GaudiLlamaRotaryEmbedding(config=self.config) def forward( self, diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 190a94e7a5..a7a0c0e920 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -1,7 +1,6 @@ import contextlib import math import os -import warnings from typing import Optional, Tuple, Union import torch @@ -348,12 +347,7 @@ def pre_attn_forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, - **kwargs, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -631,11 +625,8 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states + ( hidden_states, present, @@ -657,7 +648,6 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, - **kwargs, ) self.self_attention.attention_all_reduce(hidden_states) @@ -674,6 +664,13 @@ def forward( ) mlp_layernorm_out = self.post_attention_layernorm(residual) + if ( + self.config.new_decoder_architecture + and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1 + ): + mlp_layernorm_out = attention_layernorm_out + outputs = (present, attn_scores) hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) @@ -709,7 +706,7 @@ def pre_attn( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, ): - if self.config.new_decoder_architecture: + if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -978,6 +975,7 @@ def prepare_inputs_for_generation( past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1019,19 +1017,26 @@ def prepare_inputs_for_generation( else: position_ids = position_ids[:, -input_ids.shape[1] :] - return { - "input_ids": input_ids, - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "reuse_cache": reuse_cache, - "cache_idx": kwargs.get("cache_idx"), - "use_flash_attention": kwargs.get("use_flash_attention"), - "flash_attention_recompute": kwargs.get("flash_attention_recompute"), - "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), - } + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "reuse_cache": reuse_cache, + "cache_idx": kwargs.get("cache_idx"), + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + } + ) + return model_inputs def forward( self, diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 7f57e838b8..6c537dfa31 100644 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -19,14 +19,12 @@ # limitations under the License. """PyTorch Gemma model.""" -import math -import warnings from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.gemma.modeling_gemma import ( GemmaAttention, @@ -75,10 +73,8 @@ def gaudi_gemma_attention_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: if token_idx is not None: @@ -98,7 +94,7 @@ def gaudi_gemma_attention_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: # no matter the length, we just slice it attn_weights = attn_weights + attention_mask @@ -135,7 +131,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -147,12 +143,6 @@ def forward( The only differences are: - add new args token_idx """ - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -193,7 +183,7 @@ def gaudi_gemma_model_forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -235,12 +225,13 @@ def gaudi_gemma_model_forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + return_legacy_cache = False # noqa: F841 + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -258,6 +249,13 @@ def gaudi_gemma_model_forward( # normalized hidden_states = hidden_states * (self.config.hidden_size**0.5) + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -305,11 +303,9 @@ def gaudi_gemma_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -327,7 +323,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -392,7 +388,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, ): """ Inherits from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py @@ -405,48 +409,18 @@ def prepare_inputs_for_generation( token_idx = kwargs.get("token_idx", None) - past_length = 0 if past_key_values is not None: if token_idx is None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = ( - past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - ) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] else: # past_length += token_idx input_ids = torch.index_select(input_ids, 1, token_idx - 1) - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -470,18 +444,12 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: - cache_position = cache_position[-input_length:] - model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, } diff --git a/optimum/habana/transformers/models/gpt2/__init__.py b/optimum/habana/transformers/models/gpt2/__init__.py index 4052373a27..a76e916b42 100644 --- a/optimum/habana/transformers/models/gpt2/__init__.py +++ b/optimum/habana/transformers/models/gpt2/__init__.py @@ -1 +1,7 @@ -from .modeling_gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward +from .modeling_gpt2 import ( + GaudiGPT2Attention, + GaudiGPT2Block, + GaudiGPT2DoubleHeadsModel, + GaudiGPT2LMHeadModel, + gaudi_gpt2_forward, +) diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index 3a2e85d24b..b2ec2c0229 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -3,7 +3,14 @@ import torch from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2LMHeadModel, logger +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2MLP, + GPT2Attention, + GPT2DoubleHeadsModel, + GPT2DoubleHeadsModelOutput, + GPT2LMHeadModel, + logger, +) class GaudiGPT2Attention(GPT2Attention): @@ -315,9 +322,14 @@ def gaudi_gpt2_forward( position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None if attention_mask is not None: - attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] @@ -352,11 +364,6 @@ def gaudi_gpt2_forward( # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds @@ -586,3 +593,142 @@ def forward( attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) + + +class GaudiGPT2DoubleHeadsModel(GPT2DoubleHeadsModel): + """ + Copied from GPT2DoubleHeadsModel: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/gpt2/modeling_gpt2.py#L1377 + The only differences are: + - add new args token_idx to support static shapes + """ + + def prepare_inputs_for_generation( + self, input_ids, inputs_embeds=None, past_key_values=None, token_idx=None, **kwargs + ): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "token_idx": token_idx, + } + ) + + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index ca0d06aebc..53b714ae62 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -252,6 +252,7 @@ def gaudi_gpt_bigcode_block_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + **kwargs, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py diff --git a/optimum/habana/transformers/models/gpt_neox/__init__.py b/optimum/habana/transformers/models/gpt_neox/__init__.py index cceb114b82..495e9ef766 100644 --- a/optimum/habana/transformers/models/gpt_neox/__init__.py +++ b/optimum/habana/transformers/models/gpt_neox/__init__.py @@ -1,7 +1,7 @@ from .modeling_gpt_neox import ( GaudiGPTNeoXForCausalLM, + GaudiGPTNeoXLayer, gaudi_gpt_neox_attention_forward, - gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, ) diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index d4614fe959..aa6423d2b1 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -3,7 +3,17 @@ import torch from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM, apply_rotary_pos_emb, logger +from transformers.models.gpt_neox.modeling_gpt_neox import ( + GPTNeoXAttention, + GPTNeoXForCausalLM, + GPTNeoXLayer, + GPTNeoXMLP, + GPTNeoXModel, + apply_rotary_pos_emb, + logger, +) + +from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask try: @@ -22,7 +32,6 @@ def gaudi_gpt_neox_attention_forward( layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - padding_mask: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, ): """ @@ -93,57 +102,68 @@ def gaudi_gpt_neox_attention_forward( return outputs -def gaudi_gpt_neox_layer_forward( - self, - hidden_states: Optional[torch.FloatTensor], - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - layer_past: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, -): - """ - Copied from GPTNeoxLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py - The only differences are: - - add new args token_idx - """ - attention_layer_outputs = self.attention( - self.input_layernorm(hidden_states), - attention_mask=attention_mask, - position_ids=position_ids, - layer_past=layer_past, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - ) - attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) - attn_output = self.post_attention_dropout(attn_output) - outputs = attention_layer_outputs[1:] - - if self.use_parallel_residual: - # pseudocode: - # x = x + attn(ln1(x)) + mlp(ln2(x)) - mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - mlp_output = self.post_mlp_dropout(mlp_output) - hidden_states = mlp_output + attn_output + hidden_states - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - attn_output = attn_output + hidden_states - mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - mlp_output = self.post_mlp_dropout(mlp_output) - hidden_states = mlp_output + attn_output - - if use_cache: - outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) - else: - outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) +class GaudiGPTNeoXLayer(GPTNeoXLayer): + def __init__(self, config): + super(GPTNeoXLayer, self).__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_dropout = torch.nn.Dropout(config.hidden_dropout) + self.post_mlp_dropout = torch.nn.Dropout(config.hidden_dropout) + self.attention = GPTNeoXAttention(config) + self.mlp = GPTNeoXMLP(config) - return outputs + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + layer_past: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + ): + """ + Copied from GPTNeoxLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py + The only differences are: + - add new args token_idx + """ + attention_layer_outputs = self.attention( + self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + ) + attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + attn_output = self.post_attention_dropout(attn_output) + outputs = attention_layer_outputs[1:] + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + mlp_output = self.post_mlp_dropout(mlp_output) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + mlp_output = self.post_mlp_dropout(mlp_output) + hidden_states = mlp_output + attn_output + + if use_cache: + outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + else: + outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + + return outputs def gaudi_gpt_neox_model_forward( @@ -195,24 +215,17 @@ def gaudi_gpt_neox_model_forward( position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + # Attention mask. - if attention_mask is not None: - assert batch_size > 0, "batch_size has to be defined and > 0" - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -221,9 +234,6 @@ def gaudi_gpt_neox_model_forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - hidden_states = self.emb_dropout(inputs_embeds) if self.gradient_checkpointing and self.training: @@ -295,6 +305,16 @@ class GaudiGPTNeoXForCausalLM(GPTNeoXForCausalLM): - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ + def __init__(self, config): + super(GPTNeoXForCausalLM, self).__init__(config) + + config._attn_implementation = "eager" + self.gpt_neox = GPTNeoXModel(config) + self.embed_out = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -397,6 +417,7 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "past_key_values": past_key_values, "position_ids": position_ids, + "use_cache": kwargs.get("use_cache"), "token_idx": token_idx, } ) diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index 12ad78e29a..ce754dadb5 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -1,4 +1,3 @@ -# TODO: To remove when the repo is upgraded to Transformers >= 4.41.0 from transformers.models.llama.configuration_llama import LlamaConfig @@ -51,9 +50,9 @@ def __init__( rope_scaling, attention_bias, attention_dropout, + mlp_bias, **kwargs, ) - self.mlp_bias = mlp_bias self.fused_qkv = fused_qkv self.parallel_strategy = parallel_strategy diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7d10ddfcf2..429b0a6ce3 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,7 +1,6 @@ import copy import math import os -import warnings from typing import List, Optional, Tuple, Union import torch @@ -10,7 +9,7 @@ from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -31,6 +30,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +from .configuration_llama import LlamaConfig try: @@ -82,19 +82,54 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): class GaudiLlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -107,18 +142,49 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + def _dynamic_frequency_update(self, seq_len, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + # seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] + + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(seq_len, device=x.device) + if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self._cos_cached[:seq_len].to(dtype=x.dtype), - self._sin_cached[:seq_len].to(dtype=x.dtype), + self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, ) class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -132,6 +198,15 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -157,10 +232,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - bias = config.mlp_bias if hasattr(config, "mlp_bias") else False - self.gate_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) - self.up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) - self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.gate_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def pre_mlp_forward(self, x): @@ -333,7 +407,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.k_cache = KVCache() self.v_cache = KVCache() self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None - if config.fused_qkv: + if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads self.dim1 = self.num_heads * self.head_dim @@ -387,6 +461,7 @@ def pre_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -431,7 +506,7 @@ def pre_attn_forward( value_states = torch.cat(value_states, dim=-1) else: - if self.config.fused_qkv: + if hasattr(self.config, "fused_qkv") and self.config.fused_qkv: qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split( qkv_states, [self.dim1, self.dim2, self.dim2], dim=-1 @@ -440,6 +515,7 @@ def pre_attn_forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # TODO: update when auto mp params is enabled in DeepSpeed (cf. https://github.com/HabanaAI/DeepSpeed/blob/94309c7b5dfc1a69858f5c9f25737b2f81a332a5/deepspeed/module_inject/replace_module.py#L440) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) @@ -461,8 +537,18 @@ def pre_attn_forward( else: kv_seq_len = past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + else: + cos, sin = position_embeddings query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + if use_cache: # reuse k, v, self_attention if reuse_cache: @@ -564,7 +650,7 @@ def pre_attn_forward( attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -653,6 +739,7 @@ def pre_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -672,6 +759,7 @@ def pre_attn_forward( output_attentions, use_cache, cache_position, + position_embeddings, token_idx, attn_softmax_bf16, reuse_cache, @@ -712,10 +800,11 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -738,12 +827,8 @@ def forward( - add new arg flash_attention_causal_mask - add new arg flash_attention_fast_softmax """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states + hidden_states, self_attn_weights, present_key_value = self.pre_attn( hidden_states, attention_mask, @@ -752,6 +837,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, token_idx, attn_softmax_bf16, reuse_cache, @@ -786,6 +872,7 @@ def pre_attn( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -805,6 +892,7 @@ def pre_attn( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, @@ -870,6 +958,7 @@ def __init__(self, config: LlamaConfig): config.parallel_strategy = None self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GaudiLlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -891,7 +980,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -938,6 +1027,7 @@ def forward( batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") + if hasattr(self.config, "use_fused_rope") and self.config.use_fused_rope is False: global has_fused_rope has_fused_rope = False @@ -975,14 +1065,12 @@ def forward( if ignore_cache_position is False: if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None and cache_position: position_ids = cache_position.unsqueeze(0) - else: if position_ids is None: position_ids = torch.arange( @@ -1005,6 +1093,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = None # self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1034,6 +1125,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, None, attn_softmax_bf16, False, @@ -1052,6 +1144,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, @@ -1076,11 +1169,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if not use_new_cache and isinstance(next_cache, Cache): + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1121,7 +1213,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1221,58 +1313,29 @@ def prepare_inputs_for_generation( attention_mask=None, inputs_embeds=None, cache_position=None, + position_ids=None, + use_cache=True, token_idx=None, **kwargs, ): - past_length = 0 - reuse_cache = kwargs.get("reuse_cache") bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = ( - past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - ) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] elif (reuse_cache or bucket_internal) and token_idx is not None: # KV cache is pre allocated with reuse cache or will be padded with bucket internal # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1290,9 +1353,6 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( @@ -1300,7 +1360,7 @@ def prepare_inputs_for_generation( "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "trim_logits": kwargs.get("trim_logits"), diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py index 7b917e163e..ea7c112c7d 100644 --- a/optimum/habana/transformers/models/mamba/modeling_mamba.py +++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py @@ -14,10 +14,16 @@ def gaudi_MambaForCausalLM_update_model_kwargs_for_generation( - self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs ) -> Dict[str, Any]: model_kwargs["cache_params"] = outputs.get("cache_params", None) token_idx = model_kwargs.get("token_idx", None) + if ( + model_kwargs.get("use_cache", True) + and "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens if token_idx is not None: token_idx.add_(1) if "token_idx_cpu" in model_kwargs: @@ -26,13 +32,34 @@ def gaudi_MambaForCausalLM_update_model_kwargs_for_generation( def gaudi_MambaForCausalLM_prepare_inputs_for_generation( - self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask=None, + **kwargs, ): token_idx = kwargs.get("token_idx", None) token_idx_cpu = kwargs.get("token_idx_cpu", None) - if cache_params is not None: + if use_cache: if token_idx is None: - input_ids = input_ids[:, -1].unsqueeze(-1) + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1].unsqueeze(-1) + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: @@ -41,6 +68,13 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( if inputs_embeds is not None and cache_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - model_inputs["cache_params"] = cache_params + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + } + ) return model_inputs diff --git a/optimum/habana/transformers/models/mistral/configuration_mistral.py b/optimum/habana/transformers/models/mistral/configuration_mistral.py index 6e5968a00f..9af0c8975f 100644 --- a/optimum/habana/transformers/models/mistral/configuration_mistral.py +++ b/optimum/habana/transformers/models/mistral/configuration_mistral.py @@ -1,3 +1,4 @@ +from transformers.modeling_rope_utils import rope_config_validation from transformers.models.mistral.configuration_mistral import MistralConfig @@ -16,6 +17,7 @@ def __init__( num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, + head_dim=None, hidden_act="silu", max_position_embeddings=4096 * 32, initializer_range=0.02, @@ -38,6 +40,7 @@ def __init__( num_hidden_layers, num_attention_heads, num_key_value_heads, + head_dim, hidden_act, max_position_embeddings, initializer_range, @@ -54,25 +57,6 @@ def __init__( ) self.rope_scaling = rope_scaling - self._rope_scaling_validation() - def _rope_scaling_validation(self): - """ - Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172 - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 42b5735154..7d95e548ce 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -25,11 +25,9 @@ import habana_frameworks.torch.core as htcore import torch -import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.models.mistral.modeling_mistral import ( @@ -134,34 +132,6 @@ def forward(self, x, y): return torch.matmul(x, y) -# Copy from GaudiMixtralAttentionLongSequence -class GaudiMistralAttentionLongSequence: - @staticmethod - def forward(q, k, v, mask, causal, q_block_size): - """ - Support long sequence at prompt phase - """ - q_len = q.size(-2) - q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) - q_padding = q_tiles * q_block_size - q_len - q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) - if mask is not None: - mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) - attn_output = torch.zeros_like(q) - - for i in range(q_tiles): - s, e = i * q_block_size, (i + 1) * q_block_size - row_q = q[:, :, s:e, :] - row_mask = mask[:, :, s:e, :] - row_o = attn_output[:, :, s:e, :] - row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None)) - - if q_padding != 0: - attn_output = attn_output[:, :, :-q_padding, :] - - return attn_output - - def gaudi_mistral_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -232,7 +202,6 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self._init_rope() self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.block_size = 1024 def _init_rope(self): """ @@ -301,6 +270,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -397,17 +367,7 @@ def forward( attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) - elif FusedSDPA and not self.training and q_len == key_states.size(-2) and q_len > 8192: - htcore.mark_step() - attn_output = GaudiMistralAttentionLongSequence.forward( - query_states, - key_states, - value_states, - attention_mask, - False, - self.block_size, - ) - htcore.mark_step() + else: # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( @@ -415,24 +375,9 @@ def forward( ) attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attn_weights.size() not in [ - (bsz, self.num_heads, q_len, kv_seq_len), - (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), - ]: - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" - f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask if attn_softmax_bf16: attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) @@ -450,8 +395,8 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -485,9 +430,10 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -514,6 +460,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -521,6 +468,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, + **kwargs, ) hidden_states = residual + hidden_states @@ -558,12 +506,13 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -597,53 +546,45 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False past_key_values_length = 0 - use_legacy_cache = True use_new_cache = False - if past_key_values is not None and use_cache and not reuse_cache and use_new_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + return_legacy_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + if use_cache and not isinstance(past_key_values, Cache) and use_new_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - else: - # 4d mask is passed through the layers - attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # 4d mask is passed through the layers + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + hidden_states = inputs_embeds # decoder layers @@ -668,11 +609,12 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, None if past_key_values is None else past_key_values[layer_idx], output_attentions, use_cache, + cache_position, None, False, cache_idx, @@ -684,11 +626,12 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -717,8 +660,9 @@ def forward( next_cache = ( next_decoder_cache if not use_new_cache - else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache) + else (next_decoder_cache.to_legacy_cache() if return_legacy_cache else next_decoder_cache) ) + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -744,13 +688,14 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, @@ -788,6 +733,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -833,7 +779,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, ): """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -849,33 +803,12 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: if token_idx is None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) elif reuse_cache and token_idx is not None: @@ -883,7 +816,6 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -898,13 +830,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": kwargs.get("reuse_cache"), @@ -938,4 +871,4 @@ def apply_customized_rope(q, k, cos, sin, position_ids): k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ) else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids) + return apply_rotary_pos_emb(q, k, cos, sin) diff --git a/optimum/habana/transformers/models/mixtral/configuration_mixtral.py b/optimum/habana/transformers/models/mixtral/configuration_mixtral.py index 90a783f0d1..b9121cfbd4 100644 --- a/optimum/habana/transformers/models/mixtral/configuration_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/configuration_mixtral.py @@ -1,3 +1,4 @@ +from transformers.modeling_rope_utils import rope_config_validation from transformers.models.mixtral.configuration_mixtral import MixtralConfig @@ -64,25 +65,6 @@ def __init__( ) self.rope_scaling = rope_scaling - self._rope_scaling_validation() - def _rope_scaling_validation(self): - """ - Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172 - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + # Validate the correctness of rotary position embeddings parameters + rope_config_validation(self) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 0dd6ffc47e..b6c750fa00 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -22,7 +22,6 @@ import contextlib import math -import warnings from typing import List, Optional, Tuple, Union import habana_frameworks.torch.core as htcore @@ -274,11 +273,11 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py @@ -289,10 +288,6 @@ def forward( - add new args flash_attention_recompute - add new args cache_idx """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -471,6 +466,7 @@ def forward( output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, @@ -485,11 +481,6 @@ def forward( - add new args flash_attention_recompute - add new args cache_idx """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -502,6 +493,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, flash_attention_recompute=flash_attention_recompute, @@ -549,6 +541,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, @@ -604,27 +597,18 @@ def forward( else: past_key_values_length = past_key_values[0][0].shape[2] - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None @@ -669,6 +653,7 @@ def forward( output_attentions, output_router_logits, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -679,6 +664,7 @@ def forward( output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, flash_attention_recompute=flash_attention_recompute, @@ -752,6 +738,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, @@ -779,6 +766,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, flash_attention_recompute=flash_attention_recompute, @@ -830,9 +818,17 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + output_router_logits=False, + position_ids=None, + use_cache=True, + **kwargs, ): - past_length = 0 reuse_cache = kwargs.get("reuse_cache") token_idx = kwargs.get("token_idx", None) @@ -841,39 +837,17 @@ def prepare_inputs_for_generation( if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] elif reuse_cache and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -888,13 +862,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": reuse_cache, diff --git a/optimum/habana/transformers/models/mpt/modeling_mpt.py b/optimum/habana/transformers/models/mpt/modeling_mpt.py index cb44f4dad2..3d495a320b 100755 --- a/optimum/habana/transformers/models/mpt/modeling_mpt.py +++ b/optimum/habana/transformers/models/mpt/modeling_mpt.py @@ -48,6 +48,9 @@ def gaudi_mpt_attention_forward( batch_size, seq_length = hidden_states.shape[:2] mixed_qkv = self.Wqkv(hidden_states) + if self.clip_qkv: + mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + bs, seq_len, three_times_hidden_size = mixed_qkv.shape mixed_qkv = mixed_qkv.view(bs, seq_len, self.n_heads * 3, self.head_dim) mixed_qkv = mixed_qkv.transpose(1, 2) diff --git a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py index 4ecfd3264f..4c7b24b988 100644 --- a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py +++ b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py @@ -25,6 +25,7 @@ def gaudi_persimmon_attention_forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -93,7 +94,12 @@ def gaudi_persimmon_attention_forward( past_key_value.value_cache.append(value_states) else: # Specific to RoPE models with partial rotation - cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_emb.dim, + "cache_position": cache_position, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -104,12 +110,9 @@ def gaudi_persimmon_attention_forward( f" {attn_weights.size()}" ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) @@ -142,6 +145,7 @@ def gaudi_persimmon_decoder_layer_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -161,6 +165,7 @@ def gaudi_persimmon_decoder_layer_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, ) hidden_states = residual + hidden_states @@ -195,6 +200,7 @@ def gaudi_persimmon_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -240,16 +246,17 @@ def gaudi_persimmon_model_forward( past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # embed positions + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) attention_mask = _gaudi_prepare_4d_causal_attention_mask( @@ -275,6 +282,8 @@ def gaudi_persimmon_model_forward( position_ids, past_key_values, output_attentions, + use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -284,6 +293,7 @@ def gaudi_persimmon_model_forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, ) @@ -328,6 +338,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ @@ -353,6 +364,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, ) @@ -385,7 +397,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, ): """ Inherits from PersimmonForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/persimmon/modeling_persimmon.py @@ -398,37 +418,15 @@ def prepare_inputs_for_generation( token_idx = kwargs.get("token_idx", None) if past_key_values is not None: if token_idx is None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -443,13 +441,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, } diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index 4c9d9dd4d2..07f4d0cd71 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -152,6 +152,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -312,6 +313,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -337,6 +339,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -372,6 +375,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, @@ -422,23 +426,23 @@ def forward( else: past_seen_tokens = past_key_values[0][0].shape[2] - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = self.embed_dropout(inputs_embeds) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) # 4d mask is passed through the layers attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens ) + inputs_embeds = self.embed_dropout(inputs_embeds) hidden_states = inputs_embeds # decoder layers @@ -456,9 +460,10 @@ def forward( hidden_states, attention_mask, position_ids, - None if past_key_values is None else past_key_values[layer_idx], output_attentions, use_cache, + None if past_key_values is None else past_key_values[layer_idx], + cache_position, None, ) else: @@ -469,6 +474,7 @@ def forward( past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -520,6 +526,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, @@ -550,6 +557,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, @@ -591,7 +599,16 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + token_idx=None, + **kwargs, ): """ Inherits from PhiForCausalLM: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py @@ -601,46 +618,23 @@ def prepare_inputs_for_generation( - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ - past_length = 0 reuse_cache = kwargs.get("reuse_cache") # Omit tokens covered by past_key_values if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] elif reuse_cache and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -655,13 +649,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": kwargs.get("reuse_cache"), diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 27f3319579..0c8970dd88 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -18,7 +18,6 @@ import math import os -import warnings from typing import List, Optional, Tuple, Union import torch @@ -280,11 +279,6 @@ def pre_attn_forward( - add new arg flash_attention_recompute - add new arg flash_attention_fast_softmax """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -610,13 +604,13 @@ def forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] + batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] + batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: @@ -850,50 +844,35 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + token_idx=None, + **kwargs, ): - past_length = 0 - reuse_cache = kwargs.get("reuse_cache") bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] elif (reuse_cache or bucket_internal) and token_idx is not None: # KV cache is pre allocated with reuse cache or will be padded with bucket internal # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -910,14 +889,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids.contiguous(), "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "trim_logits": kwargs.get("trim_logits"), diff --git a/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py index d4728c30f0..53cea37255 100644 --- a/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -268,7 +268,7 @@ def gaudi_SeamlessM4TDecoder_forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length diff --git a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py index f53994a5b1..08becc263a 100644 --- a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py +++ b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py @@ -32,6 +32,7 @@ def gaudi_stablelm_attention_forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -97,7 +98,12 @@ def gaudi_stablelm_attention_forward( past_key_value.value_cache.append(value_states) else: # Specific to RoPE models with partial rotation - cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_emb.dim, + "cache_position": cache_position, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # Repeat k/v heads if n_kv_heads < n_heads @@ -112,12 +118,9 @@ def gaudi_stablelm_attention_forward( f" {attn_weights.size()}" ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) @@ -163,6 +166,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -182,6 +186,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, ) @@ -223,6 +228,7 @@ def gaudi_stablelm_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -267,16 +273,17 @@ def gaudi_stablelm_model_forward( past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # embed positions + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None @@ -305,6 +312,8 @@ def gaudi_stablelm_model_forward( position_ids, past_key_values, output_attentions, + use_cache, + cache_position, None, ) else: @@ -315,6 +324,7 @@ def gaudi_stablelm_model_forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, ) @@ -359,6 +369,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ @@ -382,6 +393,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, ) @@ -414,7 +426,15 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, ): """ Inherits from StableLmForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py @@ -427,37 +447,15 @@ def prepare_inputs_for_generation( token_idx = kwargs.get("token_idx", None) if past_key_values is not None: if token_idx is None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -472,13 +470,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, } diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index 4ca9793dd3..36d5379e4f 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -17,7 +17,6 @@ ############################################################################### import math -import warnings from typing import List, Optional, Tuple, Union import torch @@ -244,11 +243,6 @@ def pre_attn_forward( - add new args use_flash_attention - add new arg flash_attention_recompute """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -339,9 +333,7 @@ def pre_attn_forward( attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask if attn_softmax_bf16: @@ -384,10 +376,16 @@ def post_attn_forward(self, attn_output): class GaudiStarcoder2DecoderLayer(Starcoder2DecoderLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): - super().__init__(config, layer_idx) + super(Starcoder2DecoderLayer, self).__init__() self.hidden_size = config.hidden_size + self.self_attn = GaudiStarcoder2Attention(config, layer_idx) + self.mlp = GaudiStarcoder2MLP(config) + + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -416,16 +414,16 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states, self_attn_weights, present_key_value = self.pre_attn( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - token_idx, - attn_softmax_bf16, - reuse_cache, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -513,6 +511,22 @@ def post_mlp(self, hidden_states, residual): class GaudiStarcoder2Model(Starcoder2Model): + def __init__(self, config: Starcoder2Config): + super(Starcoder2Model, self).__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embedding_dropout = config.embedding_dropout + self.layers = torch.nn.ModuleList( + [GaudiStarcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = "eager" + self.norm = torch.nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -553,8 +567,6 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self._attn_implementation = "eager" - # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -691,6 +703,15 @@ def forward( class GaudiStarcoder2ForCausalLM(Starcoder2ForCausalLM): + def __init__(self, config): + super(Starcoder2ForCausalLM, self).__init__(config) + self.model = GaudiStarcoder2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -792,48 +813,33 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + token_idx=None, + **kwargs, ): - past_length = 0 - reuse_cache = kwargs.get("reuse_cache") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] elif reuse_cache and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -850,14 +856,14 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids.contiguous(), "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "trim_logits": kwargs.get("trim_logits"), diff --git a/optimum/habana/transformers/models/swin/__init__.py b/optimum/habana/transformers/models/swin/__init__.py deleted file mode 100644 index 59dbee4d5d..0000000000 --- a/optimum/habana/transformers/models/swin/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .modeling_swin import gaudi_swin_get_attn_mask diff --git a/optimum/habana/transformers/models/swin/modeling_swin.py b/optimum/habana/transformers/models/swin/modeling_swin.py deleted file mode 100644 index 9ea3b9d28c..0000000000 --- a/optimum/habana/transformers/models/swin/modeling_swin.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Swin Transformer model.""" - -import torch -from transformers.models.swin.modeling_swin import window_partition - - -def gaudi_swin_get_attn_mask(self, height, width, dtype): - """ - Copied from SwinLayer.get_attn_mask : https://github.com/huggingface/transformers/blob/main/src/transformers/models/swin/modeling_swin.py - The only difference is moving img_mask to hpu for performance - """ - if self.shift_size > 0: - # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device="hpu") - height_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - width_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - count = 0 - for height_slice in height_slices: - for width_slice in width_slices: - img_mask[:, height_slice, width_slice, :] = count - count += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - return attn_mask diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index c6dd9cb546..be5c2f6176 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -190,6 +190,81 @@ def _gaudi_wav2vec2_sample_negative_indices( return sampled_negative_indices +def gaudi_wav2vec2_encoder_forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, +): + """ + Copied from Transformers: https://github.com/huggingface/transformers/blob/7790943c91411f4234d11dfbf4c2f21ce7caf088/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L755 + The only difference is that torch.rand device is set to 'hpu' (required to capture operation as part of HPU graph) + """ + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([], device="hpu") + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def gaudi_wav2vec2_forward( self, input_values: Optional[torch.Tensor], @@ -300,100 +375,6 @@ def _gaudi_wav2vec2_mask_hidden_states( return hidden_states -def gaudi_wav2vec2_encoder_forward( - self, - hidden_states: torch.tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, -): - """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/7790943c91411f4234d11dfbf4c2f21ce7caf088/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L755 - The only difference is that torch.rand device is set to 'hpu' (required to capture operation as part of HPU graph) - """ - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if attention_mask is not None: - # make sure padded tokens output 0 - expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states[~expand_attention_mask] = 0 - - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) - - position_embeddings = self.pos_conv_embed(hidden_states) - hidden_states = hidden_states + position_embeddings - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() - - for layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - dropout_probability = torch.rand([], device="hpu") - - skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = layer_outputs[0] - - if skip_the_layer: - layer_outputs = (None, None) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L2290 - v4.38.2 implementation caused accuracy issue to run pytest Wav2Vec2RobustModelTest. - """ - hidden_states = hidden_states.unsqueeze(1) - hidden_states = torch.nn.functional.unfold( - hidden_states, - (self.kernel_size, self.in_conv_dim), - stride=(1, self.in_conv_dim), - dilation=(self.dilation, 1), - ) - hidden_states = hidden_states.transpose(1, 2) - hidden_states = self.kernel(hidden_states) - - hidden_states = self.activation(hidden_states) - return hidden_states - - def gaudi_wav2vec2forctc_forward( self, input_values: Optional[torch.Tensor], @@ -409,6 +390,10 @@ def gaudi_wav2vec2forctc_forward( changing flattened_targets tensor shapes across training iterations. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + outputs = self.wav2vec2( input_values, attention_mask=attention_mask, @@ -421,8 +406,6 @@ def gaudi_wav2vec2forctc_forward( logits = self.lm_head(hidden_states) loss = None if labels is not None: - if labels.max() >= self.config.vocab_size: - raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask @@ -463,3 +446,22 @@ def gaudi_wav2vec2forctc_forward( output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Copied from Transformers: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L2290 + v4.38.2 implementation caused accuracy issue to run pytest Wav2Vec2RobustModelTest. + """ + hidden_states = hidden_states.unsqueeze(1) + hidden_states = torch.nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 83e77460d7..5c418e66b7 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -15,8 +15,10 @@ import contextlib import copy +import functools import importlib.metadata import inspect +import json import math import os import random @@ -46,10 +48,10 @@ is_deepspeed_available, is_deepspeed_zero3_enabled, ) -from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import _get_fsdp_ckpt_kwargs -from transformers.trainer_callback import TrainerCallback, TrainerState +from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerState from transformers.trainer_pt_utils import ( DistributedTensorGatherer, EvalLoopContainer, @@ -57,7 +59,6 @@ LengthGroupedSampler, SequentialDistributedSampler, find_batch_size, - get_dataloader_sampler, get_model_param_count, nested_concat, nested_detach, @@ -78,11 +79,13 @@ get_last_checkpoint, has_length, ) -from transformers.training_args import ParallelMode, TrainingArguments +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments from transformers.utils import ( + ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, @@ -376,13 +379,18 @@ def create_optimizer(self): "eps": self.args.adam_epsilon, } else: - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, self.model) + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, self.model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` # to avoid arguments conflicts. if "optimizer_dict" in optimizer_kwargs: @@ -402,7 +410,7 @@ def _tune_save_checkpoint(self, checkpoint_dir: str): def _wrap_model(self, model, training=True, dataloader=None): # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again - if unwrap_model(model) is not model: + if self.accelerator.unwrap_model(model) is not model: return model # Note: in torch.distributed mode, there's no point in wrapping the model @@ -640,7 +648,11 @@ def _inner_training_loop( if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self.state = TrainerState() + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size @@ -724,6 +736,9 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) if self.is_fsdp_enabled: self.model = self.model_wrapped = model @@ -791,6 +806,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) @@ -844,27 +860,12 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step - self._zero_model_grad(model) _grad_norm: Optional[float] = None - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. - if not args.ignore_data_skip: - for epoch in range(epochs_trained): - sampler = get_dataloader_sampler(train_dataloader) - sampler_kinds = [RandomSampler, SeedableRandomSampler] - is_random_sampler = isinstance(sampler, tuple(sampler_kinds)) - if not is_random_sampler: - # We just need to begin an iteration to create the randomization of the sampler. - for _ in train_dataloader: - break - else: - # Otherwise we need to call the whooooole sampler cause there is some random operation added - # AT THE VERY END! - sampler = sampler if sampler is not None else [] - _ = list(sampler) + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) if self.args.adjust_throughput: self.log_evaluate_save_time = 0 @@ -930,12 +931,17 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): "a `main_input_name` attribute to the model class you are using." ) else: - input_device = inputs[main_input_name].device - self.state.num_input_tokens_seen += torch.sum( - self.accelerator.gather( - torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) + self.state.num_input_tokens_seen += ( + torch.sum( + self.accelerator.gather( + torch.tensor( + inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 + ) + ) ) - ).item() + .cpu() + .item() + ) if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False @@ -1025,9 +1031,11 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): args.max_grad_norm, ) - # Optimizer step optimizer_was_run = True self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated @@ -1111,7 +1119,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): for checkpoint in checkpoints_sorted: if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") - shutil.rmtree(checkpoint) + shutil.rmtree(checkpoint, ignore_errors=True) self.control = self.callback_handler.on_train_end(args, self.state, self.control) @@ -1157,9 +1165,20 @@ def _load_best_model(self): has_been_loaded = True if _is_peft_model(model): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. - if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapter = model.active_adapters[0] + if len(model.active_adapters) > 1: + logger.warning("Detected multiple active adapters, will only consider the first one") + else: + active_adapter = model.active_adapter + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): - model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + model.load_adapter(self.state.best_model_checkpoint, active_adapter) # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys @@ -1190,7 +1209,9 @@ def _load_best_model(self): if has_been_loaded: self._issue_warnings_after_load(load_result) - elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists( + os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME) + ): load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=False) self._issue_warnings_after_load(load_result) else: @@ -1242,15 +1263,7 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign metrics = None if self.control.should_evaluate: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) - self._report_to_hp_search(trial, self.state.global_step, metrics) - - # Run delayed LR scheduler now that metrics are populated - if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - self.lr_scheduler.step(metrics[metric_to_check]) + metrics = self._evaluate(trial, ignore_keys_for_eval) if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) @@ -1324,7 +1337,13 @@ def _save_checkpoint(self, model, trial, metrics=None): metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" - metric_value = metrics[metric_to_check] + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc operator = np.greater if self.args.greater_is_better else np.less if ( @@ -1337,6 +1356,8 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save the Trainer state if self.args.should_save: + # Update the `TrainerControl` state to where we are currently + self.state.stateful_callbacks["TrainerControl"] = self.control.state() self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: @@ -1553,6 +1574,13 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) + del inputs + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training @@ -1568,7 +1596,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te self.model.base_model.update_and_allocate(self.state.global_step) self.accelerator.deepspeed_engine_wrapped.engine.step() else: - self.accelerator.backward(loss) + self.accelerator.backward(loss, **kwargs) self.model.base_model.update_and_allocate(self.state.global_step) else: if self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing: @@ -1577,9 +1605,9 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te # in backward does not automatically run with FP8 precision. In order to handle this, # the backward is run in `fp8_autocast` context with FP8ContextWrapper.create_fp8_context(self.accelerator.fp8_recipe_handler): - self.accelerator.backward(loss) + self.accelerator.backward(loss, **kwargs) else: - self.accelerator.backward(loss) + self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): @@ -1640,8 +1668,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, supported_classes): - if isinstance(unwrap_model(self.model), supported_classes): - unwrap_model(self.model).save_pretrained( + if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): + self.accelerator.unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: @@ -1677,12 +1705,13 @@ def evaluate( 2. use throughput_warmup_steps in evaluation throughput calculation """ # handle multipe eval datasets - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + override = eval_dataset is not None + eval_dataset = eval_dataset if override else self.eval_dataset if isinstance(eval_dataset, dict): metrics = {} for eval_dataset_name, _eval_dataset in eval_dataset.items(): dataset_metrics = self.evaluate( - eval_dataset=_eval_dataset, + eval_dataset=_eval_dataset if override else eval_dataset_name, ignore_keys=ignore_keys, metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", ) @@ -1711,6 +1740,8 @@ def evaluate( total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] num_samples = output.num_samples - self.args.throughput_warmup_steps * total_batch_size num_steps = math.ceil(output.num_samples / total_batch_size) - self.args.throughput_warmup_steps @@ -1755,11 +1786,13 @@ def evaluation_loop( model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) + self.model_preparation_time = round(time.time() - start_time, 4) if self.is_fsdp_enabled: self.model = model @@ -1786,17 +1819,15 @@ def evaluation_loop( ) self.already_wrapped_for_hpu_graphs = True - # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # if full bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: - if args.fp16_full_eval: - model = model.to(dtype=torch.float16, device=args.device) - elif args.bf16_full_eval: + if args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = self.args.eval_batch_size - logger.info(f"***** Running {description} *****") + logger.info(f"\n***** Running {description} *****") if has_length(dataloader): logger.info(f" Num examples = {self.num_examples(dataloader)}") else: @@ -1816,6 +1847,8 @@ def evaluation_loop( all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + metrics = None + # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 @@ -1848,7 +1881,7 @@ def evaluation_loop( inputs["flash_attention_causal_mask"] = True # Prediction step - loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) main_input_name = getattr(self.model, "main_input_name", "input_ids") inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None @@ -1858,17 +1891,19 @@ def evaluation_loop( logits_dtype = get_dtype(logits) # Update containers - if loss is not None: - losses = self.gather_function((loss.repeat(batch_size))) + if losses is not None: + losses = self.gather_function((losses.repeat(batch_size))) all_losses.add(losses) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) - all_labels.add(labels) + if not self.args.batch_eval_metrics or description == "Prediction": + all_labels.add(labels) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) - all_inputs.add(inputs_decode) + if not self.args.batch_eval_metrics or description == "Prediction": + all_inputs.add(inputs_decode) if logits is not None: if args.use_habana and logits_dtype != "float32": logits = to_device_dtype(logits, target_dtype=torch.float32) @@ -1876,17 +1911,36 @@ def evaluation_loop( if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) - all_preds.add(logits) + if not self.args.batch_eval_metrics or description == "Prediction": + all_preds.add(logits) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and logits is not None and labels is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs), + compute_result=is_last_step, + ) + else: + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels), + compute_result=is_last_step, + ) + + del losses, logits, labels, inputs + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: all_losses.to_cpu_and_numpy() all_preds.to_cpu_and_numpy() all_labels.to_cpu_and_numpy() all_inputs.to_cpu_and_numpy() + del losses, logits, labels, inputs + # nested concat does accumulation on tensors of variable length. # Added mark step here to avoid graph recompile if args.use_lazy_mode: @@ -1924,14 +1978,19 @@ def evaluation_loop( all_preds = convert_into_dtypes(all_preds, logits_dtype) # Metrics! - if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + if ( + self.compute_metrics is not None + and all_preds is not None + and all_labels is not None + and not self.args.batch_eval_metrics + ): if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) - else: + elif metrics is None: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors @@ -1941,6 +2000,8 @@ def evaluation_loop( metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() elif isinstance(all_losses, np.ndarray): metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "model_preparation_time"): + metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): @@ -2051,6 +2112,17 @@ def _push_from_checkpoint(self, checkpoint_folder): output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, GAUDI_CONFIG_NAME] + # Add sharded checkpoints if we have an index + for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + index_path = os.path.join(checkpoint_folder, index_file) + if os.path.isfile(index_path): + modeling_files.append(index_file) + with open(index_path) as f: + index = json.loads(f.read()) + shard_files = list(set(index["weight_map"].values())) + modeling_files.extend(shard_files) + if is_peft_available(): + modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) @@ -2164,7 +2236,7 @@ def prediction_loop( batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) - logger.info(f"***** Running {description} *****") + logger.info(f"\n***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") @@ -2172,6 +2244,7 @@ def prediction_loop( preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + metrics: Optional[dict] = None world_size = max(1, args.world_size) @@ -2211,8 +2284,24 @@ def prediction_loop( ) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) - # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and preds_host is not None and labels_host is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host), + compute_result=is_last_step, + ) + else: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host), + compute_result=is_last_step, + ) + + if self.args.batch_eval_metrics or ( + args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 + ): + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) @@ -2220,6 +2309,7 @@ def prediction_loop( inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) # Set back to None to begin a new accumulation + del losses_host, preds_host, labels_host, inputs_host losses_host, preds_host, labels_host, inputs_host = None, None, None, None # nested concat does accumulation on tensors of variable length. @@ -2243,14 +2333,19 @@ def prediction_loop( label_ids = labels_gatherer.finalize() if not prediction_loss_only else None inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None - if self.compute_metrics is not None and preds is not None and label_ids is not None: + if ( + self.compute_metrics is not None + and preds is not None + and label_ids is not None + and not self.args.batch_eval_metrics + ): if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: + elif metrics is None: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors @@ -2295,6 +2390,18 @@ def create_accelerator_and_postprocess(self): even_batches=accelerator_config.pop("even_batches"), use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), ) + non_blocking = accelerator_config.pop("non_blocking") + if not is_accelerate_available("0.30.0"): + if non_blocking: + raise ImportError( + "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." + ) + else: + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning( + "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." + ) + dataloader_config.non_blocking = non_blocking # this would have been updated above, no need for it anymore accelerator_config.pop("gradient_accumulation_kwargs") @@ -2313,6 +2420,11 @@ def create_accelerator_and_postprocess(self): # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index c280e46888..1ab4cc56b3 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -362,34 +362,45 @@ def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN - if isinstance(self.evaluation_strategy, EvaluationStrategy): + if self.evaluation_strategy is not None: warnings.warn( - ( - "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version" - " 5 of 🤗 Transformers. Use `IntervalStrategy` instead" - ), + "`evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead", + FutureWarning, + ) + self.eval_strategy = self.evaluation_strategy + + if isinstance(self.eval_strategy, EvaluationStrategy): + warnings.warn( + "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5" + " of 🤗 Transformers. Use `IntervalStrategy` instead", FutureWarning, ) # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. - self.evaluation_strategy = self.evaluation_strategy.value + self.eval_strategy = self.eval_strategy.value - self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) + self.eval_strategy = IntervalStrategy(self.eval_strategy) self.logging_strategy = IntervalStrategy(self.logging_strategy) self.save_strategy = IntervalStrategy(self.save_strategy) self.hub_strategy = HubStrategy(self.hub_strategy) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) - if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: + if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO: self.do_eval = True + if self.torch_empty_cache_steps is not None: + if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0): + raise ValueError( + f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}." + ) + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero - if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): + if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): if self.logging_steps > 0: logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") self.eval_steps = self.logging_steps else: raise ValueError( - f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or" + f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or" " --logging_steps" ) @@ -401,7 +412,7 @@ def __post_init__(self): if self.logging_steps != int(self.logging_steps): raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") self.logging_steps = int(self.logging_steps) - if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: + if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: if self.eval_steps != int(self.eval_steps): raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") self.eval_steps = int(self.eval_steps) @@ -412,12 +423,12 @@ def __post_init__(self): # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. if self.load_best_model_at_end: - if self.evaluation_strategy != self.save_strategy: + if self.eval_strategy != self.save_strategy: raise ValueError( "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " - f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" + f"strategy: {self.eval_strategy}\n- Save strategy: {self.save_strategy}" ) - if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: if self.eval_steps < 1 or self.save_steps < 1: if not (self.eval_steps < 1 and self.save_steps < 1): raise ValueError( @@ -453,12 +464,12 @@ def __post_init__(self): ) and self.metric_for_best_model is None: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: - self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] + self.greater_is_better = not (self.metric_for_best_model.endswith("loss")) if self.run_name is None: self.run_name = self.output_dir if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: - if self.evaluation_strategy == IntervalStrategy.NO: + if self.eval_strategy == IntervalStrategy.NO: raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") if not is_torch_available(): raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") @@ -477,6 +488,42 @@ def __post_init__(self): if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"): raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher") + # We need to setup the accelerator config here *before* the first call to `self.device` + if is_accelerate_available(): + if not isinstance(self.accelerator_config, (AcceleratorConfig)): + if self.accelerator_config is None: + self.accelerator_config = AcceleratorConfig() + elif isinstance(self.accelerator_config, dict): + self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) + else: + self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + + if self.dispatch_batches is not None: + warnings.warn( + "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'dispatch_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.dispatch_batches = self.dispatch_batches + + if self.split_batches is not None: + warnings.warn( + "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'split_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.split_batches = self.split_batches + + if self.dataloader_drop_last: + self.accelerator_config.even_batches = False + if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: assert get_habana_frameworks_version().minor > 12, "Torch compile is not available" self.torch_compile = True @@ -514,6 +561,13 @@ def __post_init__(self): from transformers.integrations import get_available_reporting_integrations self.report_to = get_available_reporting_integrations() + + if "codecarbon" in self.report_to and torch.version.hip: + logger.warning( + "When using the Trainer, CodeCarbonCallback requires the `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). Automatically disabling the codecarbon callback. Reference: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to." + ) + self.report_to.remove("codecarbon") + elif self.report_to == "none" or self.report_to == ["none"]: self.report_to = [] elif not isinstance(self.report_to, list): @@ -527,10 +581,13 @@ def __post_init__(self): " during training" ) + if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0 or 0 < self.warmup_steps <= 1: + raise ValueError("warmup_steps must be either 0 or > 1") + # Copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/training_args.py#L1563 # except following changes, (1) Remove XLA specific code & (2) change fsdp_backward_prefetch to backward_prefetch if isinstance(self.fsdp, bool): - self.fsdp = "full_shard" if self.fsdp else "" + self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else "" if isinstance(self.fsdp, str): self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] if self.fsdp == [FSDPOption.OFFLOAD]: @@ -541,6 +598,15 @@ def __post_init__(self): elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + if self.gradient_checkpointing and ( + FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp + ): + logger.warning( + "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" + " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather" + " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404" + ) + if self.fsdp_config is None: self.fsdp_config = {} @@ -616,46 +682,24 @@ def __post_init__(self): ) prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() - os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")) - os.environ[f"{prefix}SYNC_MODULE_STATES"] = str(self.fsdp_config.get("sync_module_states", "true")) - os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")) - os.environ[f"{prefix}ACTIVATION_CHECKPOINTING"] = str( - self.fsdp_config.get("activation_checkpointing", "false") - ) + os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() - if is_accelerate_available(): - if not isinstance(self.accelerator_config, (AcceleratorConfig)): - if self.accelerator_config is None: - self.accelerator_config = AcceleratorConfig() - elif isinstance(self.accelerator_config, dict): - self.accelerator_config = AcceleratorConfig(**self.accelerator_config) - # Check that a user didn't pass in the class instantiator - # such as `accelerator_config = AcceleratorConfig` - elif isinstance(self.accelerator_config, type): - raise NotImplementedError( - "Tried passing in a callable to `accelerator_config`, but this is not supported. " - "Please pass in a fully constructed `AcceleratorConfig` object instead." - ) - else: - self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) - if self.dispatch_batches is not None: - warnings.warn( - "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" - " `--accelerator_config {'dispatch_batches':VALUE} instead", - FutureWarning, - ) - self.accelerator_config.dispatch_batches = self.dispatch_batches + sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() + cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() - if self.split_batches is not None: - warnings.warn( - "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" - " `--accelerator_config {'split_batches':VALUE} instead", - FutureWarning, - ) - self.accelerator_config.split_batches = self.split_batches + if sync_module_states == "false" and cpu_ram_efficient_loading == "true": + # In this case, all the processes except the main process would have random weights leading + # to unexpected behaviour during training, thus throwing error here to prevent it. + raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') - if self.dataloader_drop_last: - self.accelerator_config.even_batches = False + os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading + + os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() + + os.environ[f"{prefix}ACTIVATION_CHECKPOINTING"] = str( + self.fsdp_config.get("activation_checkpointing", "false") + ) if isinstance(self.debug, str): self.debug = [DebugOption(s) for s in self.debug.split()] @@ -747,6 +791,12 @@ def __post_init__(self): FutureWarning, ) + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` > 0.30.0." + "This is not supported and we recommend you to update your version." + ) + def __str__(self): self_as_dict = asdict(self) @@ -785,9 +835,30 @@ def _setup_devices(self) -> "torch.device": f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " "Please run `pip install transformers[torch]` or `pip install accelerate -U`" ) - GaudiAcceleratorState._reset_state() - GaudiPartialState._reset_state() - self.distributed_state = None + # We delay the init of `PartialState` to the end for clarity + accelerator_state_kwargs = {"enabled": True, "use_configured_state": False} + if isinstance(self.accelerator_config, AcceleratorConfig): + accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( + "use_configured_state", False + ) + if accelerator_state_kwargs["use_configured_state"]: + if GaudiPartialState._shared_state == {}: + raise ValueError( + "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured " + "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. " + ) + # We rely on `PartialState` to yell if there's issues here (which it will) + self.distributed_state = GaudiPartialState(cpu=self.use_cpu) + if self.deepspeed and self.distributed_state.distributed_type != GaudiDistributedType.DEEPSPEED: + raise RuntimeError( + "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, " + "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set " + "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly." + ) + else: + GaudiAcceleratorState._reset_state() + GaudiPartialState._reset_state() + self.distributed_state = None # Set the log level here for optimum.utils.logging # otherwise logs are not sent in this method. @@ -796,8 +867,11 @@ def _setup_devices(self) -> "torch.device": if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: os.environ["ACCELERATE_USE_IPEX"] = "false" + + self._n_gpu = 1 if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): - self.distributed_state = GaudiPartialState(cpu=True, backend=self.ddp_backend) + accelerator_state_kwargs["cpu"] = True + accelerator_state_kwargs["backend"] = self.ddp_backend self._n_gpu = 0 elif self.use_habana: # Some methods needs to be tweaked to optimally run on Gaudi @@ -816,20 +890,28 @@ def _setup_devices(self) -> "torch.device": ) if self.deepspeed: - # Need to do similar for Accelerator init - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - self.distributed_state = GaudiPartialState(timeout=timedelta(seconds=self.ddp_timeout)) - del os.environ["ACCELERATE_USE_DEEPSPEED"] + accelerator_state_kwargs["use_deepspeed"] = True + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) else: - self.distributed_state = GaudiPartialState( - backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) - ) - self._n_gpu = 1 + accelerator_state_kwargs["backend"] = self.ddp_backend + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) else: raise ValueError( "No device has been set. Use either --use_habana to run on HPU or --no_cuda to run on CPU." ) + # Now we pop everything + if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop( + "use_configured_state", False + ): + # We need to patch this env var when enabling to detect deepspeed + use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False) + if use_deepspeed: + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = GaudiPartialState(**accelerator_state_kwargs) + if use_deepspeed: + del os.environ["ACCELERATE_USE_DEEPSPEED"] + device = self.distributed_state.device self.local_rank = self.distributed_state.local_process_index if ( diff --git a/setup.py b/setup.py index e71577deb4..cea680353e 100644 --- a/setup.py +++ b/setup.py @@ -29,10 +29,10 @@ INSTALL_REQUIRES = [ - "transformers >= 4.40.0, < 4.41.0", + "transformers >= 4.43.0, < 4.44.0", "optimum", "torch", - "accelerate < 0.28.0", + "accelerate >= 0.33.0, < 0.34.0", "diffusers == 0.29.2", "huggingface_hub >= 0.23.2", "sentence-transformers[train] == 3.0.1", @@ -51,6 +51,7 @@ "scipy", "torchsde", "timm", + "peft", ] QUALITY_REQUIRES = [ diff --git a/tests/baselines/Qwen2_7B.json b/tests/baselines/Qwen2_7B.json index 453d60848a..844f57b729 100644 --- a/tests/baselines/Qwen2_7B.json +++ b/tests/baselines/Qwen2_7B.json @@ -16,7 +16,7 @@ "--packing True", "--gradient_accumulation_steps 8", "--gradient_checkpointing True", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--throughput_warmup_steps 5", "--warmup_ratio 0.03", @@ -53,7 +53,7 @@ "--packing True", "--gradient_accumulation_steps 8", "--gradient_checkpointing True", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--throughput_warmup_steps 5", "--warmup_ratio 0.03", diff --git a/tests/baselines/bridgetower_large_itm_mlm_itc.json b/tests/baselines/bridgetower_large_itm_mlm_itc.json index e188228256..9b2a27509e 100644 --- a/tests/baselines/bridgetower_large_itm_mlm_itc.json +++ b/tests/baselines/bridgetower_large_itm_mlm_itc.json @@ -19,7 +19,8 @@ "--dataloader_num_workers 2", "--logging_steps 10", "--use_hpu_graphs_for_inference", - "--distribution_strategy fast_ddp" + "--distribution_strategy fast_ddp", + "--trust_remote_code True" ] } } diff --git a/tests/baselines/clip_roberta.json b/tests/baselines/clip_roberta.json index e2b14ba68e..18d80762cc 100755 --- a/tests/baselines/clip_roberta.json +++ b/tests/baselines/clip_roberta.json @@ -20,7 +20,8 @@ "--save_strategy epoch", "--use_hpu_graphs_for_training", "--use_hpu_graphs_for_inference", - "--dataloader_num_workers 16" + "--dataloader_num_workers 16", + "--trust_remote_code True" ] } } @@ -49,7 +50,8 @@ "--use_hpu_graphs_for_inference", "--dataloader_num_workers 16", "--distribution_strategy fast_ddp", - "--mediapipe_dataloader" + "--mediapipe_dataloader", + "--trust_remote_code True" ] } } diff --git a/tests/baselines/falcon_40b.json b/tests/baselines/falcon_40b.json index d3563466c5..4a91bf9a7a 100644 --- a/tests/baselines/falcon_40b.json +++ b/tests/baselines/falcon_40b.json @@ -13,7 +13,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 16", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -48,7 +48,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 16", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", diff --git a/tests/baselines/llama_7b.json b/tests/baselines/llama_7b.json index ff9f3f1485..1c303c9d9c 100644 --- a/tests/baselines/llama_7b.json +++ b/tests/baselines/llama_7b.json @@ -13,7 +13,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 1", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -74,7 +74,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 1", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -110,7 +110,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 2", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -145,7 +145,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 2", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -179,7 +179,7 @@ "extra_arguments": [ "--bf16 True", "--gradient_accumulation_steps 2", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -218,7 +218,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 2", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -255,7 +255,7 @@ "extra_arguments": [ "--bf16 True", "--gradient_accumulation_steps 2", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", @@ -450,7 +450,7 @@ "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 1", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--warmup_ratio 0.03", "--lr_scheduler_type constant", diff --git a/tests/baselines/wav2vec2_large_lv60.json b/tests/baselines/wav2vec2_large_lv60.json index d645ced656..920239618b 100644 --- a/tests/baselines/wav2vec2_large_lv60.json +++ b/tests/baselines/wav2vec2_large_lv60.json @@ -21,7 +21,8 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'" + "--chars_to_ignore ',?.!-;:\"“%‘”'", + "--trust_remote_code True" ] } } @@ -35,7 +36,7 @@ "multi_card": { "learning_rate": 4e-4, "train_batch_size": 8, - "eval_wer": 0.06120587068623562, + "eval_wer": 0.11090, "train_runtime": 308.8036, "train_samples_per_second": 225.572, "eval_samples_per_second": 196.665, @@ -51,10 +52,11 @@ "--dataloader_num_workers 8", "--chars_to_ignore ',?.!-;:\"“%‘”'", "--use_hpu_graphs_for_training", - "--use_hpu_graphs_for_inference" + "--use_hpu_graphs_for_inference", + "--trust_remote_code True" ] } } } } -} \ No newline at end of file +} diff --git a/tests/baselines/whisper_small.json b/tests/baselines/whisper_small.json index 6f3b01a412..5b44467f71 100644 --- a/tests/baselines/whisper_small.json +++ b/tests/baselines/whisper_small.json @@ -26,7 +26,8 @@ "--predict_with_generate", "--use_hpu_graphs_for_inference", "--label_features_max_length 128", - "--pipelining_fwd_bwd True" + "--pipelining_fwd_bwd True", + "--trust_remote_code True" ] } } @@ -58,7 +59,8 @@ "--dataloader_num_workers 8", "--predict_with_generate", "--use_hpu_graphs_for_inference", - "--label_features_max_length 128" + "--label_features_max_length 128", + "--trust_remote_code True" ] } } diff --git a/tests/example_diff/run_audio_classification.txt b/tests/example_diff/run_audio_classification.txt index d7b474164d..68cea814ae 100644 --- a/tests/example_diff/run_audio_classification.txt +++ b/tests/example_diff/run_audio_classification.txt @@ -33,8 +33,8 @@ < check_min_version("4.44.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") 174,176d175 < freeze_feature_extractor: Optional[bool] = field( < default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."} diff --git a/tests/example_diff/run_clip.txt b/tests/example_diff/run_clip.txt index 1099d3c94a..4d27570a7e 100644 --- a/tests/example_diff/run_clip.txt +++ b/tests/example_diff/run_clip.txt @@ -28,8 +28,8 @@ < check_min_version("4.44.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") 181a190,192 > mediapipe_dataloader: bool = field( > default=False, metadata={"help": "Turn on MediaPipe hardware-based accelerated data loading."} diff --git a/tests/example_diff/run_clm.txt b/tests/example_diff/run_clm.txt index c91df2d5cd..95d2758bb4 100644 --- a/tests/example_diff/run_clm.txt +++ b/tests/example_diff/run_clm.txt @@ -38,8 +38,8 @@ > 63a64,69 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") > diff --git a/tests/example_diff/run_glue.txt b/tests/example_diff/run_glue.txt index 46005ba396..12c1867ba6 100644 --- a/tests/example_diff/run_glue.txt +++ b/tests/example_diff/run_glue.txt @@ -27,8 +27,8 @@ > logger = logging.getLogger(__name__) > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") 67,68d76 < logger = logging.getLogger(__name__) < diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt index 49ab2bb6a1..1a31bd5e7f 100644 --- a/tests/example_diff/run_image_classification.txt +++ b/tests/example_diff/run_image_classification.txt @@ -28,8 +28,8 @@ < check_min_version("4.44.0.dev0") --- > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") 184c192 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- diff --git a/tests/example_diff/run_mlm.txt b/tests/example_diff/run_mlm.txt index 3e4f6c5863..5617bc0cc0 100644 --- a/tests/example_diff/run_mlm.txt +++ b/tests/example_diff/run_mlm.txt @@ -34,8 +34,8 @@ 61a62,69 > > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") > diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt index 961785aaac..d845803f0c 100644 --- a/tests/example_diff/run_qa.txt +++ b/tests/example_diff/run_qa.txt @@ -32,8 +32,8 @@ > 58a62,67 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") > diff --git a/tests/example_diff/run_seq2seq_qa.txt b/tests/example_diff/run_seq2seq_qa.txt index 322661ff62..171b192d39 100644 --- a/tests/example_diff/run_seq2seq_qa.txt +++ b/tests/example_diff/run_seq2seq_qa.txt @@ -24,8 +24,8 @@ > 54a58,63 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") > diff --git a/tests/example_diff/run_speech_recognition_ctc.txt b/tests/example_diff/run_speech_recognition_ctc.txt index a99ee732b3..d18ad32353 100644 --- a/tests/example_diff/run_speech_recognition_ctc.txt +++ b/tests/example_diff/run_speech_recognition_ctc.txt @@ -25,8 +25,8 @@ > return () 59a61,66 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") > diff --git a/tests/example_diff/run_speech_recognition_seq2seq.txt b/tests/example_diff/run_speech_recognition_seq2seq.txt index 196d356171..154914ba68 100644 --- a/tests/example_diff/run_speech_recognition_seq2seq.txt +++ b/tests/example_diff/run_speech_recognition_seq2seq.txt @@ -22,8 +22,8 @@ 51c58,59 < check_min_version("4.44.0.dev0") --- -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") 230a239,242 > label_features_max_length: int = field( > default=None, diff --git a/tests/example_diff/run_summarization.txt b/tests/example_diff/run_summarization.txt index 81868ab221..c2d4b27005 100644 --- a/tests/example_diff/run_summarization.txt +++ b/tests/example_diff/run_summarization.txt @@ -36,8 +36,8 @@ > 60a67,72 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") > diff --git a/tests/example_diff/run_translation.txt b/tests/example_diff/run_translation.txt index e7038d847c..efc8a4b4da 100644 --- a/tests/example_diff/run_translation.txt +++ b/tests/example_diff/run_translation.txt @@ -28,8 +28,8 @@ > 60a64,69 > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. -> check_min_version("4.40.0") -> check_optimum_habana_min_version("1.11.0") +> check_min_version("4.43.0") +> check_optimum_habana_min_version("1.12.0") > > require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") > diff --git a/tests/resource/sample_text.txt b/tests/resource/sample_text.txt new file mode 100644 index 0000000000..a42812060c --- /dev/null +++ b/tests/resource/sample_text.txt @@ -0,0 +1,33 @@ +This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত +Text should be one-sentence-per-line, with empty lines between documents. +This sample text is public domain and was randomly selected from Project Guttenberg. + +The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. +Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. +Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. +"Cass" Beard had risen early that morning, but not with a view to discovery. +A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. +The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. +This was nearly opposite. +Mr. Cassius crossed the highway, and stopped suddenly. +Something glittered in the nearest red pool before him. +Gold, surely! +But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. +Looking at it more attentively, he saw that it bore the inscription, "May to Cass." +Like most of his fellow gold-seekers, Cass was superstitious. + +The fountain of classic wisdom, Hypatia herself. +As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. +From my youth I felt in me a soul above the matter-entangled herd. +She revealed to me the glorious fact, that I am a spark of Divinity itself. +A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. +There is a philosophic pleasure in opening one's treasures to the modest young. +Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. +Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; +but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. +Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. +His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; +while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. +At last they reached the quay at the opposite end of the street; +and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. +He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 7380681142..c4b2104511 100755 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -971,11 +971,9 @@ def test_stable_diffusion_xl_euler(self): image_slice = image[-3:, -3:, -1] self.assertEqual(image.shape, (64, 64, 3)) - expected_slice = np.array([0.5552, 0.5569, 0.4725, 0.4348, 0.4994, 0.4632, 0.5142, 0.5012, 0.47]) + expected_slice = np.array([0.5388, 0.5451, 0.4694, 0.4582, 0.5252, 0.4832, 0.5288, 0.5034, 0.4766]) - # The threshold should be 1e-2 below but it started failing - # from Diffusers v0.24. However, generated images still look similar. - self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-1) + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) def test_stable_diffusion_xl_euler_ancestral(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -991,7 +989,7 @@ def test_stable_diffusion_xl_euler_ancestral(self): image_slice = image[-3:, -3:, -1] self.assertEqual(image.shape, (64, 64, 3)) - expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791]) + expected_slice = np.array([0.4539, 0.5119, 0.4521, 0.4395, 0.5495, 0.49344, 0.5761, 0.5147, 0.4943]) self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) def test_stable_diffusion_xl_turbo_euler_ancestral(self): @@ -1010,7 +1008,7 @@ def test_stable_diffusion_xl_turbo_euler_ancestral(self): image_slice = image[-3:, -3:, -1] self.assertEqual(image.shape, (64, 64, 3)) - expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791]) + expected_slice = np.array([0.4539, 0.5119, 0.4521, 0.4395, 0.5495, 0.49344, 0.5761, 0.5147, 0.4943]) self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) @parameterized.expand(["pil", "np", "latent"]) @@ -2438,7 +2436,7 @@ def test_stable_video_diffusion_single_video(self): self.assertEqual(len(outputs), 1) self.assertEqual(image.shape, (2, 3, 32, 32)) - expected_slice = np.array([0.5910, 0.5797, 0.5521, 0.6628, 0.6212, 0.6422, 0.5681, 0.5232, 0.5343]) + expected_slice = np.array([0.6208, 0.5780, 0.5447, 0.6462, 0.6285, 0.6288, 0.5334, 0.5287, 0.5165]) self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) @@ -4731,7 +4729,7 @@ def test_stable_diffusion_xl_inpaint_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123]) + expected_slice = np.array([0.8279, 0.5673, 0.6088, 0.6156, 0.6923, 0.7347, 0.6547, 0.6108, 0.5198]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -4833,7 +4831,7 @@ def test_stable_diffusion_xl_refiner(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922]) + expected_slice = np.array([0.7540, 0.5231, 0.5833, 0.6217, 0.6339, 0.7067, 0.6507, 0.5672, 0.5030]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_fp8_examples.py b/tests/test_fp8_examples.py index 54f5a7a63b..27020a2b8f 100644 --- a/tests/test_fp8_examples.py +++ b/tests/test_fp8_examples.py @@ -71,7 +71,7 @@ def _test_fp8_train( if model_name == "mistralai/Mistral-7B-Instruct-v0.2": command += [ "--num_train_epochs 3", - "--evaluation_strategy no", + "--eval_strategy no", "--save_strategy no", "--learning_rate 4e-4", "--warmup_ratio 0.03", diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py index b3af2a05cc..7d8128b765 100644 --- a/tests/test_fsdp_examples.py +++ b/tests/test_fsdp_examples.py @@ -104,7 +104,7 @@ def _test_fsdp( "--bf16 True ", "--gradient_accumulation_steps 2", "--save_strategy 'no'", - "--evaluation_strategy 'no'", + "--eval_strategy 'no'", "--learning_rate 0.0003", "--warmup_ratio 0.03", "--max_grad_norm 0.3", diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index c4e2fd340b..c2ac38e873 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -53,10 +53,10 @@ ("meta-llama/Llama-2-70b-hf", 4, 750, False, 128, 2048, 7422.4), ("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, 568.5), ("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, 4656.2), - ("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128, 12397.11410288204), - ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048, 5394.675714459493), - ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128, 919.8470890081497), - ("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048, 2471.950758729518), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128, 17068.965283763682), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048, 6979.225194247115), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128, 1681.4401450088983), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048, 3393.149396451692), ("mistralai/Mixtral-8x7B-v0.1", 1, 1, True, 128, 128, 39.26845661768185), ("microsoft/phi-2", 1, 1, True, 128, 128, 254.08932787178165), ], @@ -75,6 +75,9 @@ "distributed_tp": [ ("meta-llama/Llama-2-7b-hf", 1345.2369318328463), ], + "contrastive_search": [ + ("gpt2-xl", 1, False, 51.61471298016438), + ], } else: # Gaudi1 CI baselines @@ -106,6 +109,9 @@ "torch_compile": [], "torch_compile_distributed": [], "distributed_tp": [], + "contrastive_search": [ + ("gpt2-xl", 1, False, 34.48141280163397), + ], } @@ -122,6 +128,7 @@ def _test_text_generation( max_input_tokens: int = 0, max_output_tokens: int = 100, parallel_strategy: str = None, + contrastive_search: bool = False, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -177,6 +184,9 @@ def _test_text_generation( if not deepspeed: command.append("--bf16") + if contrastive_search: + command += ["--top_k 4", "--penalty_alpha 0.5"] + if fp8: if "--trim_logits" not in command: command += ["--trim_logits"] @@ -185,6 +195,11 @@ def _test_text_generation( command.insert(-2, "--flash_attention_recompute") command.insert(-2, "--bucket_size 128") command.insert(-2, "--bucket_internal") + if "Mistral" in model_name: + command.insert(-2, "--use_flash_attention") + command.insert(-2, "--flash_attention_recompute") + command.insert(-2, "--attn_softmax_bf16") + command.insert(-2, "--trim_logits") elif "falcon-180b" in model_name.lower(): command.insert(-2, "--flash_attention_recompute") @@ -327,6 +342,13 @@ def test_text_generation_distributed_tp(model_name: str, baseline: float, token: ) +@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["contrastive_search"]) +def test_text_generation_contrastive_search( + model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str +): + _test_text_generation(model_name, baseline, token, batch_size, reuse_cache, contrastive_search=True) + + class TextGenPipeline(TestCase): def test_text_generation_pipeline_script(self): path_to_script = ( diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 8be28e79c9..bcbb9521eb 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -33,7 +33,11 @@ from pytest import mark from requests.exceptions import HTTPError from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + GPT2LMHeadModel, IntervalStrategy, + LineByLineTextDataset, PretrainedConfig, TrainerCallback, get_polynomial_decay_schedule_with_warmup, @@ -52,6 +56,7 @@ is_staging_test, require_accelerate, require_optuna, + require_peft, require_safetensors, require_sentencepiece, require_tensorboard, @@ -72,6 +77,7 @@ from transformers.utils.hp_naming import TrialShortNamer from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.accelerate import GaudiAccelerator, GaudiAcceleratorState from optimum.utils import logging @@ -81,7 +87,6 @@ from torch import nn from torch.utils.data import IterableDataset from transformers import EarlyStoppingCallback, GPT2Config, PreTrainedModel, TrainerState - from transformers.modeling_utils import unwrap_model from optimum.habana import GaudiTrainer from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -93,10 +98,11 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") -PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" +PATH_SAMPLE_TEXT = f"{get_tests_dir()}/resource/sample_text.txt" adapt_transformers_to_gaudi() @@ -192,6 +198,27 @@ def __call__(self, eval_pred): return {"accuracy": true.astype(np.float32).mean().item()} +class AlmostAccuracyBatched: + def __init__(self, thresh=0.25): + self.thresh = thresh + self.batch_acc = [] + + def __call__(self, eval_pred, compute_result): + predictions, labels = eval_pred + if isinstance(predictions, tuple): + predictions = predictions[0] + if isinstance(labels, tuple): + labels = labels[0] + batch_size = len(predictions) + true = torch.abs(predictions - labels) <= self.thresh + acc = true.type(torch.FloatTensor).mean().item() + self.batch_acc.extend([acc] * batch_size) + if compute_result: + result = {"accuracy": np.mean(self.batch_acc).item()} + self.batch_acc = [] + return result + + class RegressionModelConfig(PretrainedConfig): def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs): super().__init__(**kwargs) @@ -568,7 +595,9 @@ def test_trainer_with_datasets(self): # Base training. Should have the same results as test_reproducible_training model = RegressionModel() - args = GaudiTrainingArguments("./regression", learning_rate=0.1, use_habana=True, use_lazy_mode=True) + args = GaudiTrainingArguments( + "./regression", learning_rate=0.1, use_habana=True, use_lazy_mode=True, report_to="none" + ) trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset) trainer.train() self.check_trained_model(trainer.model) @@ -591,7 +620,9 @@ def test_trainer_with_datasets(self): def test_model_init(self): train_dataset = RegressionDataset() gaudi_config = get_gaudi_config() - args = GaudiTrainingArguments("./regression", learning_rate=0.1, use_habana=True, use_lazy_mode=True) + args = GaudiTrainingArguments( + "./regression", learning_rate=0.1, use_habana=True, use_lazy_mode=True, report_to="none" + ) trainer = GaudiTrainer( gaudi_config=gaudi_config, args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel() ) @@ -617,7 +648,7 @@ def test_gradient_accumulation(self): # The test below is commented because it leads to a core dumped error # when it is run with all other tests. It passes when run alone. - # It seems to be cause by setting `use_reentrant` to False in + # It seems to be caused by setting `use_reentrant` to False in # gradient checkpointing. # def test_gradient_checkpointing(self): # trainer = get_regression_trainer( @@ -662,7 +693,7 @@ def test_custom_optimizer(self): train_dataset = RegressionDataset() gaudi_config = get_gaudi_config() gaudi_config.use_fused_adam = False - args = GaudiTrainingArguments("./regression", use_habana=True, use_lazy_mode=True) + args = GaudiTrainingArguments("./regression", use_habana=True, use_lazy_mode=True, report_to="none") model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=1.0) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0) @@ -690,6 +721,7 @@ def test_lr_scheduler_kwargs(self): warmup_steps=num_warmup_steps, use_habana=True, use_lazy_mode=True, + report_to="none", ) gaudi_config = get_gaudi_config() trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset) @@ -719,6 +751,7 @@ def test_cosine_with_min_lr_scheduler(self): warmup_steps=num_warmup_steps, use_habana=True, use_lazy_mode=True, + report_to="none", ) trainer = GaudiTrainer(model, gaudi_config=get_gaudi_config(), args=args, train_dataset=train_dataset) trainer.create_optimizer_and_scheduler(num_training_steps=num_steps) @@ -739,10 +772,11 @@ def test_reduce_lr_on_plateau_args(self): gaudi_config.use_fused_adam = False args = GaudiTrainingArguments( "./regression", - evaluation_strategy="epoch", + eval_strategy="epoch", metric_for_best_model="eval_loss", use_habana=True, use_lazy_mode=True, + report_to="none", ) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=1.0) @@ -780,10 +814,11 @@ def log(self, logs): args = GaudiTrainingArguments( "./regression", lr_scheduler_type="reduce_lr_on_plateau", - evaluation_strategy="epoch", + eval_strategy="epoch", metric_for_best_model="eval_loss", num_train_epochs=10, learning_rate=0.2, + report_to="none", use_habana=True, use_lazy_mode=True, ) @@ -818,7 +853,7 @@ def test_adafactor_lr_none(self): from transformers.optimization import Adafactor, AdafactorSchedule train_dataset = RegressionDataset() - args = GaudiTrainingArguments("./regression", use_habana=True, use_lazy_mode=True) + args = GaudiTrainingArguments("./regression", use_habana=True, use_lazy_mode=True, report_to="none") gaudi_config = get_gaudi_config() gaudi_config.use_fused_adam = False model = RegressionModel().to("hpu") @@ -888,7 +923,7 @@ def test_trainer_works_with_dict(self): eval_dataset = RegressionDataset() model = RegressionDictModel() gaudi_config = get_gaudi_config() - args = GaudiTrainingArguments("./regression", use_habana=True, use_lazy_mode=True) + args = GaudiTrainingArguments("./regression", use_habana=True, use_lazy_mode=True, report_to="none") trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset, eval_dataset=eval_dataset) trainer.train() _ = trainer.evaluate() @@ -899,7 +934,7 @@ def test_evaluation_with_keys_to_drop(self): tiny_gpt2 = GaudiGPT2LMHeadModel(config) x = torch.randint(0, 100, (128,)) eval_dataset = RepeatDataset(x) - args = GaudiTrainingArguments("./test", use_habana=True, use_lazy_mode=True) + args = GaudiTrainingArguments("./test", use_habana=True, use_lazy_mode=True, report_to="none") gaudi_config = get_gaudi_config() trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, eval_dataset=eval_dataset) # By default the past_key_values are removed @@ -936,6 +971,65 @@ def test_number_of_steps_in_training(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) + @require_peft + def test_multiple_peft_adapters(self): + from peft import LoraConfig, get_peft_model + + # Tests if resuming from checkpoint works if the model has multiple adapters + + MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID) + + peft_config = LoraConfig( + r=4, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + tiny_model = get_peft_model(tiny_model, peft_config, "adapter1") + tiny_model.add_adapter("adapter2", peft_config) + + train_dataset = LineByLineTextDataset( + tokenizer=tokenizer, + file_path=PATH_SAMPLE_TEXT, + block_size=tokenizer.max_len_single_sentence, + ) + for example in train_dataset.examples: + example["labels"] = example["input_ids"] + + tokenizer.pad_token = tokenizer.eos_token + + with tempfile.TemporaryDirectory() as tmpdir: + args = GaudiTrainingArguments( + tmpdir, + per_device_train_batch_size=1, + learning_rate=1e-9, + save_steps=5, + logging_steps=5, + max_steps=10, + use_habana=True, + use_lazy_mode=True, + ) + gaudi_config = get_gaudi_config() + trainer = GaudiTrainer(tiny_model, gaudi_config, args, tokenizer=tokenizer, train_dataset=train_dataset) + + trainer.train() + parameters = dict(tiny_model.named_parameters()) + state = dataclasses.asdict(trainer.state) + + # Reinitialize trainer + trainer = GaudiTrainer(tiny_model, gaudi_config, args, tokenizer=tokenizer, train_dataset=train_dataset) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + trainer.train(resume_from_checkpoint=checkpoint) + parameters1 = dict(tiny_model.named_parameters()) + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(parameters, parameters1) + self.check_trainer_state_are_the_same(state, state1) + # TODO: investigate why this test fails # def test_neftune(self): # config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) @@ -945,7 +1039,7 @@ def test_number_of_steps_in_training(self): # # Trainer without inf/nan filter # args = GaudiTrainingArguments( - # "./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4, use_habana=True, use_lazy_mode=True, + # "./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4, use_habana=True, use_lazy_mode=True, report_to="none" # ) # gaudi_config = get_gaudi_config() # trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, train_dataset=train_dataset) @@ -963,7 +1057,7 @@ def test_number_of_steps_in_training(self): # tiny_gpt2 = GPT2LMHeadModel(config) # # Trainer without inf/nan filter # args = GaudiTrainingArguments( - # "./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4, use_habana=True, use_lazy_mode=True, + # "./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4, use_habana=True, use_lazy_mode=True, report_to="none" # ) # trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, train_dataset=train_dataset) @@ -997,6 +1091,7 @@ def test_logging_inf_nan_filter(self): logging_nan_inf_filter=False, use_habana=True, use_lazy_mode=True, + report_to="none", ) trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, train_dataset=train_dataset) trainer.train() @@ -1010,6 +1105,7 @@ def test_logging_inf_nan_filter(self): logging_nan_inf_filter=True, use_habana=True, use_lazy_mode=True, + report_to="none", ) trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, train_dataset=train_dataset) trainer.train() @@ -1053,16 +1149,120 @@ def test_train_and_eval_dataloaders(self): # tests that we do not require dataloader to have a .dataset attribute def test_dataloader_without_dataset(self): train_dataset = RegressionDataset(length=128) - args = GaudiTrainingArguments(output_dir="tmp_trainer", use_habana=True, use_lazy_mode=True) - trainer = CustomDataloaderTrainer( - model=RegressionModel(), - gaudi_config=get_gaudi_config(), - args=args, + with tempfile.TemporaryDirectory() as tmp_dir: + args = GaudiTrainingArguments(output_dir=tmp_dir, use_habana=True, use_lazy_mode=True, report_to="none") + trainer = CustomDataloaderTrainer( + model=RegressionModel(), + gaudi_config=get_gaudi_config(), + args=args, + train_dataset=train_dataset, + eval_dataset=train_dataset, + ) + trainer.train() + trainer.evaluate() + + def test_get_eval_dataloader_without_persistent_workers(self): + train_dataset = RegressionDataset() + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + args = GaudiTrainingArguments( + "./test", + report_to="none", + dataloader_persistent_workers=False, + use_habana=True, + use_lazy_mode=True, + ) + + # Single evaluation dataset + eval_dataset = RegressionDataset() + gaudi_config = get_gaudi_config() + trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + default_dataloader = trainer.get_eval_dataloader() + dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset) + + self.assertEqual(default_dataloader.dataset, eval_dataset) + self.assertEqual(dataloader_with_dataset.dataset, eval_dataset) + self.assertNotEqual(default_dataloader, dataloader_with_dataset) + + # Multiple evaluation datasets + first_dataset = RegressionDataset() + second_dataset = RegressionDataset() + trainer = GaudiTrainer( + tiny_gpt2, + gaudi_config, + args, train_dataset=train_dataset, - eval_dataset=train_dataset, + eval_dataset={"first": first_dataset, "second": second_dataset}, ) - trainer.train() - trainer.evaluate() + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + first_dataloader = trainer.get_eval_dataloader("first") + first_dataloader_repeated = trainer.get_eval_dataloader("first") + second_dataloader = trainer.get_eval_dataloader("second") + second_dataloader_repeated = trainer.get_eval_dataloader("second") + + self.assertEqual(first_dataset, first_dataloader.dataset) + self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset) + self.assertEqual(second_dataset, second_dataloader.dataset) + self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset) + self.assertNotEqual(first_dataloader, first_dataloader_repeated) + self.assertNotEqual(second_dataloader, second_dataloader_repeated) + + def test_get_eval_dataloader_with_persistent_workers(self): + train_dataset = RegressionDataset() + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + args = GaudiTrainingArguments( + "./test", + report_to="none", + dataloader_persistent_workers=True, + dataloader_num_workers=2, + use_habana=True, + use_lazy_mode=True, + ) + + # Single evaluation dataset + eval_dataset = RegressionDataset() + gaudi_config = get_gaudi_config() + trainer = GaudiTrainer(tiny_gpt2, gaudi_config, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + default_dataloader = trainer.get_eval_dataloader() + dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset) + + self.assertEqual(default_dataloader.dataset, eval_dataset) + self.assertEqual(dataloader_with_dataset.dataset, eval_dataset) + self.assertEqual(default_dataloader, dataloader_with_dataset) + + # Multiple evaluation datasets + first_dataset = RegressionDataset() + second_dataset = RegressionDataset() + trainer = GaudiTrainer( + tiny_gpt2, + gaudi_config, + args, + train_dataset=train_dataset, + eval_dataset={"first": first_dataset, "second": second_dataset}, + ) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + first_dataloader = trainer.get_eval_dataloader("first") + first_dataloader_repeated = trainer.get_eval_dataloader("first") + second_dataloader = trainer.get_eval_dataloader("second") + second_dataloader_repeated = trainer.get_eval_dataloader("second") + + self.assertEqual(first_dataset, first_dataloader.dataset) + self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset) + self.assertEqual(second_dataset, second_dataloader.dataset) + self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset) + self.assertEqual(first_dataloader, first_dataloader_repeated) + self.assertEqual(second_dataloader, second_dataloader_repeated) def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel() @@ -1075,6 +1275,7 @@ def test_data_is_not_parallelized_when_model_is_parallel(self): per_device_eval_batch_size=16, use_habana=True, use_lazy_mode=True, + report_to="none", ) gaudi_config = get_gaudi_config() trainer = GaudiTrainer( @@ -1128,6 +1329,49 @@ def test_evaluate(self): expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + def test_evaluate_with_batch_eval_metrics(self): + trainer = get_regression_trainer( + a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.evaluate() + + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + + # With a number of elements not a round multiple of the batch size + trainer = get_regression_trainer( + a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.evaluate() + + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + + # With logits preprocess + trainer = get_regression_trainer( + a=1.5, + b=2.5, + compute_metrics=AlmostAccuracyBatched(), + batch_eval_metrics=True, + preprocess_logits_for_metrics=lambda logits, labels: logits + 1, + ) + results = trainer.evaluate() + + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + def test_predict(self): trainer = get_regression_trainer(a=1.5, b=2.5) preds = trainer.predict(trainer.eval_dataset).predictions @@ -1160,6 +1404,58 @@ def test_predict(self): self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_predict_with_batch_eval_metrics(self): + trainer = get_regression_trainer( + a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.predict(trainer.eval_dataset) + preds = results.predictions + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + gt = 1.5 * x + 2.5 + self.assertTrue(np.allclose(preds, gt)) + expected_acc = AlmostAccuracy()((preds, y))["accuracy"] + self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc) + + # With a number of elements not a round multiple of the batch size + trainer = get_regression_trainer( + a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.predict(trainer.eval_dataset) + preds = results.predictions + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) + expected_acc = AlmostAccuracy()((preds, y))["accuracy"] + self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc) + + # With more than one output of the model + trainer = get_regression_trainer( + a=1.5, b=2.5, double_output=True, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + preds = trainer.predict(trainer.eval_dataset).predictions + x = trainer.eval_dataset.x + self.assertEqual(len(preds), 2) + self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) + self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) + + # With more than one output/label of the model + trainer = get_regression_trainer( + a=1.5, + b=2.5, + double_output=True, + label_names=["labels", "labels_2"], + compute_metrics=AlmostAccuracyBatched(), + batch_eval_metrics=True, + ) + outputs = trainer.predict(trainer.eval_dataset) + preds = outputs.predictions + labels = outputs.label_ids + x = trainer.eval_dataset.x + self.assertEqual(len(preds), 2) + self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) + self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) + self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) + self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_dynamic_shapes(self): eval_dataset = DynamicShapesDataset(batch_size=self.batch_size) model = RegressionModel(a=2, b=1) @@ -1204,7 +1500,12 @@ def test_dynamic_shape_feature(self): gaudi_config = get_gaudi_config() gaudi_config.use_dynamic_shapes = True args = GaudiTrainingArguments( - "./regression", use_habana=True, use_lazy_mode=True, per_device_train_batch_size=1, num_train_epochs=1 + "./regression", + use_habana=True, + use_lazy_mode=True, + per_device_train_batch_size=1, + num_train_epochs=1, + report_to="none", ) model = RegressionModel() trainer = GaudiTrainer( @@ -1220,7 +1521,12 @@ def test_dynamic_shape_feature(self): gaudi_config = get_gaudi_config() gaudi_config.use_dynamic_shapes = False args = GaudiTrainingArguments( - "./regression", use_habana=True, use_lazy_mode=True, per_device_train_batch_size=1, num_train_epochs=1 + "./regression", + use_habana=True, + use_lazy_mode=True, + per_device_train_batch_size=1, + num_train_epochs=1, + report_to="none", ) model = RegressionModel() trainer = GaudiTrainer( @@ -1299,6 +1605,56 @@ def test_safe_checkpoints(self): tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors ) + def test_load_best_model_with_save(self): + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + save_steps=5, + evaluation_strategy="steps", + eval_steps=5, + max_steps=9, + ) + trainer.train() + # Check that we have the last known step: + assert os.path.exists( + os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}") + ), f"Could not find checkpoint-{trainer.state.max_steps}" + # And then check the last step + assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9" + + # Now test that using a limit works + # Should result in: + # - save at step 5 (but is deleted) + # - save at step 10 (loaded in at the end when `load_best_model=True`) + # - save at step 11 + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + save_steps=5, + evaluation_strategy="steps", + eval_steps=5, + load_best_model_at_end=True, + save_total_limit=2, + max_steps=11, + ) + trainer.train() + # Check that we have the last known step: + assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11" + # And then check the last multiple + assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10" + # Finally check that we don't have an old one + assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected" + + # Finally check that the right model was loaded in, checkpoint-10 + # this goes by the last `eval` step check to do so, so it won't be + # the last model *saved* + model_state = trainer.model.state_dict() + final_model_weights = safetensors.torch.load_file( + os.path.join(tmpdir, "checkpoint-10", "model.safetensors") + ) + for k, v in model_state.items(): + assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same" + def test_can_resume_training(self): with tempfile.TemporaryDirectory() as tmpdir: kwargs = { @@ -1620,7 +1976,7 @@ def test_load_best_model_at_end(self): output_dir=tmpdir, learning_rate=0.1, eval_steps=5, - evaluation_strategy="steps", + eval_strategy="steps", save_steps=5, load_best_model_at_end=True, ) @@ -1636,7 +1992,7 @@ def test_load_best_model_at_end(self): output_dir=tmpdir, learning_rate=0.1, eval_steps=5, - evaluation_strategy="steps", + eval_strategy="steps", save_steps=5, load_best_model_at_end=True, metric_for_best_model="accuracy", @@ -1653,7 +2009,7 @@ def test_load_best_model_at_end(self): b=2.5, output_dir=tmpdir, learning_rate=0.1, - evaluation_strategy="epoch", + eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", @@ -1672,7 +2028,7 @@ def test_load_best_model_at_end(self): output_dir=tmpdir, learning_rate=0.1, eval_steps=5, - evaluation_strategy="steps", + eval_strategy="steps", save_steps=5, load_best_model_at_end=True, pretrained=False, @@ -1693,7 +2049,7 @@ def test_load_best_model_from_safetensors(self): output_dir=tmpdir, learning_rate=0.1, eval_steps=5, - evaluation_strategy="steps", + eval_strategy="steps", save_steps=5, load_best_model_at_end=True, save_safetensors=save_safetensors, @@ -1809,7 +2165,7 @@ def test_early_stopping_callback(self): gradient_accumulation_steps=1, per_device_train_batch_size=16, load_best_model_at_end=True, - evaluation_strategy=IntervalStrategy.EPOCH, + eval_strategy=IntervalStrategy.EPOCH, save_strategy=IntervalStrategy.EPOCH, compute_metrics=AlmostAccuracy(), metric_for_best_model="accuracy", @@ -1825,7 +2181,7 @@ def test_early_stopping_callback(self): num_train_epochs=20, gradient_accumulation_steps=1, per_device_train_batch_size=16, - evaluation_strategy=IntervalStrategy.EPOCH, + eval_strategy=IntervalStrategy.EPOCH, compute_metrics=AlmostAccuracy(), metric_for_best_model="accuracy", ) @@ -1840,8 +2196,10 @@ def test_flos_extraction(self): trainer = get_regression_trainer(learning_rate=0.1) def assert_flos_extraction(trainer, wrapped_model_to_check): - self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check)) - self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0) + self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check)) + self.assertGreaterEqual( + getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0 + ) # with plain model assert_flos_extraction(trainer, trainer.model) @@ -1869,7 +2227,7 @@ def test_checkpoint_rotation(self): # With best model at end trainer = get_regression_trainer( - output_dir=tmp_dir, evaluation_strategy="steps", load_best_model_at_end=True, save_total_limit=2 + output_dir=tmp_dir, eval_strategy="steps", load_best_model_at_end=True, save_total_limit=2 ) trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5") self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25]) @@ -1877,7 +2235,7 @@ def test_checkpoint_rotation(self): # Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume # from checkpoint trainer = get_regression_trainer( - output_dir=tmp_dir, evaluation_strategy="steps", load_best_model_at_end=True, save_total_limit=1 + output_dir=tmp_dir, eval_strategy="steps", load_best_model_at_end=True, save_total_limit=1 ) trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-25") self.check_checkpoint_deletion(trainer, tmp_dir, [25]) @@ -1915,14 +2273,15 @@ def test_mem_metrics(self): def test_no_wd_param_group(self): model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)])) gaudi_config = get_gaudi_config() - args = GaudiTrainingArguments(output_dir="./test", use_habana=True, use_lazy_mode=True) - trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args) - trainer.create_optimizer_and_scheduler(10) - wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight'] # fmt: skip - wd_params = [p for n, p in model.named_parameters() if n in wd_names] - no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names] - self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params) - self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) + with tempfile.TemporaryDirectory() as tmp_dir: + args = GaudiTrainingArguments(output_dir=tmp_dir, use_habana=True, use_lazy_mode=True, report_to="none") + trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args) + trainer.create_optimizer_and_scheduler(10) + wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight'] # fmt: skip + wd_params = [p for n, p in model.named_parameters() if n in wd_names] + no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names] + self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params) + self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) def test_accelerator_config_empty(self): # Checks that a config can be made with the defaults if not passed @@ -2053,9 +2412,12 @@ def test_accelerate_config_from_dataclass_grad_accum(self): config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) eval_dataset = SampleIterableDataset() + gaudi_config = get_gaudi_config() with tempfile.TemporaryDirectory() as tmp_dir: - args = RegressionGaudiTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config) - trainer = GaudiTrainer(model=model, args=args, eval_dataset=eval_dataset) + args = RegressionGaudiTrainingArguments( + output_dir=tmp_dir, accelerator_config=accelerator_config, use_habana=True + ) + trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args, eval_dataset=eval_dataset) self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10) self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False) self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False) @@ -2139,12 +2501,27 @@ def test_accelerator_config_only_deprecated_args(self): trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args, eval_dataset=eval_dataset) self.assertEqual(trainer.accelerator.split_batches, True) + def test_accelerator_custom_state(self): + GaudiAcceleratorState._reset_state(reset_partial_state=True) + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(ValueError) as cm: + _ = RegressionGaudiTrainingArguments( + output_dir=tmp_dir, use_habana=True, accelerator_config={"use_configured_state": True} + ) + self.assertIn("Please define this beforehand", str(cm.warnings[0].message)) + _ = GaudiAccelerator() + _ = RegressionGaudiTrainingArguments( + output_dir=tmp_dir, use_habana=True, accelerator_config={"use_configured_state": True} + ) + GaudiAcceleratorState._reset_state(reset_partial_state=True) + @require_accelerate_version_min_0_28 def test_accelerator_config_from_dict_grad_accum_num_steps(self): with tempfile.TemporaryDirectory() as tmp_dir: config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) eval_dataset = SampleIterableDataset() + gaudi_config = get_gaudi_config() # case - TrainingArguments.gradient_accumulation_steps == 1 # - gradient_accumulation_kwargs['num_steps] == 1 @@ -2157,8 +2534,9 @@ def test_accelerator_config_from_dict_grad_accum_num_steps(self): "num_steps": 1, } }, + use_habana=True, ) - trainer = GaudiTrainer(model=model, args=args, eval_dataset=eval_dataset) + trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args, eval_dataset=eval_dataset) self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1) # case - TrainingArguments.gradient_accumulation_steps > 1 @@ -2172,9 +2550,10 @@ def test_accelerator_config_from_dict_grad_accum_num_steps(self): "num_steps": 10, } }, + use_habana=True, ) with self.assertRaises(Exception) as context: - trainer = GaudiTrainer(model=model, args=args, eval_dataset=eval_dataset) + trainer = GaudiTrainer(model=model, gaudi_config=gaudi_config, args=args, eval_dataset=eval_dataset) self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception)) def test_accelerator_config_not_instantiated(self): @@ -2210,6 +2589,48 @@ class CustomTrainingArguments(GaudiTrainingArguments): ) self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + def test_torch_dtype_to_json(self): + @dataclasses.dataclass + class TorchDtypeTrainingArguments(GaudiTrainingArguments): + torch_dtype: torch.dtype = dataclasses.field( + default=torch.float32, + ) + + for dtype in [ + "float32", + "float64", + "complex64", + "complex128", + "bfloat16", + "uint8", + "int8", + "int16", + "int32", + "int64", + "bool", + ]: + torch_dtype = getattr(torch, dtype) + with tempfile.TemporaryDirectory() as tmp_dir: + args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype, use_habana=True) + + args_dict = args.to_dict() + self.assertIn("torch_dtype", args_dict) + self.assertEqual(args_dict["torch_dtype"], dtype) + + @require_accelerate_version_min_0_30 + def test_eval_use_gather_object(self): + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = GaudiTrainingArguments( + "./regression", use_habana=True, use_lazy_mode=True, report_to="none", eval_use_gather_object=True + ) + gaudi_config = get_gaudi_config() + trainer = GaudiTrainer(model, gaudi_config, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + def test_profiling(self): # 24 total steps and compilation takes place during the 1st three steps trainer = get_regression_trainer(profiling_warmup_steps=3, profiling_steps=21) @@ -2299,47 +2720,60 @@ def get_commit_history(self, repo): def test_push_to_hub_with_saves_each_epoch(self): with tempfile.TemporaryDirectory() as tmp_dir: - trainer = get_regression_trainer( - output_dir=os.path.join(tmp_dir, "test-trainer-epoch"), - push_to_hub=True, - hub_token=self._token, - # To avoid any flakiness if the training goes faster than the uploads. - hub_always_push=True, - save_strategy="epoch", - ) - trainer.train() + with self.assertLogs(level="WARNING") as logs: + trainer = get_regression_trainer( + output_dir=os.path.join(tmp_dir, "test-trainer-epoch"), + push_to_hub=True, + hub_token=self._token, + # To avoid any flakiness if the training goes faster than the uploads. + hub_always_push=True, + save_strategy="epoch", + ) + trainer.train() commits = list_repo_commits(f"{USER}/test-trainer-epoch", token=self._token) commits = [c.title for c in commits] self.assertIn("initial commit", commits) - for i in range(1, 4): - self.assertIn(f"Training in progress, epoch {i}", commits) + self.assertIn("Training in progress, epoch 1", commits) + self.assertIn("Training in progress, epoch 2", commits) + # Epochs 3 and 4 are not guaranteed to be present (empty commits) + self.assertTrue(any("Skipping to prevent empty commit." in record.message for record in logs.records)) def test_push_to_hub_with_saves_each_n_steps(self): num_gpus = max(1, get_gpu_count()) if num_gpus > 2: - return + self.skipTest(reason="More than 2 GPUs available") with tempfile.TemporaryDirectory() as tmp_dir: - trainer = get_regression_trainer( - output_dir=os.path.join(tmp_dir, "test-trainer-step"), - push_to_hub=True, - hub_token=self._token, - # To avoid any flakiness if the training goes faster than the uploads. - hub_always_push=True, - save_strategy="steps", - save_steps=5, - ) - trainer.train() + with self.assertLogs(level="WARNING") as logs: + trainer = get_regression_trainer( + output_dir=os.path.join(tmp_dir, "test-trainer-step"), + push_to_hub=True, + hub_token=self._token, + # To avoid any flakiness if the training goes faster than the uploads. + hub_always_push=True, + save_strategy="steps", + save_steps=5, + ) + trainer.train() commits = list_repo_commits(f"{USER}/test-trainer-step", token=self._token) commits = [c.title for c in commits] self.assertIn("initial commit", commits) + # Some commits are skipped if nothing has changed + # We expect 1 commit per 5 epochs + 1 commit at the end + nb_empty_commits = len( + [record for record in logs.records if "Skipping to prevent empty commit." in record.message] + ) + nb_epoch_commits = len([commit for commit in commits if "Training in progress, step" in commit]) + # max_steps depend on the number of available GPUs max_steps = math.ceil(trainer.args.num_train_epochs * len(trainer.get_train_dataloader())) - for i in range(5, max_steps, 5): - self.assertIn(f"Training in progress, step {i}", commits) + nb_expected_commits = len(range(5, max_steps, 5)) + + # '>=' since final commit might be an empty commit as well (not deterministic) + self.assertGreaterEqual(nb_empty_commits + nb_epoch_commits, nb_expected_commits) @require_tensorboard def test_push_to_hub_with_tensorboard_logs(self): @@ -2423,7 +2857,7 @@ def hp_name(trial): output_dir=tmp_dir, learning_rate=0.1, logging_steps=1, - evaluation_strategy=IntervalStrategy.EPOCH, + eval_strategy=IntervalStrategy.EPOCH, save_strategy=IntervalStrategy.EPOCH, num_train_epochs=4, disable_tqdm=True, @@ -2472,7 +2906,7 @@ def compute_objective(metrics: Dict[str, float]) -> List[float]: output_dir=tmp_dir, learning_rate=0.1, logging_steps=1, - evaluation_strategy=IntervalStrategy.EPOCH, + eval_strategy=IntervalStrategy.EPOCH, save_strategy=IntervalStrategy.EPOCH, num_train_epochs=10, disable_tqdm=True, @@ -2531,7 +2965,7 @@ def compute_objective(metrics: Dict[str, float]) -> List[float]: # output_dir=tmp_dir, # learning_rate=0.1, # logging_steps=1, -# evaluation_strategy=IntervalStrategy.EPOCH, +# eval_strategy=IntervalStrategy.EPOCH, # save_strategy=IntervalStrategy.EPOCH, # num_train_epochs=4, # disable_tqdm=True, @@ -2594,7 +3028,7 @@ def compute_objective(metrics: Dict[str, float]) -> List[float]: # output_dir=tmp_dir, # learning_rate=0.1, # logging_steps=1, -# evaluation_strategy=IntervalStrategy.EPOCH, +# eval_strategy=IntervalStrategy.EPOCH, # save_strategy=IntervalStrategy.EPOCH, # num_train_epochs=4, # disable_tqdm=True, @@ -2709,7 +3143,7 @@ def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs): # output_dir=tmp_dir, # learning_rate=0.1, # logging_steps=1, -# evaluation_strategy=IntervalStrategy.EPOCH, +# eval_strategy=IntervalStrategy.EPOCH, # save_strategy=IntervalStrategy.EPOCH, # num_train_epochs=4, # disable_tqdm=True, @@ -2735,50 +3169,56 @@ def test_hyperparameter_search_backends(self): class OptimizerAndModelInspectionTest(unittest.TestCase): def test_get_num_trainable_parameters(self): model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32)) - args = GaudiTrainingArguments( - output_dir="tmp_trainer", - use_habana=True, - use_lazy_mode=True, - ) # in_features * out_features + bias layer_1 = 128 * 64 + 64 layer_2 = 64 * 32 + 32 - trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) - self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) - # Freeze the last layer - for param in model[-1].parameters(): - param.requires_grad = False - self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) + with tempfile.TemporaryDirectory() as tmp_dir: + args = GaudiTrainingArguments( + output_dir=tmp_dir, + use_habana=True, + use_lazy_mode=True, + report_to="none", + ) + trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2) + # Freeze the last layer + for param in model[-1].parameters(): + param.requires_grad = False + self.assertEqual(trainer.get_num_trainable_parameters(), layer_1) def test_get_learning_rates(self): model = nn.Sequential(nn.Linear(128, 64)) - args = GaudiTrainingArguments( - output_dir="tmp_trainer", - use_habana=True, - use_lazy_mode=True, - ) - trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) - with self.assertRaises(ValueError): - trainer.get_learning_rates() - trainer.create_optimizer() - self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) + with tempfile.TemporaryDirectory() as tmp_dir: + args = GaudiTrainingArguments( + output_dir=tmp_dir, + use_habana=True, + use_lazy_mode=True, + report_to="none", + ) + trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) + with self.assertRaises(ValueError): + trainer.get_learning_rates() + trainer.create_optimizer() + self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05]) def test_get_optimizer_group(self): model = nn.Sequential(nn.Linear(128, 64)) - args = GaudiTrainingArguments( - output_dir="tmp_trainer", - use_habana=True, - use_lazy_mode=True, - ) - trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) - # ValueError is raised if optimizer is None - with self.assertRaises(ValueError): - trainer.get_optimizer_group() - trainer.create_optimizer() - # Get groups - num_groups = len(trainer.get_optimizer_group()) - self.assertEqual(num_groups, 2) - # Get group of parameter - param = next(model.parameters()) - group = trainer.get_optimizer_group(param) - self.assertIn(param, group["params"]) + with tempfile.TemporaryDirectory() as tmp_dir: + args = GaudiTrainingArguments( + output_dir=tmp_dir, + use_habana=True, + use_lazy_mode=True, + report_to="none", + ) + trainer = GaudiTrainer(model=model, gaudi_config=get_gaudi_config(), args=args) + # ValueError is raised if optimizer is None + with self.assertRaises(ValueError): + trainer.get_optimizer_group() + trainer.create_optimizer() + # Get groups + num_groups = len(trainer.get_optimizer_group()) + self.assertEqual(num_groups, 2) + # Get group of parameter + param = next(model.parameters()) + group = trainer.get_optimizer_group(param) + self.assertIn(param, group["params"]) diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py index 84413f9022..abecf284c4 100644 --- a/tests/test_trainer_distributed.py +++ b/tests/test_trainer_distributed.py @@ -85,6 +85,7 @@ def _test_gaudi_trainer_distributed(self, kwargs={}): command_list += [output_dir] command_list += ["--use_habana"] command_list += ["--use_lazy_mode"] + command_list += ["--report_to none"] for key, value in kwargs.items(): command_list += [f"--{key} {value}"] command = [" ".join(command_list)] diff --git a/tests/test_trainer_seq2seq.py b/tests/test_trainer_seq2seq.py index 165ae0dcee..cb1d5811aa 100644 --- a/tests/test_trainer_seq2seq.py +++ b/tests/test_trainer_seq2seq.py @@ -32,8 +32,8 @@ class GaudiSeq2seqTrainerTester(TestCasePlus): @require_torch def test_finetune_t5(self): - train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") - val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]") + train_dataset = datasets.load_dataset("abisee/cnn_dailymail", "3.0.0", split="train[:1%]") + val_dataset = datasets.load_dataset("abisee/cnn_dailymail", "3.0.0", split="validation[:1%]") train_dataset = train_dataset.select(range(32)) val_dataset = val_dataset.select(range(16)) @@ -51,6 +51,7 @@ def test_finetune_t5(self): use_habana=True, use_lazy_mode=True, use_hpu_graphs_for_inference=True, + report_to="none", ) model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5-v1.1") @@ -145,6 +146,7 @@ def test_bad_generation_config_fail_early(self): generation_config=gen_config, use_habana=True, use_lazy_mode=True, + report_to="none", ) with self.assertRaises(ValueError) as exc: _ = GaudiSeq2SeqTrainer( diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index 8ffbad89d1..000d4046c0 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -55,11 +55,13 @@ DisjunctiveConstraint, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + GenerateBeamDecoderOnlyOutput, + GenerateBeamEncoderDecoderOutput, + GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, HammingDiversityLogitsProcessor, - InfNanRemoveLogitsProcessor, LogitsProcessorList, MaxLengthCriteria, MinLengthLogitsProcessor, @@ -86,6 +88,7 @@ class GenerationTesterMixin: model_tester = None all_generative_model_classes = () input_name = "input_ids" + max_new_tokens = 3 def _update_default_model_kwargs(self, model_kwargs): model_kwargs["limit_hpu_graphs"] = False @@ -278,7 +281,6 @@ def _greedy_generate( max_length=max_length, ) - kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( @@ -295,31 +297,7 @@ def _greedy_generate( **model_kwargs, ) - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - output_greedy = model.greedy_search( - input_ids, - max_length=max_length, - logits_processor=logits_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_greedy, output_generate + return output_generate def _sample_generate( self, @@ -357,41 +335,7 @@ def _sample_generate( **model_kwargs, ) - torch.manual_seed(0) - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=num_return_sequences, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - - # prevent flaky generation test failures - logits_processor.append(InfNanRemoveLogitsProcessor()) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - output_sample = model.sample( - input_ids.repeat_interleave(num_return_sequences, dim=0), - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - - return output_sample, output_generate + return output_generate def _beam_search_generate( self, @@ -425,37 +369,7 @@ def _beam_search_generate( **model_kwargs, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - output_beam_search = model.beam_search( - input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), - beam_scorer, - max_length=max_length, - logits_processor=logits_processor, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_generate, output_beam_search + return output_generate def _beam_sample_generate( self, @@ -488,44 +402,7 @@ def _beam_sample_generate( **logits_warper_kwargs, **model_kwargs, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams` - torch.manual_seed(0) - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - - # prevent flaky generation test failures - logits_processor = LogitsProcessorList() - logits_processor.append(InfNanRemoveLogitsProcessor()) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - output_beam_sample = model.beam_sample( - input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), - beam_scorer, - max_length=max_length, - logits_warper=logits_warper, - logits_processor=logits_processor, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - - return output_generate, output_beam_sample + return output_generate def _group_beam_search_generate( self, @@ -558,37 +435,7 @@ def _group_beam_search_generate( **model_kwargs, ) - # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - output_group_beam_search = model.group_beam_search( - input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), - beam_scorer, - max_length=max_length, - logits_processor=logits_processor, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_generate, output_group_beam_search + return output_generate def _constrained_beam_search_generate( self, @@ -624,37 +471,7 @@ def _constrained_beam_search_generate( **model_kwargs, ) - # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - num_interleave=constrained_beam_scorer.num_beams, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - output_group_beam_search = model.constrained_beam_search( - input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0), - constrained_beam_scorer, - max_length=max_length, - logits_processor=logits_processor, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - ) - return output_generate, output_group_beam_search + return output_generate def _contrastive_generate( self, @@ -682,7 +499,6 @@ def _contrastive_generate( max_length=max_length, ) - kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) model.generation_config.static_shapes = self._get_static_shapes() @@ -701,33 +517,7 @@ def _contrastive_generate( **contrastive_search_kwargs, ) - if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, - input_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - kwargs["encoder_outputs"] = encoder_outputs - - with torch.no_grad(): - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - self._update_default_model_kwargs(model_kwargs) - stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)]) - output_contrastive = model.contrastive_search( - input_ids, - stopping_criteria=stopping_criteria, - logits_processor=logits_processor, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - **model_kwargs, - **contrastive_search_kwargs, - ) - return output_contrastive, output_generate + return output_generate def test_greedy_generate(self): # check `generate()` and `greedy_search()` are equal @@ -735,10 +525,13 @@ def test_greedy_generate(self): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) - self.assertListEqual(output_greedy.tolist(), output_generate.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: @@ -746,7 +539,7 @@ def test_greedy_generate_dict_outputs(self): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -758,16 +551,17 @@ def test_greedy_generate_dict_outputs(self): ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) else: - self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) - - for output in (output_greedy, output_generate): - self._check_outputs(output, input_ids, model.config) + self._check_outputs(output_generate, input_ids, model.config) def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: @@ -781,7 +575,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_greedy, output_generate = self._greedy_generate( + output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -792,10 +586,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): return_dict_in_generate=True, ) - self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) - - for output in (output_greedy, output_generate): - self._check_outputs(output, input_ids, model.config, use_cache=True) + self._check_outputs(output_generate, input_ids, model.config, use_cache=True) def test_sample_generate(self): for model_class in self.all_generative_model_classes: @@ -814,8 +605,7 @@ def test_sample_generate(self): ) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) - # check `generate()` and `sample()` are equal - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -826,21 +616,11 @@ def test_sample_generate(self): logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, ) - self.assertListEqual(output_sample.tolist(), output_generate.tolist()) - # check `generate()` and `sample()` yield equal results for `num_return_sequences` - output_sample, output_generate = self._sample_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_return_sequences=3, - logits_processor=logits_processor, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, - ) - self.assertListEqual(output_sample.tolist(), output_generate.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: @@ -860,7 +640,7 @@ def test_sample_generate_dict_output(self): ) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - output_sample, output_generate = self._sample_generate( + output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -877,16 +657,17 @@ def test_sample_generate_dict_output(self): ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) else: - self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist()) - - for output in (output_sample, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=2) + self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2) def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: @@ -911,8 +692,7 @@ def test_beam_search_generate(self): ) beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - # check `generate()` and `beam_search()` are equal - output_generate, output_beam_search = self._beam_search_generate( + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -923,13 +703,16 @@ def test_beam_search_generate(self): logits_processor=logits_processor, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) if model.config.is_encoder_decoder: max_length = 4 beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_generate, output_beam_search = self._beam_search_generate( + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -939,7 +722,10 @@ def test_beam_search_generate(self): logits_process_kwargs=logits_process_kwargs, logits_processor=logits_processor, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: @@ -966,7 +752,7 @@ def test_beam_search_generate_dict_output(self): max_length, ) beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_generate, output_beam_search = self._beam_search_generate( + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -981,21 +767,27 @@ def test_beam_search_generate_dict_output(self): return_dict_in_generate=True, ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) - self.assertTrue( - torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) - ) self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - for output in (output_beam_search, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] + ) def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: @@ -1029,7 +821,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self): config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_beam, output_generate = self._beam_search_generate( + output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1044,12 +836,13 @@ def test_beam_search_generate_dict_outputs_use_cache(self): return_dict_in_generate=True, ) - self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist()) - - for output in (output_beam, output_generate): - self._check_outputs( - output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams - ) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self._check_outputs( + output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams + ) @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") def test_beam_sample_generate(self): @@ -1071,7 +864,7 @@ def test_beam_sample_generate(self): max_length = 4 beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_generate, output_beam_sample = self._beam_sample_generate( + output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1081,7 +874,23 @@ def test_beam_sample_generate(self): logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): + input_embeds = model.get_input_embeddings()(input_ids) + beam_kwargs.update({"inputs_embeds": input_embeds}) + output_generate2 = self._beam_sample_generate( + model=model, + input_ids=None, + attention_mask=attention_mask, + beam_kwargs=beam_kwargs, + logits_warper_kwargs=logits_warper_kwargs, + ) + + torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") def test_beam_sample_generate_dict_output(self): @@ -1104,7 +913,7 @@ def test_beam_sample_generate_dict_output(self): max_length = 4 beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_beam_sample, output_generate = self._beam_sample_generate( + output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1119,22 +928,23 @@ def test_beam_sample_generate_dict_output(self): return_dict_in_generate=True, ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) - self.assertTrue( - torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) - self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - - for output in (output_beam_sample, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) def test_generate_without_input_ids(self): config, _, _, max_length = self._get_input_ids_and_config() @@ -1310,7 +1120,7 @@ def test_constrained_beam_search_generate(self): beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( input_ids.shape[0], max_length, constraints, num_return_sequences=1 ) - output_generate, output_beam_search = self._constrained_beam_search_generate( + output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1321,11 +1131,12 @@ def test_constrained_beam_search_generate(self): logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) + for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) - # check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences` + # check`constrained_beam_search` for higher than 1 `num_return_sequences` # Sample constraints force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ @@ -1339,7 +1150,7 @@ def test_constrained_beam_search_generate(self): input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences ) - output_generate, output_beam_search = self._constrained_beam_search_generate( + output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1350,7 +1161,7 @@ def test_constrained_beam_search_generate(self): logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, ) - self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) + self.assertTrue(output_generate.shape[-1] == max_length) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -1391,7 +1202,7 @@ def test_constrained_beam_search_generate_dict_output(self): beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( input_ids.shape[0], max_length, constraints, num_return_sequences=1 ) - output_generate, output_beam_search = self._constrained_beam_search_generate( + output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1406,24 +1217,23 @@ def test_constrained_beam_search_generate_dict_output(self): output_attentions=True, return_dict_in_generate=True, ) - + self.assertTrue(output_generate.sequences.shape[-1] == max_length) if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) - self.assertTrue( - torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - for output in (output_beam_search, output_generate): - self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) - def test_contrastive_generate(self): # check `generate()` and `contrastive_search()` are equal for model_class in self.all_generative_model_classes: @@ -1441,10 +1251,13 @@ def test_contrastive_generate(self): # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() - output_contrastive, output_generate = self._contrastive_generate( + output_generate = self._contrastive_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) - self.assertListEqual(output_contrastive.tolist(), output_generate.tolist()) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: @@ -1462,7 +1275,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_contrastive, output_generate = self._contrastive_generate( + output_generate = self._contrastive_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, @@ -1473,10 +1286,11 @@ def test_contrastive_generate_dict_outputs_use_cache(self): return_dict_in_generate=True, ) - self.assertListEqual(output_generate.sequences.tolist(), output_contrastive.sequences.tolist()) - - for output in (output_contrastive, output_generate): - self._check_outputs(output, input_ids, model.config, use_cache=True) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self._check_outputs(output_generate, input_ids, model.config, use_cache=True) def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output @@ -1519,8 +1333,47 @@ def test_contrastive_generate_low_memory(self): ) self.assertListEqual(low_output.tolist(), high_output.tolist()) - return + def test_contrastive_generate_dynamic_shapes(self): + # Check that choosing dynamic shapes does not change the model output + for model_class in self.all_generative_model_classes: + # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). + if any( + model_name in model_class.__name__.lower() + for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] + ): + return + + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + + # NOTE: contrastive search only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + + config.use_cache = True + config.is_decoder = True + + # test output equality of dynamic vs. static shapes + model = model_class(config).to(torch_device).eval() + model.generation_config.static_shapes = False + dynamic_output = model.generate( + input_ids, + top_k=4, + penalty_alpha=0.6, + max_length=max_length, + attention_mask=attention_mask, + ) + + model.generation_config.static_shapes = True + static_output = model.generate( + input_ids, + top_k=4, + penalty_alpha=0.6, + max_length=max_length, + attention_mask=attention_mask, + ) + self.assertListEqual(dynamic_output.tolist(), static_output.tolist()) + # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana") @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. def test_assisted_decoding_matches_greedy_search(self): @@ -1594,6 +1447,7 @@ def test_assisted_decoding_matches_greedy_search(self): for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) + # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana") def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not diff --git a/tests/transformers/tests/models/albert/test_modeling_albert.py b/tests/transformers/tests/models/albert/test_modeling_albert.py index 224fb35056..88a1dfb11f 100644 --- a/tests/transformers/tests/models/albert/test_modeling_albert.py +++ b/tests/transformers/tests/models/albert/test_modeling_albert.py @@ -41,7 +41,6 @@ AlbertForTokenClassification, AlbertModel, ) - from transformers.models.albert.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST class AlbertModelTester: @@ -325,9 +324,9 @@ def test_model_various_embeddings(self): @slow def test_model_from_pretrained(self): - for model_name in ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = AlbertModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "albert/albert-base-v1" + model = AlbertModel.from_pretrained(model_name) + self.assertIsNotNone(model) @require_torch diff --git a/tests/transformers/tests/models/bert/test_modeling_bert.py b/tests/transformers/tests/models/bert/test_modeling_bert.py index d0639e1347..0deee635a8 100644 --- a/tests/transformers/tests/models/bert/test_modeling_bert.py +++ b/tests/transformers/tests/models/bert/test_modeling_bert.py @@ -45,7 +45,6 @@ BertModel, logging, ) - from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST class BertModelTester: @@ -599,9 +598,9 @@ def test_for_warning_if_padding_and_no_attention_mask(self): @slow def test_model_from_pretrained(self): - for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = BertModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "google-bert/bert-base-uncased" + model = BertModel.from_pretrained(model_name) + self.assertIsNotNone(model) @slow @require_torch_gpu diff --git a/tests/transformers/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/transformers/tests/models/bridgetower/test_modeling_bridgetower.py index 453e9ad140..db339f6c1b 100644 --- a/tests/transformers/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/transformers/tests/models/bridgetower/test_modeling_bridgetower.py @@ -55,7 +55,6 @@ BridgeTowerForMaskedLM, BridgeTowerModel, ) - from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): from PIL import Image @@ -364,9 +363,9 @@ def test_for_masked_language_modeling(self): @slow def test_model_from_pretrained(self): - for model_name in BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = BridgeTowerModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "BridgeTower/bridgetower-base" + model = BridgeTowerModel.from_pretrained(model_name) + self.assertIsNotNone(model) @slow def test_save_load_fast_init_from_base(self): diff --git a/tests/transformers/tests/models/distilbert/test_modeling_distilbert.py b/tests/transformers/tests/models/distilbert/test_modeling_distilbert.py index e54b2a4c8f..e00cc54cfa 100644 --- a/tests/transformers/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/transformers/tests/models/distilbert/test_modeling_distilbert.py @@ -31,7 +31,6 @@ if is_torch_available(): import torch from transformers import ( - DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, DistilBertForMaskedLM, DistilBertForMultipleChoice, DistilBertForQuestionAnswering, @@ -262,9 +261,9 @@ def test_for_multiple_choice(self): @slow def test_model_from_pretrained(self): - for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = DistilBertModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "distilbert-base-uncased" + model = DistilBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) @slow @require_torch_gpu diff --git a/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py b/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py index 20047d6cb5..a978dfa98c 100644 --- a/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py @@ -38,7 +38,6 @@ if is_torch_available(): import torch from transformers import ( - GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2DoubleHeadsModel, GPT2ForQuestionAnswering, GPT2ForSequenceClassification, @@ -694,9 +693,9 @@ def test_batch_generation_2heads(self): @slow def test_model_from_pretrained(self): - for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = GPT2Model.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "openai-community/gpt2" + model = GPT2Model.from_pretrained(model_name) + self.assertIsNotNone(model) @require_torch diff --git a/tests/transformers/tests/models/gptj/test_modeling_gptj.py b/tests/transformers/tests/models/gptj/test_modeling_gptj.py index 4271079915..b416c02720 100644 --- a/tests/transformers/tests/models/gptj/test_modeling_gptj.py +++ b/tests/transformers/tests/models/gptj/test_modeling_gptj.py @@ -37,7 +37,6 @@ if is_torch_available(): import torch from transformers import ( - GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, AutoTokenizer, GPTJForCausalLM, GPTJForQuestionAnswering, @@ -523,9 +522,9 @@ def test_batch_generation(self): @slow def test_model_from_pretrained(self): - for model_name in GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16) - self.assertIsNotNone(model) + model_name = "EleutherAI/gpt-j-6B" + model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16) + self.assertIsNotNone(model) @require_torch @@ -631,7 +630,7 @@ def test_contrastive_search_gptj(self): tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") model = GPTJForCausalLM.from_pretrained( - "EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16 + "EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.bfloat16 ).to(torch_device) input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) @@ -642,17 +641,17 @@ def test_contrastive_search_gptj(self): generated_text, [ "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " - "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, " - "United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, " - "Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google's " - "parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating " - "a company that would apply deep learning to problems in healthcare, energy, transportation, and " - "other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 " - "million in cash and stock.[3] The acquisition was seen as a way for Google to enter the " - "fast-growing field of artificial intelligence (AI), which it had so far avoided due to concerns " - 'about ethical and social implications.[4] Google co-founder Sergey Brin said that he was "thrilled" ' - 'to have acquired DeepMind, and that it would "help us push the boundaries of AI even further."' - "[5]\n\nDeepMind's founders, Demis Hassabis and Mustafa Suleyman, were joined by a number of Google " - "employees" + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London " + "and has offices in New York, San Francisco, Cambridge, London, Paris, Tokyo, Beijing, Seoul, " + "Singapore, Sydney, and Mountain View.[1]\n\nContents\n\nIn 2010, Google's parent company, " + "Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that " + "would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2] " + "The investment was led by Founders Fund, a venture capital firm that invests in early-stage " + "start-ups, and the London-based venture capital firm Atomico.[3]\n\nOn April 23, 2014, Google " + "announced that it had acquired DeepMind for $400 million in cash and stock.[4] The acquisition was " + "seen as a way for Google to gain access to the company's expertise in machine learning and " + "artificial intelligence (AI), which it could apply to a range of products and services at Google.[5] " + 'Google CEO Larry Page said that the acquisition would "make Google a leader in this new field and ' + "help answer some of the most challenging questions we face as a society—" ], ) diff --git a/tests/transformers/tests/models/llama/test_modeling_llama.py b/tests/transformers/tests/models/llama/test_modeling_llama.py index 2c505b6811..500c3df7c8 100644 --- a/tests/transformers/tests/models/llama/test_modeling_llama.py +++ b/tests/transformers/tests/models/llama/test_modeling_llama.py @@ -17,10 +17,11 @@ import unittest from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available +from transformers import is_torch_available from transformers.testing_utils import require_torch, slow from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from optimum.habana.transformers.models.llama.configuration_llama import LlamaConfig from optimum.habana.utils import set_seed from ...generation.test_utils import GenerationTesterMixin diff --git a/tests/transformers/tests/models/roberta/test_modeling_roberta.py b/tests/transformers/tests/models/roberta/test_modeling_roberta.py index 3d17c13d6a..197a83509c 100644 --- a/tests/transformers/tests/models/roberta/test_modeling_roberta.py +++ b/tests/transformers/tests/models/roberta/test_modeling_roberta.py @@ -41,7 +41,6 @@ RobertaModel, ) from transformers.models.roberta.modeling_roberta import ( - ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, RobertaEmbeddings, create_position_ids_from_input_ids, ) @@ -480,9 +479,9 @@ def test_for_question_answering(self): @slow def test_model_from_pretrained(self): - for model_name in ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = RobertaModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "FacebookAI/roberta-base" + model = RobertaModel.from_pretrained(model_name) + self.assertIsNotNone(model) def test_create_position_ids_respects_padding_index(self): """Ensure that the default position ids only assign a sequential . This is a regression diff --git a/tests/transformers/tests/models/swin/test_modeling_swin.py b/tests/transformers/tests/models/swin/test_modeling_swin.py index 6ab27dc4d3..da5f6204b1 100644 --- a/tests/transformers/tests/models/swin/test_modeling_swin.py +++ b/tests/transformers/tests/models/swin/test_modeling_swin.py @@ -35,7 +35,6 @@ import torch from torch import nn from transformers import SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel - from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -460,9 +459,9 @@ def test_hidden_states_output_with_padding(self): @slow def test_model_from_pretrained(self): - for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = SwinModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "microsoft/swin-tiny-patch4-window7-224" + model = SwinModel.from_pretrained(model_name) + self.assertIsNotNone(model) def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/transformers/tests/models/t5/test_modeling_t5.py b/tests/transformers/tests/models/t5/test_modeling_t5.py index caf2915acf..9883907f6e 100644 --- a/tests/transformers/tests/models/t5/test_modeling_t5.py +++ b/tests/transformers/tests/models/t5/test_modeling_t5.py @@ -49,7 +49,6 @@ T5Model, T5Tokenizer, ) - from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST torch_device = "hpu" @@ -816,9 +815,9 @@ def test_v1_1_resize_embeddings(self): @slow def test_model_from_pretrained(self): - for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = T5Model.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "google-t5/t5-small" + model = T5Model.from_pretrained(model_name) + self.assertIsNotNone(model) @unittest.skip("Test has a segmentation fault on torch 1.8.0") def test_export_to_onnx(self): @@ -1422,7 +1421,10 @@ def test_translation_en_to_ro(self): translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) self.assertEqual(translation, expected_translation) + # contrastive search is not supported and expected to fail + # In earlier versions it was passing because it was going down default implementation, and it just happened to pass @slow + @pytest.mark.xfail(reason="contrastive search is not implemented", raises=NotImplementedError) def test_contrastive_search_t5(self): article = ( " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" diff --git a/tests/transformers/tests/models/vit/test_modeling_vit.py b/tests/transformers/tests/models/vit/test_modeling_vit.py index 6f8e7ad41d..291abecef2 100644 --- a/tests/transformers/tests/models/vit/test_modeling_vit.py +++ b/tests/transformers/tests/models/vit/test_modeling_vit.py @@ -37,7 +37,6 @@ import torch from torch import nn from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel - from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -245,9 +244,9 @@ def test_for_image_classification(self): @slow def test_model_from_pretrained(self): - for model_name in VIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = ViTModel.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "google/vit-base-patch16-224" + model = ViTModel.from_pretrained(model_name) + self.assertIsNotNone(model) # We will verify our results on an image of cute cats diff --git a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py index 62c3b075c5..ce8e18d78a 100644 --- a/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/transformers/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -1257,12 +1257,8 @@ def get_logits(model, input_features): pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng")) torch.save(adapter_weights, pt_filepath) - # model.load_adapter is broken in transformers - # since adapter_weights fails to load with weights_only=True - with self.assertRaises(OSError): - model.load_adapter("eng") - with self.assertRaises(OSError): - model.load_adapter("eng", use_safetensors=False) + model.load_adapter("eng") + model.load_adapter("eng", use_safetensors=False) # we will load adapter_weights directly while model.load_adapter fails state_dict = torch.load(pt_filepath) state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()} diff --git a/tests/transformers/tests/test_modeling_common.py b/tests/transformers/tests/test_modeling_common.py index 900466aaa2..e981ed2855 100755 --- a/tests/transformers/tests/test_modeling_common.py +++ b/tests/transformers/tests/test_modeling_common.py @@ -853,7 +853,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) + from optimum.habana.transformers.models import GaudiGPT2DoubleHeadsModel + + traced_model = symbolic_trace( + model, input_names, disable_check=isinstance(model, GaudiGPT2DoubleHeadsModel) + ) traced_output = traced_model(**filtered_inputs) model_output = model(**filtered_inputs)