Skip to content

Commit

Permalink
Merge pull request #12 from mackelab/manuel_changes
Browse files Browse the repository at this point in the history
Adding the configs, per run and seeding and so one
  • Loading branch information
manuelgloeckler authored Feb 2, 2024
2 parents 3bdde36 + c7d5e00 commit a0eb2e6
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 66 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/

# Distribution / packaging
.Python
plots/
build/
develop-eggs/
dist/
Expand Down
7 changes: 7 additions & 0 deletions configs/conf_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

data: "random"
experiments: ["ScaleDimKL"]
n: 10000
d: 100

seed: 0
51 changes: 50 additions & 1 deletion labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
import requests
from requests.auth import HTTPBasicAuth
import os
import functools


STORAGEBOX_URL = os.getenv("HETZNER_STORAGEBOX_URL")
HETZNER_STORAGEBOX_USERNAME = os.getenv("HETZNER_STORAGEBOX_USERNAME")
HETZNER_STORAGEBOX_PASSWORD = os.getenv("HETZNER_STORAGEBOX_PASSWORD")

torch.manual_seed(0)

## Hetzner Storage Box API functions ----

DATASETS = {}

def upload_file(local_path, remote_path):
"""
Expand Down Expand Up @@ -56,10 +58,57 @@ def download_file(remote_path, local_path):
return True
return False

def register_dataset(name:str) -> callable:
"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.
Args:
func (callable): Dataset generator function
Returns:
callable: Dataset generator function wrapper
"""

def decorator(func):
@functools.wraps(func)
def wrapper(n: int, d: int, **kwargs):

assert n > 0, "n must be a positive integer"
assert d > 0, "d must be a positive integer"

# Call the original function
dataset = func(n,d, **kwargs)

# Convert the dataset to a PyTorch tensor
dataset = torch.Tensor(dataset) if not isinstance(dataset, torch.Tensor) else dataset

assert dataset.shape == (n, d), f"Dataset shape must be {(n, d)}"

return dataset

DATASETS[name] = wrapper
return wrapper
return decorator

def get_dataset(name: str) -> torch.Tensor:
"""Get a dataset by name
Args:
name (str): Name of the dataset
n (int): Number of samples
d (int): Dimensionality of the samples
Returns:
torch.Tensor: Dataset
"""
assert name in DATASETS, f"Dataset {name} not found, please register it first "
return DATASETS[name]

# ------------------------------


## Data functions ----
# This will be an arbitrary function, returning a numric array and can be registered as a dataset as follows:

@register_dataset("random")
def random_dataset(n=1000, d=10):
return torch.randn(n, d)
58 changes: 32 additions & 26 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,40 @@
import torch
from metrics.sliced_wasserstein import sliced_wasserstein_distance
from metrics.gaussian_kl import gaussian_kl_divergence
from data import random_dataset
from metrics import sliced_wasserstein_distance, gaussian_kl_divergence
from plotting import plot_scaling_metric_dimensionality


def scaling_sliced_wasserstein_samples(dataset1, dataset2):
distances = []
dimensionality = list(range(1, 1000, 100))
for d in dimensionality:
distances.append(sliced_wasserstein_distance(dataset1[:, :d], dataset2[:, :d]))
return dimensionality, distances


def scaling_kl_samples(dataset1, dataset2):
distances = []
dimensionality = list(range(2, 1000, 98))
for d in dimensionality:
distances.append(gaussian_kl_divergence(dataset1[:, :d], dataset2[:, :d]))
# print(distances)
return dimensionality, distances


def run_metric_on_datasets(dataset1, dataset2, metric):
return metric(dataset1, dataset2)


class Experiment:
def __init__(self):
pass

def run_experiment(self, dataset1, dataset2, experiment_fn):
return experiment_fn(dataset1, dataset2)
def run_experiment(self, metric, dataset1, dataset2):
raise NotImplementedError("Subclasses must implement this method")

def plot_experiment(self):
raise NotImplementedError("Subclasses must implement this method")


class ScaleDim(Experiment):

def __init__(self, metric_name, metric_fn, min_dim=1, max_dim=1000, step=100):
self.metric_name = metric_name
self.metric_fn = metric_fn
self.dimensionality = list(range(min_dim, max_dim, step))
super().__init__()

