Skip to content

Commit

Permalink
take mean for distance-based loss
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Oct 24, 2023
1 parent c2947a6 commit 324b9a6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def __call__(self, outputs, targets):
theta = 0.00004
loss = ( ( torch.sqrt(d) * (d > theta) * t ) + ( (d <= theta) * (1 - d) * (1 - t) ) )

scalar_loss = torch.max(loss)
scalar_loss = torch.mean(loss)
return scalar_loss

def softabs(x, eps=0.01):
Expand Down

0 comments on commit 324b9a6

Please sign in to comment.