diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index 743fc940c0..4e0dfde84b 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -28,12 +28,26 @@ logger = logging.getLogger(__name__) +MODEL_HEAD_MAP = { + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + "tagging": TaggingHead, + "multiple_choice": MultipleChoiceHead, + "question_answering": QuestionAnsweringHead, + "dependency_parsing": BiaffineParsingHead, + "masked_lm": BertStyleMaskedLMHead, + "causal_lm": CausalLMHead, + "seq2seq_lm": Seq2SeqLMHead, + "image_classification": ImageClassificationHead, +} + + class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): """ Adds flexible prediction heads to a model class. Implemented by the XModelWithHeads classes. """ - head_types: dict = {} + head_types: list = [] use_pooler: bool = False def __init__(self, *args, **kwargs): @@ -157,7 +171,7 @@ def add_prediction_head_from_config( config["id2label"] = id2label if head_type in self.head_types: - head_class = self.head_types[head_type] + head_class = MODEL_HEAD_MAP[head_type] head = head_class(self, head_name, **config) self.add_prediction_head(head, overwrite_ok=overwrite_ok, set_active=set_active) elif head_type in self.config.custom_heads: diff --git a/src/adapters/models/albert/adapter_model.py b/src/adapters/models/albert/adapter_model.py index ea2835a493..8f6c07d47c 100644 --- a/src/adapters/models/albert/adapter_model.py +++ b/src/adapters/models/albert/adapter_model.py @@ -7,15 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -25,6 +17,16 @@ ALBERT_START_DOCSTRING, ) class AlbertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, AlbertPreTrainedModel): + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "multiple_choice", + "question_answering", + "masked_lm", + ] + use_pooler = True + def __init__(self, config): super().__init__(config) @@ -102,13 +104,3 @@ def forward( else: # in case no head is used just return the output of the base model (including pooler output) return outputs - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "masked_lm": BertStyleMaskedLMHead, - } - use_pooler = True diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index 8c81794d4b..dec5a838c2 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -10,13 +10,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...heads import ( - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - QuestionAnsweringHead, - Seq2SeqLMHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -30,6 +24,13 @@ class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdap "decoder.embed_tokens.weight", ] + head_types = [ + "classification", + "multilabel_classification", + "question_answering", + "seq2seq_lm", + ] + def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) self.model = BartModel(config) @@ -159,10 +160,3 @@ def _reorder_cache(past, beam_idx): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "question_answering": QuestionAnsweringHead, - "seq2seq_lm": Seq2SeqLMHead, - } diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 2bf3eabb76..5667fa098d 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -11,7 +11,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...wrappers import init @@ -20,6 +20,11 @@ BEIT_START_DOCSTRING, ) class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel): + head_types = [ + "image_classification", + ] + use_pooler = True + def __init__(self, config): super().__init__(config) @@ -82,8 +87,3 @@ def forward( else: # in case no head is used just return the output of the base model (including pooler output) return outputs - - head_types = { - "image_classification": ImageClassificationHead, - } - use_pooler = True diff --git a/src/adapters/models/bert/adapter_model.py b/src/adapters/models/bert/adapter_model.py index a51bf25107..0b8e189436 100644 --- a/src/adapters/models/bert/adapter_model.py +++ b/src/adapters/models/bert/adapter_model.py @@ -7,17 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - BiaffineParsingHead, - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -27,6 +17,18 @@ BERT_START_DOCSTRING, ) class BertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel): + + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "multiple_choice", + "question_answering", + "dependency_parsing", + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) @@ -121,14 +123,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "dependency_parsing": BiaffineParsingHead, - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - } diff --git a/src/adapters/models/bert_generation/adapter_model.py b/src/adapters/models/bert_generation/adapter_model.py index 66b38bba93..072c1b099a 100644 --- a/src/adapters/models/bert_generation/adapter_model.py +++ b/src/adapters/models/bert_generation/adapter_model.py @@ -7,7 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import BertStyleMaskedLMHead, CausalLMHead, ModelWithFlexibleHeadsAdaptersMixin +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -21,6 +21,11 @@ class BertGenerationAdapterModel( ): _keys_to_ignore_on_load_unexpected = [r"lm_head.bias"] + head_types = [ + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) @@ -118,8 +123,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - } diff --git a/src/adapters/models/deberta/adapter_model.py b/src/adapters/models/deberta/adapter_model.py index 1554b9dbd9..32ec9cd45f 100644 --- a/src/adapters/models/deberta/adapter_model.py +++ b/src/adapters/models/deberta/adapter_model.py @@ -2,14 +2,7 @@ from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaPreTrainedModel from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -20,6 +13,15 @@ class DebertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DebertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"] + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "multiple_choice", + "masked_lm", + ] + def __init__(self, config): super().__init__(config) @@ -93,11 +95,3 @@ def forward( else: # in case no head is used just return the output of the base model (including pooler output) return outputs - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "question_answering": QuestionAnsweringHead, - "masked_lm": BertStyleMaskedLMHead, - } diff --git a/src/adapters/models/deberta_v2/adapter_model.py b/src/adapters/models/deberta_v2/adapter_model.py index 5eec025f79..c306f8f475 100644 --- a/src/adapters/models/deberta_v2/adapter_model.py +++ b/src/adapters/models/deberta_v2/adapter_model.py @@ -2,15 +2,7 @@ from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -23,6 +15,15 @@ class DebertaV2AdapterModel( ): _keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"] + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "multiple_choice", + "masked_lm", + ] + def __init__(self, config): super().__init__(config) @@ -96,12 +97,3 @@ def forward( else: # in case no head is used just return the output of the base model (including pooler output) return outputs - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "question_answering": QuestionAnsweringHead, - "multiple_choice": MultipleChoiceHead, - "masked_lm": BertStyleMaskedLMHead, - } diff --git a/src/adapters/models/distilbert/adapter_model.py b/src/adapters/models/distilbert/adapter_model.py index 54d3c2bc67..c28f124408 100644 --- a/src/adapters/models/distilbert/adapter_model.py +++ b/src/adapters/models/distilbert/adapter_model.py @@ -8,17 +8,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...heads import ( - BertStyleMaskedLMHead, - BiaffineParsingHead, - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -30,6 +20,17 @@ class DistilBertAdapterModel( EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel ): + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "multiple_choice", + "dependency_parsing", + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) self.distilbert = DistilBertModel(config) @@ -124,14 +125,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "dependency_parsing": BiaffineParsingHead, - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - } diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py index 8f96b1b3bb..dbccce40d2 100644 --- a/src/adapters/models/electra/adapter_model.py +++ b/src/adapters/models/electra/adapter_model.py @@ -7,17 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - BiaffineParsingHead, - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -27,6 +17,18 @@ ELECTRA_START_DOCSTRING, ) class ElectraAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, ElectraPreTrainedModel): + + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "multiple_choice", + "dependency_parsing", + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) @@ -116,14 +118,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "dependency_parsing": BiaffineParsingHead, - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - } diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py index 83b2ee1d93..c15e5c9959 100644 --- a/src/adapters/models/gpt2/adapter_model.py +++ b/src/adapters/models/gpt2/adapter_model.py @@ -2,18 +2,11 @@ import torch -from adapters.heads.base import QuestionAnsweringHead from transformers.models.gpt2.modeling_gpt2 import GPT2_START_DOCSTRING, GPT2Model, GPT2PreTrainedModel from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel -from ...heads import ( - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -35,6 +28,14 @@ class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel): _tied_weights_keys = [] # needs to be empty since GPT2 does not yet support prompt tuning + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) @@ -150,11 +151,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): "token_type_ids": token_type_ids, "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "causal_lm": CausalLMHead, - "tagging": TaggingHead, - "question_answering": QuestionAnsweringHead, - } diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py index b404effaac..3c585ba226 100644 --- a/src/adapters/models/gptj/adapter_model.py +++ b/src/adapters/models/gptj/adapter_model.py @@ -6,14 +6,7 @@ from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel -from ...heads import ( - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -35,6 +28,14 @@ class GPTJAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTJPreTrainedModel): _tied_weights_keys = [] # needs to be empty since GPT-J does not yet support prompt tuning + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) self.transformer = GPTJModel(config) @@ -146,11 +147,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): "token_type_ids": token_type_ids, "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "causal_lm": CausalLMHead, - "question_answering": QuestionAnsweringHead, - "tagging": TaggingHead, - } diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index c3569cc013..97cc0c4e3f 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -6,7 +6,7 @@ from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel -from ...heads import CausalLMHead, ClassificationHead, ModelWithFlexibleHeadsAdaptersMixin, TaggingHead +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -28,6 +28,14 @@ class LlamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, LlamaPreTrainedModel): _tied_weights_keys = [] # needs to be empty since LLaMA does not yet support prompt tuning + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "question_answering", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) @@ -141,9 +149,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - head_types = { - "causal_lm": CausalLMHead, - "tagging": TaggingHead, - "classification": ClassificationHead, - } diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index 6fc3f787d4..186aef5c09 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -11,13 +11,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...composition import adjust_tensors_for_parallel -from ...heads import ( - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - QuestionAnsweringHead, - Seq2SeqLMHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -31,6 +25,13 @@ class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAda "decoder.embed_tokens.weight", ] + head_types = [ + "classification", + "multilabel_classification", + "question_answering", + "seq2seq_lm", + ] + def __init__(self, config: MBartConfig, **kwargs): super().__init__(config, **kwargs) self.model = MBartModel(config) @@ -168,10 +169,3 @@ def _reorder_cache(past, beam_idx): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "question_answering": QuestionAnsweringHead, - "seq2seq_lm": Seq2SeqLMHead, - } diff --git a/src/adapters/models/roberta/adapter_model.py b/src/adapters/models/roberta/adapter_model.py index 5e6a9edd98..87858566b3 100644 --- a/src/adapters/models/roberta/adapter_model.py +++ b/src/adapters/models/roberta/adapter_model.py @@ -7,17 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - BiaffineParsingHead, - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -27,6 +17,17 @@ ROBERTA_START_DOCSTRING, ) class RobertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel): + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "multiple_choice", + "question_answering", + "dependency_parsing", + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) @@ -121,14 +122,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "dependency_parsing": BiaffineParsingHead, - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - } diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index b1a9620976..b544252ce3 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -6,13 +6,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...composition import adjust_tensors_for_parallel -from ...heads import ( - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - QuestionAnsweringHead, - Seq2SeqLMHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin, Seq2SeqLMHead from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -31,6 +25,13 @@ class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdapte r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] + head_types = [ + "classification", + "multilabel_classification", + "question_answering", + "seq2seq_lm", + ] + def __init__(self, config): super().__init__(config) @@ -199,10 +200,3 @@ def _reorder_cache(self, past, beam_idx): reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) return reordered_decoder_past - - head_types = { - "seq2seq_lm": Seq2SeqLMHead, - "question_answering": QuestionAnsweringHead, - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - } diff --git a/src/adapters/models/vit/adapter_model.py b/src/adapters/models/vit/adapter_model.py index 96b8b18377..ece9ec5214 100644 --- a/src/adapters/models/vit/adapter_model.py +++ b/src/adapters/models/vit/adapter_model.py @@ -11,7 +11,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...wrappers import init @@ -20,6 +20,11 @@ VIT_START_DOCSTRING, ) class ViTAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, ViTPreTrainedModel): + + head_types = [ + "image_classification", + ] + def __init__(self, config): super().__init__(config) @@ -81,7 +86,3 @@ def forward( else: # in case no head is used just return the output of the base model (including pooler output) return outputs - - head_types = { - "image_classification": ImageClassificationHead, - } diff --git a/src/adapters/models/xlm_roberta/adapter_model.py b/src/adapters/models/xlm_roberta/adapter_model.py index e3355e1b4b..8acfde792f 100644 --- a/src/adapters/models/xlm_roberta/adapter_model.py +++ b/src/adapters/models/xlm_roberta/adapter_model.py @@ -7,17 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - BiaffineParsingHead, - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -29,6 +19,18 @@ class XLMRobertaAdapterModel( EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XLMRobertaPreTrainedModel ): + + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "multiple_choice", + "question_answering", + "dependency_parsing", + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) @@ -123,14 +125,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "dependency_parsing": BiaffineParsingHead, - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - } diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py index 54b38ccdf3..94cc43f71f 100644 --- a/src/adapters/models/xmod/adapter_model.py +++ b/src/adapters/models/xmod/adapter_model.py @@ -11,17 +11,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...context import AdapterSetup -from ...heads import ( - BertStyleMaskedLMHead, - BiaffineParsingHead, - CausalLMHead, - ClassificationHead, - ModelWithFlexibleHeadsAdaptersMixin, - MultiLabelClassificationHead, - MultipleChoiceHead, - QuestionAnsweringHead, - TaggingHead, -) +from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -31,6 +21,18 @@ XMOD_START_DOCSTRING, ) class XmodAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XmodPreTrainedModel): + + head_types = [ + "classification", + "multilabel_classification", + "tagging", + "multiple_choice", + "question_answering", + "dependency_parsing", + "masked_lm", + "causal_lm", + ] + def __init__(self, config): super().__init__(config) @@ -129,14 +131,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "past_key_values": past, "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), } - - head_types = { - "classification": ClassificationHead, - "multilabel_classification": MultiLabelClassificationHead, - "tagging": TaggingHead, - "multiple_choice": MultipleChoiceHead, - "question_answering": QuestionAnsweringHead, - "dependency_parsing": BiaffineParsingHead, - "masked_lm": BertStyleMaskedLMHead, - "causal_lm": CausalLMHead, - }