Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mackelab/labproject into main
Browse files Browse the repository at this point in the history
  • Loading branch information
a-darcher committed Feb 6, 2024
2 parents 933969e + 50c6acb commit f68c085
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 17 deletions.
121 changes: 104 additions & 17 deletions docs/notebooks/wasserstein_intuition.ipynb

Large diffs are not rendered by default.

74 changes: 74 additions & 0 deletions labproject/metrics/MMD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
from sklearn import metrics
import random


def mmd_linear(X, Y):
"""MMD using linear kernel (i.e., k(x,y) = <x,y>)
Note that this is not the original linear MMD, only the reformulated and faster version.
The original version is:
def mmd_linear(X, Y):
XX = np.dot(X, X.T)
YY = np.dot(Y, Y.T)
XY = np.dot(X, Y.T)
return XX.mean() + YY.mean() - 2 * XY.mean()
Arguments:
X {[n_sample1, dim]} -- [X matrix]
Y {[n_sample2, dim]} -- [Y matrix]
Returns:
[scalar] -- [MMD value]
"""
delta = X.mean(0) - Y.mean(0)
return delta.dot(delta.T)


def mmd_rbf(X, Y, gamma=1.0):
"""MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))
Arguments:
X {[n_sample1, dim]} -- [X matrix]
Y {[n_sample2, dim]} -- [Y matrix]
Keyword Arguments:
gamma {float} -- [kernel parameter] (default: {1.0})
Returns:
[scalar] -- [MMD value]
"""
XX = metrics.pairwise.rbf_kernel(X, X, gamma)
YY = metrics.pairwise.rbf_kernel(Y, Y, gamma)
XY = metrics.pairwise.rbf_kernel(X, Y, gamma)
return XX.mean() + YY.mean() - 2 * XY.mean()


def mmd_poly(X, Y, degree=2, gamma=1, coef0=0):
"""MMD using polynomial kernel (i.e., k(x,y) = (gamma <X, Y> + coef0)^degree)
Arguments:
X {[n_sample1, dim]} -- [X matrix]
Y {[n_sample2, dim]} -- [Y matrix]
Keyword Arguments:
degree {int} -- [degree] (default: {2})
gamma {int} -- [gamma] (default: {1})
coef0 {int} -- [constant item] (default: {0})
Returns:
[scalar] -- [MMD value]
"""
XX = metrics.pairwise.polynomial_kernel(X, X, degree, gamma, coef0)
YY = metrics.pairwise.polynomial_kernel(Y, Y, degree, gamma, coef0)
XY = metrics.pairwise.polynomial_kernel(X, Y, degree, gamma, coef0)
return XX.mean() + YY.mean() - 2 * XY.mean()


# a = np.arange(1, 10).reshape(3, 3)
# b = [[7, 6, 5], [4, 3, 2], [1, 1, 8], [0, 2, 5]]
# b = np.array(b)
# print('a:', a)
# print('b:', b)
# print(mmd_linear(a, b)) # 6.0
# print(mmd_rbf(a, b)) # 0.5822
# print(mmd_poly(a, b)) # 2436.5
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"torch",
"OmegaConf",
"torchvision",
"seaborn"
]

[project.optional-dependencies]
Expand Down

0 comments on commit f68c085

Please sign in to comment.