Skip to content

Commit

Permalink
Refactor head_types in model classes
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 31, 2024
1 parent 0a95fc8 commit d9afa76
Show file tree
Hide file tree
Showing 19 changed files with 192 additions and 285 deletions.
18 changes: 16 additions & 2 deletions src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,26 @@
logger = logging.getLogger(__name__)


MODEL_HEAD_MAP = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"multiple_choice": MultipleChoiceHead,
"question_answering": QuestionAnsweringHead,
"dependency_parsing": BiaffineParsingHead,
"masked_lm": BertStyleMaskedLMHead,
"causal_lm": CausalLMHead,
"seq2seq_lm": Seq2SeqLMHead,
"image_classification": ImageClassificationHead,
}


class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
"""
Adds flexible prediction heads to a model class. Implemented by the XModelWithHeads classes.
"""

head_types: dict = {}
head_types: list = []
use_pooler: bool = False

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -157,7 +171,7 @@ def add_prediction_head_from_config(
config["id2label"] = id2label

if head_type in self.head_types:
head_class = self.head_types[head_type]
head_class = MODEL_HEAD_MAP[head_type]
head = head_class(self, head_name, **config)
self.add_prediction_head(head, overwrite_ok=overwrite_ok, set_active=set_active)
elif head_type in self.config.custom_heads:
Expand Down
30 changes: 11 additions & 19 deletions src/adapters/models/albert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...heads import (
BertStyleMaskedLMHead,
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
MultipleChoiceHead,
QuestionAnsweringHead,
TaggingHead,
)
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -25,6 +17,16 @@
ALBERT_START_DOCSTRING,
)
class AlbertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, AlbertPreTrainedModel):
head_types = [
"classification",
"multilabel_classification",
"tagging",
"multiple_choice",
"question_answering",
"masked_lm",
]
use_pooler = True

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -102,13 +104,3 @@ def forward(
else:
# in case no head is used just return the output of the base model (including pooler output)
return outputs

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"multiple_choice": MultipleChoiceHead,
"question_answering": QuestionAnsweringHead,
"masked_lm": BertStyleMaskedLMHead,
}
use_pooler = True
22 changes: 8 additions & 14 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...heads import (
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
QuestionAnsweringHead,
Seq2SeqLMHead,
)
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -30,6 +24,13 @@ class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdap
"decoder.embed_tokens.weight",
]

head_types = [
"classification",
"multilabel_classification",
"question_answering",
"seq2seq_lm",
]

def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
Expand Down Expand Up @@ -159,10 +160,3 @@ def _reorder_cache(past, beam_idx):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"question_answering": QuestionAnsweringHead,
"seq2seq_lm": Seq2SeqLMHead,
}
12 changes: 6 additions & 6 deletions src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...wrappers import init


Expand All @@ -20,6 +20,11 @@
BEIT_START_DOCSTRING,
)
class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel):
head_types = [
"image_classification",
]
use_pooler = True

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -82,8 +87,3 @@ def forward(
else:
# in case no head is used just return the output of the base model (including pooler output)
return outputs

head_types = {
"image_classification": ImageClassificationHead,
}
use_pooler = True
35 changes: 13 additions & 22 deletions src/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,7 @@
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...heads import (
BertStyleMaskedLMHead,
BiaffineParsingHead,
CausalLMHead,
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
MultipleChoiceHead,
QuestionAnsweringHead,
TaggingHead,
)
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -27,6 +17,18 @@
BERT_START_DOCSTRING,
)
class BertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel):

head_types = [
"classification",
"multilabel_classification",
"tagging",
"multiple_choice",
"question_answering",
"dependency_parsing",
"masked_lm",
"causal_lm",
]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -121,14 +123,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"multiple_choice": MultipleChoiceHead,
"question_answering": QuestionAnsweringHead,
"dependency_parsing": BiaffineParsingHead,
"masked_lm": BertStyleMaskedLMHead,
"causal_lm": CausalLMHead,
}
12 changes: 6 additions & 6 deletions src/adapters/models/bert_generation/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...heads import BertStyleMaskedLMHead, CausalLMHead, ModelWithFlexibleHeadsAdaptersMixin
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -21,6 +21,11 @@ class BertGenerationAdapterModel(
):
_keys_to_ignore_on_load_unexpected = [r"lm_head.bias"]

head_types = [
"masked_lm",
"causal_lm",
]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -118,8 +123,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"masked_lm": BertStyleMaskedLMHead,
"causal_lm": CausalLMHead,
}
26 changes: 10 additions & 16 deletions src/adapters/models/deberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@
from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaPreTrainedModel

from ...context import AdapterSetup
from ...heads import (
BertStyleMaskedLMHead,
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
QuestionAnsweringHead,
TaggingHead,
)
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -20,6 +13,15 @@
class DebertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"]

head_types = [
"classification",
"multilabel_classification",
"tagging",
"question_answering",
"multiple_choice",
"masked_lm",
]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -93,11 +95,3 @@ def forward(
else:
# in case no head is used just return the output of the base model (including pooler output)
return outputs

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"question_answering": QuestionAnsweringHead,
"masked_lm": BertStyleMaskedLMHead,
}
28 changes: 10 additions & 18 deletions src/adapters/models/deberta_v2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel

from ...context import AdapterSetup
from ...heads import (
BertStyleMaskedLMHead,
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
MultipleChoiceHead,
QuestionAnsweringHead,
TaggingHead,
)
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -23,6 +15,15 @@ class DebertaV2AdapterModel(
):
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.bias"]

head_types = [
"classification",
"multilabel_classification",
"tagging",
"question_answering",
"multiple_choice",
"masked_lm",
]

def __init__(self, config):
super().__init__(config)

Expand Down Expand Up @@ -96,12 +97,3 @@ def forward(
else:
# in case no head is used just return the output of the base model (including pooler output)
return outputs

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"question_answering": QuestionAnsweringHead,
"multiple_choice": MultipleChoiceHead,
"masked_lm": BertStyleMaskedLMHead,
}
34 changes: 12 additions & 22 deletions src/adapters/models/distilbert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...heads import (
BertStyleMaskedLMHead,
BiaffineParsingHead,
CausalLMHead,
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
MultipleChoiceHead,
QuestionAnsweringHead,
TaggingHead,
)
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init

Expand All @@ -30,6 +20,17 @@
class DistilBertAdapterModel(
EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel
):
head_types = [
"classification",
"multilabel_classification",
"tagging",
"question_answering",
"multiple_choice",
"dependency_parsing",
"masked_lm",
"causal_lm",
]

def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
Expand Down Expand Up @@ -124,14 +125,3 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"multiple_choice": MultipleChoiceHead,
"question_answering": QuestionAnsweringHead,
"dependency_parsing": BiaffineParsingHead,
"masked_lm": BertStyleMaskedLMHead,
"causal_lm": CausalLMHead,
}
Loading

0 comments on commit d9afa76

Please sign in to comment.