Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self_supervised_contrastive_loss #17

Open
timtian12 opened this issue Nov 21, 2023 · 0 comments
Open

self_supervised_contrastive_loss #17

timtian12 opened this issue Nov 21, 2023 · 0 comments

Comments

@timtian12
Copy link

def self_supervised_contrastive_loss(self, features):
'''Compute the self-supervised VPCL loss.

    Parameters
    ----------
    features: torch.Tensor
        the encoded features of multiple partitions of input tables, with shape ``(bs, n_partition, proj_dim)``.

    Returns
    -------
    loss: torch.Tensor
        the computed self-supervised VPCL loss.
    '''
    batch_size = features.shape[0]
    labels = torch.arange(batch_size, dtype=torch.long, device=self.device).view(-1,1)
    mask = torch.eq(labels, labels.T).float().to(labels.device)
    contrast_count = features.shape[1]
    # [[0,1],[2,3]] -> [0,2,1,3]
    contrast_feature = torch.cat(torch.unbind(features,dim=1),dim=0)
    anchor_feature = contrast_feature
    anchor_count = contrast_count
    anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()
    
    mask = mask.repeat(anchor_count, contrast_count)
    logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(features.device), 0)
    mask = mask * logits_mask
    # compute log_prob
    exp_logits = torch.exp(logits) * logits_mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
    # compute mean of log-likelihood over positive
    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
    loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
    loss = loss.view(anchor_count, batch_size).mean()
    return loss

I have a question about the function mentioned above. When calculating the final loss, the partition of different samples is multiplied together. Through masking, each row of data is restricted to calculate the product between the jth partition of the ith sample and the partitions of the other samples, and then the final loss is obtained. In contrastive learning, the general loss function for optimization is exp(zizj)/exp(zizk), where the numerator is the similarity between the same sample's different partitions and the denominator is the similarity between different samples' partitions. The goal is to make the numerator small and the denominator large. However, in the aforementioned function, only the denominator is visible and the numerator is not present. Is this my misunderstanding or is there a problem?
the loss is compute the sample i,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant