Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mackelab/labproject into main
Browse files Browse the repository at this point in the history
  • Loading branch information
jaivardhankapoor committed Feb 5, 2024
2 parents 7fec415 + f8d0834 commit de0dca9
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/

# Distribution / packaging
.Python
data/
plots/
build/
develop-eggs/
Expand Down
100 changes: 100 additions & 0 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import functools

from torch.distributions import MultivariateNormal, Categorical

from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
Expand All @@ -25,6 +27,7 @@
## Hetzner Storage Box API functions ----

DATASETS = {}
DISTRIBUTIONS = {}


def upload_file(local_path: str, remote_path: str):
Expand Down Expand Up @@ -135,6 +138,47 @@ def get_dataset(name: str) -> torch.Tensor:
return DATASETS[name]


def register_distribution(name: str) -> callable:
r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.
Args:
func (callable): Dataset generator function
Returns:
callable: Dataset generator function wrapper
Example:
>>> @register_dataset("random")
>>> def random_dataset(n=1000, d=10):
>>> return torch.randn(n, d)
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Call the original function
distribution = func(*args, **kwargs)
return distribution

DISTRIBUTIONS[name] = wrapper
return wrapper

return decorator


def get_distribution(name: str) -> torch.Tensor:
r"""Get a distribution by name
Args:
name (str): Name of the distribution
Returns:
torch.Tensor: Distribution
"""
assert name in DISTRIBUTIONS, f"Distribution {name} not found, please register it first "
return DISTRIBUTIONS[name]


def load_cifar10(
n: int, save_path="data", train=True, batch_size=100, shuffle=False, num_workers=1, device="cpu"
) -> torch.Tensor:
Expand Down Expand Up @@ -182,6 +226,60 @@ def random_dataset(n=1000, d=10):
return torch.randn(n, d)


@register_distribution("normal")
def normal_distribution():
return torch.distributions.Normal(0, 1)


@register_distribution("normal")
def normal_distribution():
return torch.distributions.Normal(0, 1)


@register_distribution("toy_2d")
def toy_mog_2d():
class Toy2D:
def __init__(self):
self.means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
self.covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
self.weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
self.gaussians = [
MultivariateNormal(mean, covariance)
for mean, covariance in zip(self.means, self.covariances)
]

def sample(self, sample_shape):
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
# Sample from the mixture
categorical = Categorical(self.weights)
sample_indices = categorical.sample(sample_shape)
return torch.stack([self.gaussians[i].sample() for i in sample_indices])

def log_prob(self, input):
probs = torch.stack([g.log_prob(input).exp() for g in self.gaussians])
probs = probs.T * self.weights
return torch.sum(probs, dim=1).log()

return Toy2D()


@register_dataset("cifar10_train")
def cifar10_train(n=1000, d=2048, save_path="data", device="cpu"):

Expand All @@ -202,6 +300,8 @@ def cifar10_train(n=1000, d=2048, save_path="data", device="cpu"):
@register_dataset("cifar10_test")
def cifar10_test(n=1000, d=2048, save_path="data", device="cpu"):

assert d == 2048, "The dimensionality of the embeddings must be 2048"

assert d is None or d == 2048, "The dimensionality of the embeddings must be 2048"

