-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathsampling.py
149 lines (108 loc) · 4.31 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import abc
import torch
import torch.nn.functional as F
from catsample import sample_categorical
from model import utils as mutils
_PREDICTORS = {}
def register_predictor(cls=None, *, name=None):
"""A decorator for registering predictor classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _PREDICTORS:
raise ValueError(
f'Already registered model with name: {local_name}')
_PREDICTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_predictor(name):
return _PREDICTORS[name]
class Predictor(abc.ABC):
"""The abstract class for a predictor algorithm."""
def __init__(self, graph, noise):
super().__init__()
self.graph = graph
self.noise = noise
@abc.abstractmethod
def update_fn(self, score_fn, x, t, step_size):
"""One update of the predictor.
Args:
score_fn: score function
x: A PyTorch tensor representing the current state
t: A Pytorch tensor representing the current time step.
Returns:
x: A PyTorch tensor of the next state.
"""
pass
@register_predictor(name="euler")
class EulerPredictor(Predictor):
def update_fn(self, score_fn, x, t, step_size):
sigma, dsigma = self.noise(t)
score = score_fn(x, sigma)
rev_rate = step_size * dsigma[..., None] * self.graph.reverse_rate(x, score)
x = self.graph.sample_rate(x, rev_rate)
return x
@register_predictor(name="none")
class NonePredictor(Predictor):
def update_fn(self, score_fn, x, t, step_size):
return x
@register_predictor(name="analytic")
class AnalyticPredictor(Predictor):
def update_fn(self, score_fn, x, t, step_size):
curr_sigma = self.noise(t)[0]
next_sigma = self.noise(t - step_size)[0]
dsigma = curr_sigma - next_sigma
score = score_fn(x, curr_sigma)
stag_score = self.graph.staggered_score(score, dsigma)
probs = stag_score * self.graph.transp_transition(x, dsigma)
return sample_categorical(probs)
class Denoiser:
def __init__(self, graph, noise):
self.graph = graph
self.noise = noise
def update_fn(self, score_fn, x, t):
sigma = self.noise(t)[0]
score = score_fn(x, sigma)
stag_score = self.graph.staggered_score(score, sigma)
probs = stag_score * self.graph.transp_transition(x, sigma)
# truncate probabilities
if self.graph.absorb:
probs = probs[..., :-1]
#return probs.argmax(dim=-1)
return sample_categorical(probs)
def get_sampling_fn(config, graph, noise, batch_dims, eps, device):
sampling_fn = get_pc_sampler(graph=graph,
noise=noise,
batch_dims=batch_dims,
predictor=config.sampling.predictor,
steps=config.sampling.steps,
denoise=config.sampling.noise_removal,
eps=eps,
device=device)
return sampling_fn
def get_pc_sampler(graph, noise, batch_dims, predictor, steps, denoise=True, eps=1e-5, device=torch.device('cpu'), proj_fun=lambda x: x):
predictor = get_predictor(predictor)(graph, noise)
projector = proj_fun
denoiser = Denoiser(graph, noise)
@torch.no_grad()
def pc_sampler(model):
sampling_score_fn = mutils.get_score_fn(model, train=False, sampling=True)
x = graph.sample_limit(*batch_dims).to(device)
timesteps = torch.linspace(1, eps, steps + 1, device=device)
dt = (1 - eps) / steps
for i in range(steps):
t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
x = projector(x)
x = predictor.update_fn(sampling_score_fn, x, t, dt)
if denoise:
# denoising step
x = projector(x)
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=device)
x = denoiser.update_fn(sampling_score_fn, x, t)
return x
return pc_sampler