From 0a95fc82813924807f2f097758139e8428b45705 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 28 Jan 2024 13:09:38 +0100 Subject: [PATCH] Add `use_pooler` --- src/adapters/heads/model_mixin.py | 7 ++++--- src/adapters/models/albert/adapter_model.py | 1 + src/adapters/models/beit/adapter_model.py | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index d0754f9086..743fc940c0 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -34,6 +34,7 @@ class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): """ head_types: dict = {} + use_pooler: bool = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -288,7 +289,7 @@ def add_classification_head( overwrite_ok=False, multilabel=False, id2label=None, - use_pooler=False, + use_pooler=use_pooler, ): """ Adds a sequence classification head on top of the model. @@ -320,7 +321,7 @@ def add_image_classification_head( overwrite_ok=False, multilabel=False, id2label=None, - use_pooler=True, + use_pooler=use_pooler, ): """ Adds an image classification head on top of the model. @@ -355,7 +356,7 @@ def add_multiple_choice_head( activation_function="tanh", overwrite_ok=False, id2label=None, - use_pooler=False, + use_pooler=use_pooler, ): """ Adds a multiple choice head on top of the model. diff --git a/src/adapters/models/albert/adapter_model.py b/src/adapters/models/albert/adapter_model.py index 30949e56b3..ea2835a493 100644 --- a/src/adapters/models/albert/adapter_model.py +++ b/src/adapters/models/albert/adapter_model.py @@ -111,3 +111,4 @@ def forward( "question_answering": QuestionAnsweringHead, "masked_lm": BertStyleMaskedLMHead, } + use_pooler = True diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index e85dfe6d1e..2bf3eabb76 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -86,3 +86,4 @@ def forward( head_types = { "image_classification": ImageClassificationHead, } + use_pooler = True