Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mackelab/labproject
Browse files Browse the repository at this point in the history
  • Loading branch information
coschroeder committed Feb 5, 2024
2 parents 5789482 + 1e25c45 commit 96e63f8
Show file tree
Hide file tree
Showing 12 changed files with 536 additions and 74 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/run_experiments.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ jobs:
run: |
cd ${{ github.workspace }}
python3 -m pip install --upgrade pip
pip install -e ".[docs]"
pip install -e ".[docs,dev]"
- name: Check formatting
run: |
cd ${{ github.workspace }}
black --check labproject/
- name: Run experiments and generate plots
shell: bash
run: |
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test_experiments.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ jobs:
cd ${{ github.workspace }}
python3 -m pip install --upgrade pip
pip install -e ".[docs]"
- name: Check formatting
run: |
cd ${{ github.workspace }}
black --check labproject/
- name: Run experiments and generate plots
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/

# Distribution / packaging
.Python
data/
plots/
build/
develop-eggs/
Expand Down
4 changes: 4 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Here all functions will be documented that are part of the public API of the lab

## Metrics

::: labproject.metrics
options:
heading_level: 3

### Gaussian KL divergence

::: labproject.metrics.gaussian_kl
Expand Down
169 changes: 115 additions & 54 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.models import inception_v3

# from torchvision.models import inception_v3
from labproject.external.inception_v3 import InceptionV3

from labproject.embeddings import FIDEmbeddingNet

import warnings


STORAGEBOX_URL = os.getenv("HETZNER_STORAGEBOX_URL")
HETZNER_STORAGEBOX_USERNAME = os.getenv("HETZNER_STORAGEBOX_USERNAME")
Expand All @@ -23,6 +27,7 @@
## Hetzner Storage Box API functions ----

DATASETS = {}
DISTRIBUTIONS = {}


def upload_file(local_path: str, remote_path: str):
Expand Down Expand Up @@ -97,7 +102,10 @@ def decorator(func):
def wrapper(n: int, d: Optional[int] = None, **kwargs):

assert n > 0, "n must be a positive integer"
assert d > 0, "d must be a positive integer"
if d is not None:
assert d > 0, "d must be a positive integer"
else:
warnings.warn("d is not specified, make sure you know what you're doing!")

# Call the original function
dataset = func(n, d, **kwargs)
Expand Down Expand Up @@ -130,6 +138,47 @@ def get_dataset(name: str) -> torch.Tensor:
return DATASETS[name]


def register_distribution(name: str) -> callable:
r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.
Args:
func (callable): Dataset generator function
Returns:
callable: Dataset generator function wrapper
Example:
>>> @register_dataset("random")
>>> def random_dataset(n=1000, d=10):
>>> return torch.randn(n, d)
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Call the original function
distribution = func(*args, **kwargs)
return distribution

DISTRIBUTIONS[name] = wrapper
return wrapper

return decorator


def get_distribution(name: str) -> torch.Tensor:
r"""Get a distribution by name
Args:
name (str): Name of the distribution
Returns:
torch.Tensor: Distribution
"""
assert name in DISTRIBUTIONS, f"Distribution {name} not found, please register it first "
return DISTRIBUTIONS[name]


def load_cifar10(
n: int, save_path="data", train=True, batch_size=100, shuffle=False, num_workers=1, device="cpu"
) -> torch.Tensor:
Expand All @@ -149,27 +198,18 @@ def load_cifar10(
"""
transform = transforms.Compose(
[
transforms.Resize((299, 299)),
transforms.ToTensor(),
# normalize specific to inception model
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# Move to GPU if available
transforms.Lambda(lambda x: x.to(device if torch.cuda.is_available() else "cpu")),
]
)
cifar10 = CIFAR10(root=save_path, train=train, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(
cifar10, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
)
dataloader_subset = Subset(dataloader.dataset, range(n))
dataset_subset = Subset(dataloader.dataset, range(n))
dataloader = DataLoader(
dataloader_subset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
dataset_subset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
)
model = inception_v3(pretrained=True)
model.fc = torch.nn.Identity() # replace the classifier with identity to get features
model.eval()
model = model.to(device if torch.cuda.is_available() else "cpu")
net = FIDEmbeddingNet(model)
net = FIDEmbeddingNet(device=device)
embeddings = net.get_embeddings(dataloader)
return embeddings

Expand All @@ -186,53 +226,69 @@ def random_dataset(n=1000, d=10):
return torch.randn(n, d)


@register_dataset("toy_2d")
def toy_mog_2D(n=1000, d=2):
"""Generate samples from a 2D mixture of 4 Gaussians that look funky.
Args:
n (int): number of samples to generate
d (int): dimensionality of the samples, always 2. Changing it does nothing.
Returns:
tensor: samples of shape (num_samples, 2)
"""
means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
gaussians = [
MultivariateNormal(mean, covariance) for mean, covariance in zip(means, covariances)
]

# Sample from the mixture
categorical = Categorical(weights)
sample_indices = categorical.sample([n])
samples = torch.stack([gaussians[i].sample() for i in sample_indices])
return samples
@register_distribution("normal")
def normal_distribution():
return torch.distributions.Normal(0, 1)


@register_distribution("normal")
def normal_distribution():
return torch.distributions.Normal(0, 1)


@register_distribution("toy_2d")
def toy_mog_2d():
class Toy2D:
def __init__(self):
self.means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
self.covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
self.weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
self.gaussians = [
MultivariateNormal(mean, covariance)
for mean, covariance in zip(self.means, self.covariances)
]

def sample(self, sample_shape):
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
# Sample from the mixture
categorical = Categorical(self.weights)
sample_indices = categorical.sample(sample_shape)
return torch.stack([self.gaussians[i].sample() for i in sample_indices])

def log_prob(self, input):
probs = torch.stack([g.log_prob(input).exp() for g in self.gaussians])
probs = probs.T * self.weights
return torch.sum(probs, dim=1).log()

return Toy2D()


@register_dataset("cifar10_train")
def cifar10_train(n=1000, d=10, save_path="data", device="cpu"):
def cifar10_train(n=1000, d=2048, save_path="data", device="cpu"):

assert d is None or d == 2048, "The dimensionality of the embeddings must be 2048"

embeddings = load_cifar10(n, save_path=save_path, train=True, device=device)
# to cpu if necessary
if device == "cuda":
embeddings = embeddings.cpu()

max_n = embeddings.shape[0]

Expand All @@ -246,7 +302,12 @@ def cifar10_test(n=1000, d=2048, save_path="data", device="cpu"):

assert d == 2048, "The dimensionality of the embeddings must be 2048"

assert d is None or d == 2048, "The dimensionality of the embeddings must be 2048"

embeddings = load_cifar10(n, save_path=save_path, train=False, device=device)
# to cpu if necessary
if device == "cuda":
embeddings = embeddings.cpu()

max_n = embeddings.shape[0]

Expand Down
28 changes: 19 additions & 9 deletions labproject/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,40 @@
import torch.nn as nn
import numpy as np

from labproject.external.inception_v3 import InceptionV3, get_inception_v3_activations


# Abstract class for embedding nets
class EmbeddingNet(nn.Module):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.embedding = None
self.embedding_net = None

def forward(self, x):
return self.embedding(x)
return self.embedding_net(x)

def get_embedding(self, x):
return self.forward(x).detach().cpu().numpy()
raise NotImplementedError("Subclasses must implement this method")

def get_embeddings(self, dataloader):
embeddings = []
for batch in dataloader:
embeddings.append(self.get_embedding(batch[0]))
return np.concatenate(embeddings)
raise NotImplementedError("Subclasses must implement this method")


class FIDEmbeddingNet(EmbeddingNet):
def __init__(self, image_model):
def __init__(self, image_model=None, device="cpu"):
super(FIDEmbeddingNet, self).__init__()
self.embedding = image_model
self.embedding_net = image_model if image_model is not None else InceptionV3()
self.device = device
self.embedding_net = self.embedding_net.to(device)

def get_embeddings(self, dataloader):
embeddings = []
for batch in dataloader:
embeddings += [
get_inception_v3_activations(self.embedding_net, batch[0].to(self.device))
]
embeddings = torch.cat(embeddings, dim=0)
return embeddings

# optional override of original methods

Expand Down
6 changes: 3 additions & 3 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from .metrics import sliced_wasserstein_distance, gaussian_kl_divergence
from .plotting import plot_scaling_metric_dimensionality
from metrics import sliced_wasserstein_distance, gaussian_kl_divergence
from plotting import plot_scaling_metric_dimensionality
from metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance
import pickle


Expand All @@ -19,7 +20,6 @@ def log_results(self, results, log_path):


class ScaleDim(Experiment):

def __init__(self, metric_name, metric_fn, min_dim=1, max_dim=1000, step=100):
self.metric_name = metric_name
self.metric_fn = metric_fn
Expand Down
Loading

0 comments on commit 96e63f8

Please sign in to comment.