-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can I add the importance_weights of each class to the corn loss? #39
Comments
Yes, they could be added. We omitted them for simplicity in the CORN paper. |
Thanks for your kind reply. I haven't run the code def corn_loss(logits, y_train, num_classes, importance_weights):
sets = []
for i in range(num_classes-1):
label_mask = y_train > i-1
label_tensor = (y_train[label_mask] > i).to(torch.int64)
sets.append((label_mask, label_tensor))
num_examples = 0
losses = 0.
for task_index, s in enumerate(sets):
train_examples = s[0]
train_labels = s[1]
if len(train_labels) < 1:
continue
num_examples += len(train_labels)
pred = logits[train_examples, task_index]
loss = -torch.sum(F.logsigmoid(pred)*train_labels
+ (F.logsigmoid(pred) - pred)*(1-train_labels))
#losses += loss
losses += importance_weights[task_index] * loss
return losses/num_examples |
teinhonglo
changed the title
Could I add the importance_weights of each class to the corn loss?
How can I add the importance_weights of each class to the corn loss?
Aug 6, 2023
Yes, this looks correct to me. You can also add a default argument so that it performs like before if someone doesn't specify the importance weights: def corn_loss(logits, y_train, num_classes, importance_weights=None):
sets = []
for i in range(num_classes-1):
label_mask = y_train > i-1
label_tensor = (y_train[label_mask] > i).to(torch.int64)
sets.append((label_mask, label_tensor))
num_examples = 0
losses = 0.
if importance_weights is None:
importance_weights = torch.ones(len(sets))
for task_index, s in enumerate(sets):
train_examples = s[0]
train_labels = s[1]
if len(train_labels) < 1:
continue
num_examples += len(train_labels)
pred = logits[train_examples, task_index]
loss = -torch.sum(F.logsigmoid(pred)*train_labels
+ (F.logsigmoid(pred) - pred)*(1-train_labels))
#losses += loss
losses += importance_weights[task_index] * loss
return losses/num_examples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Thanks for sharing the code.
I noticed that a importance_weights of the coral loss.
Could I add the importance_weights of each class to the corn loss?
Many thanks,
Tien-Hong
The text was updated successfully, but these errors were encountered: