Skip to content

Commit

Permalink
add alpha for label relaxation as a model parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Dec 19, 2024
1 parent 741b129 commit 461528d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dicee/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def __init__(self, args: dict):
if self.args["loss_fn"] == "BCELoss":
self.loss = torch.nn.BCEWithLogitsLoss()
if self.args["loss_fn"] == "LRLoss":
self.loss = LabelRelaxationLoss()
self.loss = LabelRelaxationLoss(alpha=self.args["label_relaxation_alpha"])
else:
self.loss = torch.nn.BCEWithLogitsLoss()

if self.byte_pair_encoding and self.args['model'] != "BytE":
self.token_embeddings = torch.nn.Embedding(self.num_tokens, self.embedding_dim)
self.param_init(self.token_embeddings.weight.data)
Expand Down
2 changes: 2 additions & 0 deletions dicee/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def get_default_arguments(description=None):
help='degree for polynomial embeddings')
parser.add_argument('--loss_fn', type=str, default="BCELoss",
help='The loss function used in the model')
parser.add_argument('--label_relaxation_alpha', type=float, default=0.1,
help='The alpha value for label relaxation')

if description is None:
return parser.parse_args()
Expand Down

0 comments on commit 461528d

Please sign in to comment.