Skip to content

Commit

Permalink
distance-based loss
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Oct 24, 2023
1 parent 2f52248 commit c2947a6
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,12 @@ def forward(self, data, **kwargs):
membership_per_dim = torch.max(torch.stack((nn.functional.relu(l - p), nn.functional.relu(p - r))), dim=0)[0]
# min might be replaced
#m = torch.min(membership_per_dim, dim=-1)[0]
m = torch.mean(membership_per_dim, dim=-1)
s = 2 - ( 2 * (torch.sigmoid(m)) )
logits = torch.logit( (s * 0.99) + 0.001)
#m = torch.mean(membership_per_dim, dim=-1)
#s = 2 - ( 2 * (torch.sigmoid(m)) )
#logits = torch.logit( (s * 0.99) + 0.001)

m = torch.sum(membership_per_dim, dim=-1)
logits = m

return dict(
boxes=b,
Expand All @@ -530,6 +533,19 @@ def forward(self, data, **kwargs):
target_mask=data.get("target_mask"),
)

class BoxLoss(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, outputs, targets):
d = outputs
t = targets
theta = 0.00004
loss = ( ( torch.sqrt(d) * (d > theta) * t ) + ( (d <= theta) * (1 - d) * (1 - t) ) )

scalar_loss = torch.max(loss)
return scalar_loss

def softabs(x, eps=0.01):
return (x**2 + eps) ** 0.5 - eps**0.5

Expand Down

0 comments on commit c2947a6

Please sign in to comment.