From 09e06c5f160c52abbc53f15a5924004aba084880 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Sat, 27 Apr 2024 16:32:02 +0800 Subject: [PATCH] bugfix --- src/billm/modeling_openelm.py | 4 ++-- tests/test_modeling_openelm.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/billm/modeling_openelm.py b/src/billm/modeling_openelm.py index 8d7a399..0cff591 100644 --- a/src/billm/modeling_openelm.py +++ b/src/billm/modeling_openelm.py @@ -996,7 +996,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = OpenELMModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + self.score = nn.Linear(config.model_dim, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @@ -1110,7 +1110,7 @@ def __init__(self, config): else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.model_dim, config.num_labels) # Initialize weights and apply final processing self.post_init() diff --git a/tests/test_modeling_openelm.py b/tests/test_modeling_openelm.py index 2afa405..12126d8 100644 --- a/tests/test_modeling_openelm.py +++ b/tests/test_modeling_openelm.py @@ -8,7 +8,7 @@ def test_openelm_model(): from billm import OpenELMModel, OpenELMConfig model = OpenELMModel(OpenELMConfig(vocab_size=128, - head_size=32, + head_dim=32, num_transformer_layers=2)) assert model is not None @@ -20,7 +20,7 @@ def test_biopenelm_model(): from billm import OpenELMModel, OpenELMConfig model = OpenELMModel(OpenELMConfig(vocab_size=128, - head_size=32, + head_dim=32, num_transformer_layers=2)) assert model is not None @@ -33,7 +33,7 @@ def test_biopenelm_lm(): from billm import OpenELMForCausalLM, OpenELMConfig model = OpenELMForCausalLM(OpenELMConfig(vocab_size=128, - head_size=32, + head_dim=32, num_transformer_layers=2)) assert model is not None assert len(model.model.bidirectionas) > 0 @@ -46,7 +46,7 @@ def test_biopenelm_seq_clf(): from billm import OpenELMForSequenceClassification, OpenELMConfig model = OpenELMForSequenceClassification(OpenELMConfig(vocab_size=128, - head_size=32, + head_dim=32, num_transformer_layers=2)) assert model is not None assert len(model.model.bidirectionas) > 0 @@ -59,7 +59,7 @@ def test_biopenelm_token_clf(): from billm import OpenELMForTokenClassification, OpenELMConfig model = OpenELMForTokenClassification(OpenELMConfig(vocab_size=128, - head_size=32, + head_dim=32, num_transformer_layers=2)) assert model is not None assert len(model.model.bidirectionas) > 0