Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Apr 27, 2024
1 parent 073e443 commit 09e06c5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/billm/modeling_openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = OpenELMModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.score = nn.Linear(config.model_dim, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def __init__(self, config):
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = nn.Linear(config.model_dim, config.num_labels)

# Initialize weights and apply final processing
self.post_init()
Expand Down
10 changes: 5 additions & 5 deletions tests/test_modeling_openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_openelm_model():
from billm import OpenELMModel, OpenELMConfig

model = OpenELMModel(OpenELMConfig(vocab_size=128,
head_size=32,
head_dim=32,
num_transformer_layers=2))
assert model is not None

Expand All @@ -20,7 +20,7 @@ def test_biopenelm_model():
from billm import OpenELMModel, OpenELMConfig

model = OpenELMModel(OpenELMConfig(vocab_size=128,
head_size=32,
head_dim=32,
num_transformer_layers=2))
assert model is not None

Expand All @@ -33,7 +33,7 @@ def test_biopenelm_lm():
from billm import OpenELMForCausalLM, OpenELMConfig

model = OpenELMForCausalLM(OpenELMConfig(vocab_size=128,
head_size=32,
head_dim=32,
num_transformer_layers=2))
assert model is not None
assert len(model.model.bidirectionas) > 0
Expand All @@ -46,7 +46,7 @@ def test_biopenelm_seq_clf():
from billm import OpenELMForSequenceClassification, OpenELMConfig

model = OpenELMForSequenceClassification(OpenELMConfig(vocab_size=128,
head_size=32,
head_dim=32,
num_transformer_layers=2))
assert model is not None
assert len(model.model.bidirectionas) > 0
Expand All @@ -59,7 +59,7 @@ def test_biopenelm_token_clf():
from billm import OpenELMForTokenClassification, OpenELMConfig

model = OpenELMForTokenClassification(OpenELMConfig(vocab_size=128,
head_size=32,
head_dim=32,
num_transformer_layers=2))
assert model is not None
assert len(model.model.bidirectionas) > 0

0 comments on commit 09e06c5

Please sign in to comment.