def run_experiment(self, dataset1, dataset2):
distances = []
for d in self.dimensionality:
distances.append(self.metric_fn(dataset1[:, :d], dataset2[:, :d]))
return self.dimensionality, distances

def plot_experiment(self, dimensionality, distances, dataset_name):
plot_scaling_metric_dimensionality(dimensionality, distances, self.metric_name, dataset_name)

class ScaleDimKL(ScaleDim):
def __init__(self):
super().__init__("KL", gaussian_kl_divergence, min_dim=2)

class ScaleDimSW(ScaleDim):
def __init__(self):
super().__init__("Sliced Wasserstein", sliced_wasserstein_distance)
2 changes: 2 additions & 0 deletions labproject/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from labproject.metrics.gaussian_kl import gaussian_kl_divergence
from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance
5 changes: 0 additions & 5 deletions labproject/metrics/sliced_wasserstein.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# STOLEN from Julius: https://github.com/mackelab/wasserstein_source/blob/main/wasser/sliced_wasserstein.py

import numpy as np

np.random.seed(0)
import torch

torch.manual_seed(0)


def rand_projections(embedding_dim, num_samples: int):
"""
This function generates num_samples random samples from the latent space's unti sphere.r
Expand Down
7 changes: 5 additions & 2 deletions labproject/plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import matplotlib.pyplot as plt
import os

plots_path = "plots/"
# Load matplotlibrc file
plt.style.use("../matplotlibrc")

PLOT_PATH = "../plots/"


def plot_scaling_metric_dimensionality(dimensionality, distances, metric_name, dataset_name):
Expand All @@ -11,7 +14,7 @@ def plot_scaling_metric_dimensionality(dimensionality, distances, metric_name, d
plt.title(f"{metric_name} with increasing dimensionality for {dataset_name}")
plt.savefig(
os.path.join(
plots_path,
PLOT_PATH,
f"{metric_name.lower().replace(' ', '_')}_dimensionality_{dataset_name.lower().replace(' ', '_')}.png",
)
)
Expand Down
32 changes: 0 additions & 32 deletions labproject/run.py

This file was deleted.

41 changes: 41 additions & 0 deletions labproject/run_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

from labproject.utils import set_seed, get_cfg
from labproject.data import get_dataset
from labproject.experiments import *


import time


if __name__ == "__main__":



print("Running experiments...")
cfg = get_cfg()
seed = cfg.seed

set_seed(seed)
print(f"Seed: {seed}")
print(f"Experiments: {cfg.experiments}")
print(f"Data: {cfg.data}")

dataset_fn = get_dataset(cfg.data)


for exp_name in cfg.experiments:
experiment = globals()[exp_name]()
time_start = time.time()
dataset1 = dataset_fn(cfg.n, cfg.d)
dataset2 = dataset_fn(cfg.n, cfg.d)

output = experiment.run_experiment(
dataset1=dataset1, dataset2=dataset2
)
time_end = time.time()
print(f"Experiment {exp_name} finished in {time_end - time_start}")
experiment.plot_experiment(*output, cfg.data)



print("Finished running experiments.")
41 changes: 41 additions & 0 deletions labproject/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import numpy as np
import random
import inspect

from omegaconf import OmegaConf


def set_seed(seed: int) -> None:
"""Set seed for reproducibility
Args:
seed (int): Integer seed
"""
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return seed

def get_cfg() -> OmegaConf:
"""This function returns the configuration file for the current experiment run.
The configuration file is expected to be located at ../configs/conf_{name}.yaml, where name will match the name of the run_{name}.py file.
Raises:
FileNotFoundError: If the configuration file is not found
Returns:
OmegaConf: Dictionary with the configuration parameters
"""
caller_frame = inspect.currentframe().f_back
filename = caller_frame.f_code.co_filename
name = filename.split("/")[-1].split(".")[0].split("_")[-1]
try:
config = OmegaConf.load(f"../configs/conf_{name}.yaml")
except FileNotFoundError:
msg = f"Config file not found for {name}. Please create a config file at ../configs/conf_{name}.yaml"
raise FileNotFoundError(msg)
return config
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"scipy",
"matplotlib",
"torch",
"OmegaConf",
]

[project.optional-dependencies]
Expand Down

0 comments on commit a0eb2e6

Please sign in to comment.