-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from mackelab/manuel_changes
Adding the configs, per run and seeding and so one
- Loading branch information
Showing
11 changed files
with
180 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ __pycache__/ | |
|
||
# Distribution / packaging | ||
.Python | ||
plots/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
data: "random" | ||
experiments: ["ScaleDimKL"] | ||
n: 10000 | ||
d: 100 | ||
|
||
seed: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,40 @@ | ||
import torch | ||
from metrics.sliced_wasserstein import sliced_wasserstein_distance | ||
from metrics.gaussian_kl import gaussian_kl_divergence | ||
from data import random_dataset | ||
from metrics import sliced_wasserstein_distance, gaussian_kl_divergence | ||
from plotting import plot_scaling_metric_dimensionality | ||
|
||
|
||
def scaling_sliced_wasserstein_samples(dataset1, dataset2): | ||
distances = [] | ||
dimensionality = list(range(1, 1000, 100)) | ||
for d in dimensionality: | ||
distances.append(sliced_wasserstein_distance(dataset1[:, :d], dataset2[:, :d])) | ||
return dimensionality, distances | ||
|
||
|
||
def scaling_kl_samples(dataset1, dataset2): | ||
distances = [] | ||
dimensionality = list(range(2, 1000, 98)) | ||
for d in dimensionality: | ||
distances.append(gaussian_kl_divergence(dataset1[:, :d], dataset2[:, :d])) | ||
# print(distances) | ||
return dimensionality, distances | ||
|
||
|
||
def run_metric_on_datasets(dataset1, dataset2, metric): | ||
return metric(dataset1, dataset2) | ||
|
||
|
||
class Experiment: | ||
def __init__(self): | ||
pass | ||
|
||
def run_experiment(self, dataset1, dataset2, experiment_fn): | ||
return experiment_fn(dataset1, dataset2) | ||
def run_experiment(self, metric, dataset1, dataset2): | ||
raise NotImplementedError("Subclasses must implement this method") | ||
|
||
def plot_experiment(self): | ||
raise NotImplementedError("Subclasses must implement this method") | ||
|
||
|
||
class ScaleDim(Experiment): | ||
|
||
def __init__(self, metric_name, metric_fn, min_dim=1, max_dim=1000, step=100): | ||
self.metric_name = metric_name | ||
self.metric_fn = metric_fn | ||
self.dimensionality = list(range(min_dim, max_dim, step)) | ||
super().__init__() | ||
|
||
def run_experiment(self, dataset1, dataset2): | ||
distances = [] | ||
for d in self.dimensionality: | ||
distances.append(self.metric_fn(dataset1[:, :d], dataset2[:, :d])) | ||
return self.dimensionality, distances | ||
|
||
def plot_experiment(self, dimensionality, distances, dataset_name): | ||
plot_scaling_metric_dimensionality(dimensionality, distances, self.metric_name, dataset_name) | ||
|
||
class ScaleDimKL(ScaleDim): | ||
def __init__(self): | ||
super().__init__("KL", gaussian_kl_divergence, min_dim=2) | ||
|
||
class ScaleDimSW(ScaleDim): | ||
def __init__(self): | ||
super().__init__("Sliced Wasserstein", sliced_wasserstein_distance) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from labproject.metrics.gaussian_kl import gaussian_kl_divergence | ||
from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
|
||
from labproject.utils import set_seed, get_cfg | ||
from labproject.data import get_dataset | ||
from labproject.experiments import * | ||
|
||
|
||
import time | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
|
||
|
||
print("Running experiments...") | ||
cfg = get_cfg() | ||
seed = cfg.seed | ||
|
||
set_seed(seed) | ||
print(f"Seed: {seed}") | ||
print(f"Experiments: {cfg.experiments}") | ||
print(f"Data: {cfg.data}") | ||
|
||
dataset_fn = get_dataset(cfg.data) | ||
|
||
|
||
for exp_name in cfg.experiments: | ||
experiment = globals()[exp_name]() | ||
time_start = time.time() | ||
dataset1 = dataset_fn(cfg.n, cfg.d) | ||
dataset2 = dataset_fn(cfg.n, cfg.d) | ||
|
||
output = experiment.run_experiment( | ||
dataset1=dataset1, dataset2=dataset2 | ||
) | ||
time_end = time.time() | ||
print(f"Experiment {exp_name} finished in {time_end - time_start}") | ||
experiment.plot_experiment(*output, cfg.data) | ||
|
||
|
||
|
||
print("Finished running experiments.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import numpy as np | ||
import random | ||
import inspect | ||
|
||
from omegaconf import OmegaConf | ||
|
||
|
||
def set_seed(seed: int) -> None: | ||
"""Set seed for reproducibility | ||
Args: | ||
seed (int): Integer seed | ||
""" | ||
torch.manual_seed(seed) | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
return seed | ||
|
||
def get_cfg() -> OmegaConf: | ||
"""This function returns the configuration file for the current experiment run. | ||
The configuration file is expected to be located at ../configs/conf_{name}.yaml, where name will match the name of the run_{name}.py file. | ||
Raises: | ||
FileNotFoundError: If the configuration file is not found | ||
Returns: | ||
OmegaConf: Dictionary with the configuration parameters | ||
""" | ||
caller_frame = inspect.currentframe().f_back | ||
filename = caller_frame.f_code.co_filename | ||
name = filename.split("/")[-1].split(".")[0].split("_")[-1] | ||
try: | ||
config = OmegaConf.load(f"../configs/conf_{name}.yaml") | ||
except FileNotFoundError: | ||
msg = f"Config file not found for {name}. Please create a config file at ../configs/conf_{name}.yaml" | ||
raise FileNotFoundError(msg) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ dependencies = [ | |
"scipy", | ||
"matplotlib", | ||
"torch", | ||
"OmegaConf", | ||
] | ||
|
||
[project.optional-dependencies] | ||
|