-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_noise_leaf.py
110 lines (96 loc) · 3.35 KB
/
mnist_noise_leaf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import random
import sys
from functools import partial
from pathlib import Path
from typing import Union, Tuple
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import mnist_networks
from model_manifold.inspect import (
constant_direction_kernel,
path_tangent,
domain_projection,
)
from model_manifold.plot import denormalize, to_gif, save_strip
def mnist_noise_path(
checkpoint_path: Union[str, Path],
start_idx: int = -1,
end_idx: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
normalize = transforms.Normalize((0.1307,), (0.3081,))
test_mnist = datasets.MNIST(
"data",
train=False,
download=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]),
)
network = mnist_networks.medium_cnn(checkpoint_path)
device = next(network.parameters()).device
if start_idx == -1:
start_idx = random.randrange(len(test_mnist))
start_image = test_mnist[start_idx][0].to(device)
if end_idx == -1:
end_idx = random.randrange(len(test_mnist))
end_image = test_mnist[end_idx][0].to(device)
print(f"Compute path from a noisy {start_idx} to {end_idx}.")
v = torch.randn_like(start_image)
# noinspection PyTypeChecker
noise_path, _, _ = constant_direction_kernel(
network,
start_image,
v,
steps=250,
post_processing=partial(domain_projection, normalization=normalize),
)
noisy_start = noise_path[-1]
# noinspection PyTypeChecker
data_path, prob_path, pred_path = path_tangent(
network,
noisy_start.to(device),
end_image,
steps=10000,
post_processing=partial(domain_projection, normalization=normalize),
)
data_path = denormalize(data_path, normalize)
return data_path, prob_path, pred_path, start_idx, end_idx
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export the path on a noise leaf that minimize the distance "
"from a valid image in the MNIST test set as a .gif",
usage="python3 mnist_noise_leaf.py CHECKPOINT "
"[--start START --end END --seed SEED --output-dir OUTPUT-DIR]",
)
parser.add_argument("checkpoint", type=str, help="Path to checkpoint model")
parser.add_argument(
"--start", type=int, default=-1, help="Index of the starting image"
)
parser.add_argument("--end", type=int, default=-1, help="Index of ending image")
parser.add_argument("--seed", type=int, default=100, help="Random seed")
parser.add_argument(
"--output-dir",
type=str,
default="outputs",
help="Output directory",
)
args = parser.parse_args(sys.argv[1:])
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
image_path, probability_path, prediction_path, start, end = mnist_noise_path(
args.checkpoint, args.start, args.end
)
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
filename = f"{start:05d}_noise_{end:05d}"
to_gif(
image_path,
output_dir / f"{filename}.gif",
step=100,
scale_factor=10.0,
)
save_strip(
image_path, output_dir / f"{filename}.png", probability_path, prediction_path
)