diff --git a/docs/classes/models/plbart.rst b/docs/classes/models/plbart.rst new file mode 100644 index 0000000000..69227f8f2e --- /dev/null +++ b/docs/classes/models/plbart.rst @@ -0,0 +1,20 @@ +PLBART +===== + +The PLBART model was proposed in [Unified Pre-training for Program Understanding and Generation](https://arxiv.org/abs/2103.06333) by Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang. +This is a BART-like model which can be used to perform code-summarization, code-generation, and code-translation tasks. The pre-trained model `plbart-base` has been trained using multilingual denoising task +on Java, Python and English. + +According to the abstract, + +- PLBART is a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks +- PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding. +- PLBART learns program syntax, style (e.g., identifier naming convention) and logical flow. + + +PLBartAdapterModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.PLBartAdapterModel + :members: + :inherited-members: PLBartPretrainedModel diff --git a/docs/index.rst b/docs/index.rst index b78a249c64..29ef772fdf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -82,6 +82,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/llama classes/models/mbart classes/models/mt5 + classes/models/plbart classes/models/roberta classes/models/t5 classes/models/vit diff --git a/docs/model_overview.md b/docs/model_overview.md index 38e9cbb681..4cc164bc8f 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -28,6 +28,7 @@ The table below further shows which model architectures support which adaptation | [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [MT5](classes/models/mt5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | +| [PLBart](classes/models/plbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/examples/pytorch/adapterfusion/run_fusion_glue.py b/examples/pytorch/adapterfusion/run_fusion_glue.py index 089bee41a2..fbca2daa81 100644 --- a/examples/pytorch/adapterfusion/run_fusion_glue.py +++ b/examples/pytorch/adapterfusion/run_fusion_glue.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for sequence classification on +"""Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" diff --git a/examples/pytorch/dependency-parsing/preprocessing.py b/examples/pytorch/dependency-parsing/preprocessing.py index 2188aab4fb..90f1012549 100644 --- a/examples/pytorch/dependency-parsing/preprocessing.py +++ b/examples/pytorch/dependency-parsing/preprocessing.py @@ -3,6 +3,7 @@ Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) https://arxiv.org/abs/2012.15613 """ + from collections import defaultdict from typing import List diff --git a/examples/pytorch/dependency-parsing/run_udp.py b/examples/pytorch/dependency-parsing/run_udp.py index 43bcaff6a4..2a9cddb8b1 100644 --- a/examples/pytorch/dependency-parsing/run_udp.py +++ b/examples/pytorch/dependency-parsing/run_udp.py @@ -3,6 +3,7 @@ Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) https://arxiv.org/abs/2012.15613 """ + import logging import os import sys @@ -156,9 +157,11 @@ def main(): use_fast=model_args.use_fast, do_lower_case=model_args.do_lower_case, add_prefix_space=True, # Used e.g. for RoBERTa - mecab_kwargs={"mecab_option": f"-r {model_args.mecab_dir} -d {model_args.mecab_dic_dir}"} - if model_args.is_japanese - else None, + mecab_kwargs=( + {"mecab_option": f"-r {model_args.mecab_dir} -d {model_args.mecab_dic_dir}"} + if model_args.is_japanese + else None + ), ) # The task name (with prefix) @@ -244,9 +247,11 @@ def main(): if adapter_args.train_adapter: adapter_config = AdapterConfig.load(adapter_args.adapter_config, **adapter_config_kwargs) model.load_adapter( - os.path.join(training_args.output_dir, "best_model", task_name) - if training_args.do_train - else adapter_args.load_adapter, + ( + os.path.join(training_args.output_dir, "best_model", task_name) + if training_args.do_train + else adapter_args.load_adapter + ), config=adapter_config, load_as=task_name, **adapter_load_kwargs, @@ -254,9 +259,11 @@ def main(): if adapter_args.load_lang_adapter: lang_adapter_config = AdapterConfig.load(adapter_args.lang_adapter_config, **adapter_config_kwargs) lang_adapter_name = model.load_adapter( - os.path.join(training_args.output_dir, "best_model", lang_adapter_name) - if training_args.do_train - else adapter_args.load_lang_adapter, + ( + os.path.join(training_args.output_dir, "best_model", lang_adapter_name) + if training_args.do_train + else adapter_args.load_lang_adapter + ), config=lang_adapter_config, load_as=lang_adapter_name, **adapter_load_kwargs, diff --git a/examples/pytorch/dependency-parsing/utils_udp.py b/examples/pytorch/dependency-parsing/utils_udp.py index 3424638319..0eaa4f5d3d 100644 --- a/examples/pytorch/dependency-parsing/utils_udp.py +++ b/examples/pytorch/dependency-parsing/utils_udp.py @@ -3,6 +3,7 @@ Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) https://arxiv.org/abs/2012.15613 """ + import collections import logging import os diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 3094738317..17c28b88f6 100644 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -541,9 +541,9 @@ def compute_metrics(eval_preds): # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, + preprocess_logits_for_metrics=( + preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None + ), ) # Training diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index bf6de170ab..8dc451350b 100644 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -557,9 +557,9 @@ def compute_metrics(eval_preds): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, + preprocess_logits_for_metrics=( + preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None + ), ) # Training diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 4ed66c8aae..9f6a742862 100644 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for sequence classification on GLUE.""" +"""Finetuning the library models for sequence classification on GLUE.""" # You can also adapt this script on your own text classification task. Pointers for this are left as comments. import logging diff --git a/examples/pytorch/text-generation/run_generation.py b/examples/pytorch/text-generation/run_generation.py index 42a8d1fca4..fb42de6bc5 100644 --- a/examples/pytorch/text-generation/run_generation.py +++ b/examples/pytorch/text-generation/run_generation.py @@ -14,8 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) -""" +"""Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)""" import argparse diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index ddfb5b1895..b440221945 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -110,6 +110,7 @@ "models.llama": ["LlamaAdapterModel"], "models.mbart": ["MBartAdapterModel"], "models.mt5": ["MT5AdapterModel"], + "models.plbart": ["PLBartAdapterModel"], "models.roberta": ["RobertaAdapterModel"], "models.t5": ["T5AdapterModel"], "models.vit": ["ViTAdapterModel"], @@ -217,6 +218,7 @@ from .models.llama import LlamaAdapterModel from .models.mbart import MBartAdapterModel from .models.mt5 import MT5AdapterModel + from .models.plbart import PLBartAdapterModel from .models.roberta import RobertaAdapterModel from .models.t5 import T5AdapterModel from .models.vit import ViTAdapterModel diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 3c4f96830c..4c65f4a072 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -129,6 +129,7 @@ def __init__( "bart", "mbart", "mt5", + "plbart", "gpt2", "gptj", "t5", diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index ec78430e02..079334c1e2 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -369,6 +369,27 @@ }, "layers": ["lm_head"], }, + # PLBART + "PLBartForSequenceClassification": { + "config": { + "head_type": "classification", + "layers": 2, + "activation_function": "tanh", + }, + "layers": [ + None, + "classification_head.dense", + None, + None, + "classification_head.out_proj", + ], + }, + "PLBartForConditionalGeneration": { + "config": { + "head_type": "seq2seq_lm", + }, + "layers": ["lm_head"], + }, # MT5 "MT5ForConditionalGeneration": { "config": { diff --git a/src/adapters/heads/dependency_parsing.py b/src/adapters/heads/dependency_parsing.py index d568f356b0..5d33820f45 100644 --- a/src/adapters/heads/dependency_parsing.py +++ b/src/adapters/heads/dependency_parsing.py @@ -2,6 +2,7 @@ Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) https://arxiv.org/abs/2012.15613 """ + from dataclasses import dataclass from typing import Optional, Tuple diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index b3125c6965..6d0a6257e2 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -195,9 +195,11 @@ def pad_and_concat(self, states: List[BottleneckState]) -> BottleneckState: torch.cat([state.input_tensor for state in states], dim=0), torch.cat([state.adapter_residual for state in states], dim=0), states[0].layer_norm, - torch.cat([state.bottleneck_up for state in states], dim=0) - if states[0].bottleneck_up is not None - else None, + ( + torch.cat([state.bottleneck_up for state in states], dim=0) + if states[0].bottleneck_up is not None + else None + ), states[-1].last, ) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index e54042b557..a4b38f145e 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -408,9 +408,11 @@ def repeat(self, state: LoRAState, channels: int) -> LoRAState: def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState: return LoRAState( states[0].layer_input, - torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0) - if states[0].hidden_states is not None - else None, + ( + torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0) + if states[0].hidden_states is not None + else None + ), states[0].layer_output, states[-1].last, ) diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index ff19c38f3b..55bff73de8 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -18,6 +18,12 @@ from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin +from .plbart.mixin_plbart import ( + PLBartDecoderAdaptersMixin, + PLBartDecoderWrapperAdaptersMixin, + PLBartEncoderAdaptersMixin, + PLBartModelAdaptersMixin, +) from .t5.mixin_t5 import ( T5BlockAdaptersMixin, T5ForCondiditionalGenerationWithHeadsMixin, @@ -33,8 +39,8 @@ "AlbertModel": AlbertModelAdaptersMixin, "BartEncoder": BartEncoderAdaptersMixin, "BartDecoder": BartDecoderAdaptersMixin, - "BartModel": BartModelAdaptersMixin, "BartDecoderWrapper": BartDecoderWrapperAdaptersMixin, + "BartModel": BartModelAdaptersMixin, "BeitIntermediate": BeitIntermediateAdaptersMixin, "BeitOutput": BeitOutputAdaptersMixin, "BeitModel": BeitModelAdaptersMixin, @@ -60,6 +66,10 @@ "MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin, "MT5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin, "MT5EncoderModel": T5ModelAdaptersMixin, + "PLBartEncoder": PLBartEncoderAdaptersMixin, + "PLBartDecoder": PLBartDecoderAdaptersMixin, + "PLBartModel": PLBartModelAdaptersMixin, + "PLBartDecoderWrapper": PLBartDecoderWrapperAdaptersMixin, "GPT2Model": GPT2ModelAdapterMixin, "GPTJMLP": GPTJMLPAdaptersMixin, "GPTJModel": GPTJModelAdapterMixin, diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 2d59c6da44..31dfa00cff 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -24,6 +24,7 @@ ("llama", "LlamaAdapterModel"), ("mbart", "MBartAdapterModel"), ("mt5", "MT5AdapterModel"), + ("plbart", "PLBartAdapterModel"), ("roberta", "RobertaAdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index b347fddf07..080455b497 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch BART model.""" +"""PyTorch BART model.""" from typing import Optional, Tuple import torch diff --git a/src/adapters/models/beit/modeling_beit.py b/src/adapters/models/beit/modeling_beit.py index 1ed5082beb..bc67120d13 100644 --- a/src/adapters/models/beit/modeling_beit.py +++ b/src/adapters/models/beit/modeling_beit.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch BEiT model.""" +"""PyTorch BEiT model.""" import math diff --git a/src/adapters/models/clip/modeling_clip.py b/src/adapters/models/clip/modeling_clip.py index b74a0308ef..fecbb105c8 100644 --- a/src/adapters/models/clip/modeling_clip.py +++ b/src/adapters/models/clip/modeling_clip.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch CLIP model.""" +"""PyTorch CLIP model.""" from typing import Optional, Tuple diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 1feca72b4a..4380b5e038 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch DeBERTa model.""" +"""PyTorch DeBERTa model.""" import torch import torch.utils.checkpoint diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 56d6fec448..bc41ae82af 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch DeBERTa-v2 model.""" +"""PyTorch DeBERTa-v2 model.""" import torch import torch.utils.checkpoint diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py index cbd501942c..e59aa1ad50 100644 --- a/src/adapters/models/distilbert/modeling_distilbert.py +++ b/src/adapters/models/distilbert/modeling_distilbert.py @@ -14,8 +14,8 @@ # limitations under the License. """ - PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in - part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert) +PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in +part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert) """ diff --git a/src/adapters/models/encoder_decoder/modeling_encoder_decoder.py b/src/adapters/models/encoder_decoder/modeling_encoder_decoder.py index 43178898f6..1572087d98 100644 --- a/src/adapters/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/adapters/models/encoder_decoder/modeling_encoder_decoder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Classes to support Encoder-Decoder architectures""" +"""Classes to support Encoder-Decoder architectures""" from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel diff --git a/src/adapters/models/gptj/modeling_gptj.py b/src/adapters/models/gptj/modeling_gptj.py index 700e919a17..3880df12c0 100644 --- a/src/adapters/models/gptj/modeling_gptj.py +++ b/src/adapters/models/gptj/modeling_gptj.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch GPT-J model.""" +"""PyTorch GPT-J model.""" from typing import Optional, Tuple, Union diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 7c99f286e4..d9d8b2ebcc 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model.""" +"""PyTorch LLaMA model.""" import math import warnings from typing import Optional, Tuple diff --git a/src/adapters/models/mbart/modeling_mbart.py b/src/adapters/models/mbart/modeling_mbart.py index 0f8f0d5335..45bdceae25 100644 --- a/src/adapters/models/mbart/modeling_mbart.py +++ b/src/adapters/models/mbart/modeling_mbart.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch MBART model.""" +"""PyTorch MBART model.""" from typing import Optional, Tuple import torch diff --git a/src/adapters/models/mt5/modeling_mt5.py b/src/adapters/models/mt5/modeling_mt5.py index 12ad630a74..b982d34d62 100644 --- a/src/adapters/models/mt5/modeling_mt5.py +++ b/src/adapters/models/mt5/modeling_mt5.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch MT5 model.""" +"""PyTorch MT5 model.""" import torch from torch import nn diff --git a/src/adapters/models/plbart/__init__.py b/src/adapters/models/plbart/__init__.py new file mode 100644 index 0000000000..1160ba151b --- /dev/null +++ b/src/adapters/models/plbart/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["PLBartAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import PLBartAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/plbart/adapter_model.py b/src/adapters/models/plbart/adapter_model.py new file mode 100644 index 0000000000..83f02183d0 --- /dev/null +++ b/src/adapters/models/plbart/adapter_model.py @@ -0,0 +1,162 @@ +import torch + +from transformers.models.plbart.modeling_plbart import ( + PLBART_INPUTS_DOCSTRING, + PLBART_START_DOCSTRING, + PLBartConfig, + PLBartModel, + PLBartPreTrainedModel, + shift_tokens_right, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward + +from ...heads import ModelWithFlexibleHeadsAdaptersMixin +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +@add_start_docstrings( + "PLBART Model with the option to add multiple flexible prediction heads on top.", PLBART_START_DOCSTRING +) +class PLBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, PLBartPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + + head_types = [ + "classification", + "multilabel_classification", + "question_answering", + "seq2seq_lm", + ] + + def __init__(self, config: PLBartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PLBartModel(config) + init(self.model) + + self._init_head_modules() + + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: + use_cache = False + + outputs, context = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context + + head_outputs = self.forward_head( + outputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + get_cls_from_eos_tokens=True, + # `get_cls_from_eos_tokens` requires passing eos mask + eos_mask=input_ids.eq(self.config.eos_token_id) if input_ids is not None else None, + **kwargs, + ) + + return head_outputs + + # Copied from PLBartForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), + } + + # Copied from PLBartForConditionalGeneration + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) # , self.config.decoder_start_token_id) + + # Copied from PLBartForConditionalGeneration + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/src/adapters/models/plbart/mixin_plbart.py b/src/adapters/models/plbart/mixin_plbart.py new file mode 100644 index 0000000000..bd02e04dea --- /dev/null +++ b/src/adapters/models/plbart/mixin_plbart.py @@ -0,0 +1,109 @@ +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn + +from ...composition import adjust_tensors_for_parallel +from ...methods.bottleneck import BottleneckLayer +from ...methods.lora import LoRALinear +from ...methods.prefix_tuning import PrefixTuningLayer +from ...model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + InvertibleAdaptersWrapperMixin, + ModelBaseAdaptersMixin, +) + + +class PLBartAttentionAdaptersMixin: + """Adds adapters to the BartAttention module.""" + + def init_adapters(self, model_config, adapters_config): + # Wrap layers for LoRA + self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") + self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") + self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") + + self.prefix_tuning = PrefixTuningLayer( + self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config + ) + + +class PLBartEncoderLayerAdaptersMixin: + """Adds adapters to the PLBartEncoderLayer module of PLBART.""" + + def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config + # Wrap layers for LoRA + self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config) + self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config) + + # Set attention layer location key for prefix tuning + self.self_attn.location_key = "encoder" + self.attention_adapters = BottleneckLayer("mh_adapter") + self.output_adapters = BottleneckLayer("output_adapter") + + +class PLBartDecoderLayerAdaptersMixin(PLBartEncoderLayerAdaptersMixin): + """Adds adapters to the PLBartDecoderLayer module of PLBART.""" + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + # Set attention layer location key for prefix tuning + self.self_attn.location_key = "self" + self.encoder_attn.location_key = "cross" + self.cross_attention_adapters = BottleneckLayer("cross_adapter") + + +class PLBartEncoderAdaptersMixin(InvertibleAdaptersMixin): + """Adds adapters to the PLBartEncoder module of PLBART.""" + + pass + + +class PLBartDecoderAdaptersMixin: + """Adds adapters to the PLBartDecoder module of PLBART.""" + + def forward( + self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs + ): + (input_ids,) = adjust_tensors_for_parallel(encoder_hidden_states, input_ids) + return super().forward(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, **kwargs) + + +class PLBartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """Adds adapters to the PLBartModel class.""" + + invertible_adapters_base_name = "encoder" + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + self.encoder.layernorm_embedding.register_forward_hook(self.post_embedding_forward) + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "encoder"): + for i, layer in enumerate(self.encoder.layers): + yield i, layer + for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): + yield i, layer + else: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + + +class PLBartDecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin): + """Adds adapters to the PLBartDecoderWrapper class.""" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() diff --git a/src/adapters/models/plbart/modeling_plbart.py b/src/adapters/models/plbart/modeling_plbart.py new file mode 100644 index 0000000000..2d812cae1d --- /dev/null +++ b/src/adapters/models/plbart/modeling_plbart.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PLBART model.""" +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.models.plbart.modeling_plbart import PLBartAttention, PLBartDecoderLayer, PLBartEncoderLayer +from transformers.utils import logging + +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel +from .mixin_plbart import ( + PLBartAttentionAdaptersMixin, + PLBartDecoderLayerAdaptersMixin, + PLBartEncoderLayerAdaptersMixin, +) + + +logger = logging.get_logger(__name__) + + +class PLBartAttentionWithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PLBartFlashAttention2WithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # PLBartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("PLBartFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class PLBartSdpaAttentionWithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "PLBartModel is using PLBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does" + " not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual" + " attention implementation, but specifying the manual implementation will be required from" + " Transformers version v5.0.0 onwards. This warning can be removed using the argument" + ' `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + query_states = self._shape(query_states, tgt_len, bsz) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +class PLBartEncoderLayerWithAdapters(PLBartEncoderLayerAdaptersMixin, PLBartEncoderLayer): + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + adjust_tensors_for_parallel_(hidden_states, attention_mask) + + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.attention_adapters(hidden_states, residual, self.self_attn_layer_norm) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.output_adapters(hidden_states, residual, self.final_layer_norm) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class PLBartDecoderLayerWithAdapters(PLBartDecoderLayerAdaptersMixin, PLBartDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + adjust_tensors_for_parallel_(hidden_states, attention_mask, encoder_attention_mask) + + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.attention_adapters(hidden_states, residual, self.self_attn_layer_norm) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.cross_attention_adapters(hidden_states, residual, self.encoder_attn_layer_norm) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.output_adapters(hidden_states, residual, self.final_layer_norm) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 03d9f27972..c98cfa477a 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch T5 model.""" +"""PyTorch T5 model.""" import torch from torch import nn diff --git a/src/adapters/models/vit/modeling_vit.py b/src/adapters/models/vit/modeling_vit.py index f8c02bd931..0a9d7a1b3e 100644 --- a/src/adapters/models/vit/modeling_vit.py +++ b/src/adapters/models/vit/modeling_vit.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch ViT model.""" +"""PyTorch ViT model.""" import math diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 7338f4c3ac..7e87c618d1 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -853,9 +853,9 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf return AdapterInfo( source="hf", adapter_id=model_info.modelId, - model_name=model_info.config.get("adapter_transformers", {}).get("model_name") - if model_info.config - else None, + model_name=( + model_info.config.get("adapter_transformers", {}).get("model_name") if model_info.config else None + ), username=model_info.modelId.split("/")[0], sha1_checksum=model_info.sha, ) diff --git a/src/adapters/wrappers/configuration.py b/src/adapters/wrappers/configuration.py index c49f3b8b7c..ed224cd600 100644 --- a/src/adapters/wrappers/configuration.py +++ b/src/adapters/wrappers/configuration.py @@ -46,6 +46,12 @@ "hidden_dropout_prob": "dropout", "attention_probs_dropout_prob": "attention_dropout", }, + "plbart": { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + "hidden_dropout_prob": "dropout", + "attention_probs_dropout_prob": "attention_dropout", + }, "roberta": {}, "t5": { "hidden_size": "d_model", diff --git a/tests/fixtures/samples/cifar10/cifar10.py b/tests/fixtures/samples/cifar10/cifar10.py index cd00f02603..052a203dff 100644 --- a/tests/fixtures/samples/cifar10/cifar10.py +++ b/tests/fixtures/samples/cifar10/cifar10.py @@ -1,6 +1,7 @@ """ CIFAR-10 demo data, adapted from https://huggingface.co/datasets/cifar10. """ + import os import pickle diff --git a/tests/models/test_plbart.py b/tests/models/test_plbart.py new file mode 100644 index 0000000000..7fbbfc38df --- /dev/null +++ b/tests/models/test_plbart.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import PLBartAdapterModel +from hf_transformers.tests.models.plbart.test_modeling_plbart import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class PLBartAdapterModelTest(AdapterModelTesterMixin, PLBartModelTest): + all_model_classes = (PLBartAdapterModel,) + fx_compatible = False diff --git a/tests/test_plbart.py b/tests/test_plbart.py new file mode 100644 index 0000000000..969b3b0fde --- /dev/null +++ b/tests/test_plbart.py @@ -0,0 +1,66 @@ +import unittest + +from tests.methods.test_config_union import ConfigUnionAdapterTest +from transformers import PLBartConfig +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class PLBartAdapterTestBase(AdapterTestBase): + config_class = PLBartConfig + config = make_config( + PLBartConfig, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + ) + tokenizer_name = "uclanlp/plbart-base" + + +@require_torch +class PLBartAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + EmbeddingTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + ConfigUnionAdapterTest, + PLBartAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class PLBartClassConversionTest( + ModelClassConversionTestMixin, + PLBartAdapterTestBase, + unittest.TestCase, +): + pass diff --git a/utils/back_comp/Utils.py b/utils/back_comp/Utils.py index 8ed482130c..21c15545f7 100644 --- a/utils/back_comp/Utils.py +++ b/utils/back_comp/Utils.py @@ -29,6 +29,7 @@ GPT2Config, GPTJConfig, MBartConfig, + PLBartConfig, RobertaConfig, T5Config, ViTConfig, @@ -130,6 +131,7 @@ def get_model_names(): "gpt2", "gptj", "mbart", + "plbart", "roberta", "t5", "vit", @@ -283,6 +285,19 @@ def create_model(model_name: str, model_class: Any) -> Any: ) model = model_class.from_config(mbart_config) + elif model_name == "plbart": + plbart_config = PLBartConfig( + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + vocab_size=50005, + ) + model = model_class.from_config(plbart_config) + elif model_name == "roberta": roberta_config = RobertaConfig( hidden_size=32, diff --git a/utils/convert_xmod_checkpoint.py b/utils/convert_xmod_checkpoint.py index 30ca0ede74..b3744fece6 100644 --- a/utils/convert_xmod_checkpoint.py +++ b/utils/convert_xmod_checkpoint.py @@ -1,6 +1,7 @@ """ This script can be used to convert an Xmod checkpoints (including adapters) from the HF format to the Adapters format. """ + import argparse import os import re