Skip to content

Commit

Permalink
Add use_pooler
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 28, 2024
1 parent 9e81655 commit 0a95fc8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
"""

head_types: dict = {}
use_pooler: bool = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -288,7 +289,7 @@ def add_classification_head(
overwrite_ok=False,
multilabel=False,
id2label=None,
use_pooler=False,
use_pooler=use_pooler,
):
"""
Adds a sequence classification head on top of the model.
Expand Down Expand Up @@ -320,7 +321,7 @@ def add_image_classification_head(
overwrite_ok=False,
multilabel=False,
id2label=None,
use_pooler=True,
use_pooler=use_pooler,
):
"""
Adds an image classification head on top of the model.
Expand Down Expand Up @@ -355,7 +356,7 @@ def add_multiple_choice_head(
activation_function="tanh",
overwrite_ok=False,
id2label=None,
use_pooler=False,
use_pooler=use_pooler,
):
"""
Adds a multiple choice head on top of the model.
Expand Down
1 change: 1 addition & 0 deletions src/adapters/models/albert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ def forward(
"question_answering": QuestionAnsweringHead,
"masked_lm": BertStyleMaskedLMHead,
}
use_pooler = True
1 change: 1 addition & 0 deletions src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@ def forward(
head_types = {
"image_classification": ImageClassificationHead,
}
use_pooler = True

0 comments on commit 0a95fc8

Please sign in to comment.