embeddings = load_cifar10(n, save_path=save_path, train=False, device=device)
Expand Down
17 changes: 16 additions & 1 deletion labproject/metrics/c2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
import inspect
from torch import ones, zeros, eye, sum, Tensor, tensor, allclose, manual_seed
from torch.distributions import MultivariateNormal, Normal
from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -332,7 +333,9 @@ def c2st_scores(
X = X.cpu().numpy()
Y = Y.cpu().numpy()

clf = clf_class(random_state=seed, **clf_kwargs)
if "random_state" in inspect.signature(clf_class.__init__).parameters.keys():
clf_kwargs["random_state"] = seed
clf = clf_class(**clf_kwargs)

# prepare data
data = np.concatenate((X, Y))
Expand Down Expand Up @@ -382,3 +385,15 @@ def test_optimal_c2st():
c2st = c2st_optimal(d1, d2, 100_000)
target = Normal(0.0, 1.0).cdf(tensor(mean_diff // 2))
assert allclose(c2st, target, atol=1e-3)


if __name__ == "__main__":
# Generate random samples
samples1 = torch.randn(100, 2)
samples2 = torch.randn(100, 2)

# Compute sliced wasserstein distance
c2st_nn_score = c2st_nn(samples1, samples2)
c2st_knn_score = c2st_knn(samples1, samples2)
c2st_rf_score = c2st_rf(samples1, samples2)
print(f"C2ST-NN: {c2st_nn_score}\nC2ST-KNN: {c2st_knn_score}\nC2ST-RF: {c2st_rf_score}")
69 changes: 69 additions & 0 deletions labproject/metrics/wasserstein_sinkhorn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch


def sinkhorn_loss(x: torch.Tensor, y: torch.Tensor, epsilon: float, niter: int = 100, p: int = 2):
r"""Compute the sinkhorn approximation to the Wasserstein-p distance between two sets of samples.
The sinkhorn algorithm adds a small entropy regularization term to the empirical Wasserstein distance.
Hence this function solves the modified optimal transport problem:
$$ \text{maximize}_{\pi \in \Pi(a, b)} \sum_\limits_{ij} \pi_{ij}c_{ij} +\epsilon\sum\limits_{ij} \log \pi_{ij}
\text{s.t} \, \pi 1 = a, \pi^T 1 = b
$$
Where $\{c_{ij}\}$ is the cost matrix, $\Pi(a, b)$ is the set of joint distributions with marginals $a$ and $b$.
In the sample-based setting, all weights $a$ and $b$ are equal to $1/n$.
Args:
x (torch.Tensor): tensor of samples from one distribution
y (torch.Tensor): tensor of samples from another distribution
epsilon (float): entropy regularization strength
niter (int): max number of iterations
p (int): power of distance metric
Source: https://personal.math.ubc.ca/~geoff/courses/W2019T1/Lecture13.pdf
Code adapted from https://github.com/gpeyre/SinkhornAutoDiff
"""

assert len(x.shape) == 2 and len(y.shape) == 2, "x and y must be 2D"
n, d = x.shape

# Compute pairwise p-distances
cost_matrix = torch.cdist(x.unsqueeze(0).double(), y.unsqueeze(0).double(), p=p)

K = torch.exp(-cost_matrix / epsilon)
a = torch.ones(n, dtype=torch.double) / n
b = torch.ones(n, dtype=torch.double) / n

def MC(u, v):
r"""Modified cost for logarithmic updates on u,v
$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"""
return (-cost_matrix + u.unsqueeze(1) + v.unsqueeze(0)) / epsilon

err = 1e6
actual_niter = 0 # count number of iterations
thresh = 1e-2
u, v = torch.zeros(n, dtype=torch.double), torch.zeros(n, dtype=torch.double)

# Sinkhorn loop
for actual_niter in range(niter):
u1 = u
u = epsilon * (torch.log(a) - torch.logsumexp(MC(u, v), dim=1)) + u
v = epsilon * (torch.log(b) - torch.logsumexp(MC(u, v).T, dim=1)) + v
err = torch.max((u - u1).abs().sum(), (v - v).abs().sum())
actual_niter += 1
if err < thresh:
break

U, V = u, v
transport = torch.exp(MC(U, V)) # Transport plan pi = diag(a)*K*diag(b)
cost = torch.sum(transport * cost_matrix) # Sinkhorn cost

return cost


if __name__ == "__main__":
# example usage
real_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
fake_samples = torch.randn(100, 2) # 100 samples, 2-dimensional

w2_dist = sinkhorn_loss(real_samples, fake_samples)
print(w2_dist)
6 changes: 4 additions & 2 deletions labproject/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os

# Load matplotlibrc file
STYLE_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "matplotlibrc")) # Necessary for GitHub Actions/ Calling from other directories
STYLE_PATH = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "matplotlibrc")
) # Necessary for GitHub Actions/ Calling from other directories
plt.style.use(STYLE_PATH)

PLOT_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "plots"))
PLOT_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "plots"))


def plot_scaling_metric_dimensionality(dimensionality, distances, metric_name, dataset_name):
Expand Down
20 changes: 20 additions & 0 deletions labproject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import inspect
import os
import datetime

from omegaconf import OmegaConf

Expand Down Expand Up @@ -46,3 +47,22 @@ def get_cfg() -> OmegaConf:
msg = f"Config file not found for {name}. Please create a config file at ../configs/conf_{name}.yaml"
raise FileNotFoundError(msg)
return config


def get_log_path(cfg):
"""
Get the log path for the current experiment run.
This log path is then used to save the numerical results of the experiment.
Import this function in the run_{name}.py file and call it to get the log path.
"""

# get datetime string
now = datetime.datetime.now()
if "exp_log_name" not in cfg:
exp_log_name = now.strftime("%Y-%m-%d_%H-%M-%S")
else:
exp_log_name = cfg.exp_log_name
# add datetime to the name
exp_log_name = exp_log_name + "_" + now.strftime("%Y-%m-%d_%H-%M-%S")
log_path = os.path.join(f"results/{cfg.running_user}/{exp_log_name}.pkl")
return log_path

0 comments on commit de0dca9

Please sign in to comment.