Skip to content

Commit

Permalink
sliced Wasserstein pth root
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Feb 5, 2024
1 parent a79f044 commit 1e25c45
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions labproject/metrics/sliced_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ def sliced_wasserstein_distance(

wasserstein_distance = torch.pow(torch.abs(wasserstein_distance), p)

# NOTE: currently computes the "squared" wasserstein distance
# No p-th root is applied

# return torch.pow(torch.mean(wasserstein_distance, dim=(-2, -1)), 1 / p)
return torch.mean(wasserstein_distance, dim=(-2, -1))
return torch.pow(torch.mean(wasserstein_distance, dim=(-2, -1)), 1 / p)


def rand_projections(embedding_dim: int, num_samples: int):
Expand Down

0 comments on commit 1e25c45

Please sign in to comment.