Skip to content

Commit

Permalink
cifar
Browse files Browse the repository at this point in the history
  • Loading branch information
coschroeder committed Feb 5, 2024
1 parent 96e63f8 commit 33f4a4c
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ def __init__(self, min_dim=2, **kwargs):
class ScaleDimSW(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("Sliced Wasserstein", sliced_wasserstein_distance, **kwargs)

class CIFAR10_FID_Train_Test(Experiment):
def __init__(self):
super().__init__()

def run_experiment(self, dataset1, dataset2):
fid_metric = gaussian_squared_w2_distance(dataset1, dataset2)
return fid_metric

def log_results(self, fid_metric, log_path):
with open(log_path, "wb") as f:
pickle.dump(fid_metric, f)

def plot_experiment(self, fid_metric, dataset_name):
pass

0 comments on commit 33f4a4c

Please sign in to comment.