-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrainer.py
620 lines (519 loc) · 26.8 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
from typing import List, Optional
from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
import lpips
import logging
import matplotlib.pyplot as plt
import imageio, PIL
import os
import math
import pickle
import numpy as np
from dataset import LD3Dataset
from utils import move_tensor_to_device, compute_distance_between_two, compute_distance_between_two_L1
def save_gif(snapshot_path: str):
care_files = [f for f in os.listdir(snapshot_path) if "log_best" in f]
care_files = sorted(care_files, key=lambda f: int(f.split("_")[-1].replace(".png", "")))
images = []
for f in care_files:
images.append(imageio.imread(os.path.join(snapshot_path, f)))
imageio.mimsave(os.path.join(snapshot_path, "gif.gif"), images, duration=100.)
print(f"Saved gif to {os.path.join(snapshot_path, 'gif.gif')}")
def visual(input_, name="test.png", img_resolution=32, img_channels=3):
input_ = (input_ + 1.) / 2.
batch_size = input_.shape[0]
gridh = int(math.sqrt(batch_size))
for i in range(1, gridh+1):
if batch_size % i == 0:
gridh = i
gridw = batch_size // gridh
image = (input_ * 255.).clip(0, 255).to(torch.uint8)
image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels)
image = image.cpu().numpy()
PIL.Image.fromarray(image, 'RGB').save(name)
def custom_collate_fn(batch):
collated_batch = []
for samples in zip(*batch):
if any(item is None for item in samples):
collated_batch.append(None)
else:
collated_batch.append(torch.utils.data._utils.collate.default_collate(samples))
return collated_batch
@dataclass
class TrainingConfig:
train_data: any
valid_data: any
train_batch_size: int
valid_batch_size: int
lr_time_1: float
lr_time_2: float
shift_lr: float
shift_lr_decay: float = 0.5
min_lr_time_1: float = 5e-5
min_lr_time_2: float = 1e-6
win_rate: float = 0.5
patient: int = 5
lr2_patient: int = 5
lr_time_decay: float = 0.8
momentum_time_1: float = 0.9
weight_decay_time_1: float = 0.0
loss_type: str = "LPIPS"
visualize: bool = False
no_v1: bool = False
prior_timesteps: Optional[List[float]] = None
match_prior: bool = False
@dataclass
class ModelConfig:
net: any
decoding_fn: any
noise_schedule: any
solver: any
solver_name: str
order: int
steps: int
prior_bound: float
resolution: int
channels: int
time_mode: str
solver_extra_params: Optional[dict] = None
snapshot_path: str = "logs"
device: Optional[str] = None
class LD3Trainer:
def __init__(
self, model_config: ModelConfig, training_config: TrainingConfig
) -> None:
# Model parameters
self.net = model_config.net
self.decoding_fn = model_config.decoding_fn
self.noise_schedule = model_config.noise_schedule
self.solver = model_config.solver
self.solver_name = model_config.solver_name
self.order = model_config.order
self.steps = model_config.steps
self.prior_bound = model_config.prior_bound
self.resolution = model_config.resolution
self.channels = model_config.channels
self.time_mode = model_config.time_mode
# Learning rate parameters
self.lr_time_1 = training_config.lr_time_1
self.lr_time_2 = training_config.lr_time_2
self.shift_lr = training_config.shift_lr
self.shift_lr_decay = training_config.shift_lr_decay
self.min_lr_time_1 = training_config.min_lr_time_1
self.min_lr_time_2 = training_config.min_lr_time_2
self.lr_time_decay = training_config.lr_time_decay
self.momentum_time_1 = training_config.momentum_time_1
self.weight_decay_time_1 = training_config.weight_decay_time_1
# Training data and batch sizes
self.train_data = training_config.train_data
self.valid_data = training_config.valid_data
self.train_batch_size = training_config.train_batch_size
self.valid_batch_size = training_config.valid_batch_size
self._create_valid_loaders()
self._create_train_loader()
# Training state
self.cur_iter = 0
self.cur_round = 0
self.count_worse = 0
self.count_min_lr1_hit = 0
self.count_min_lr2_hit = 0
self.best_loss = float("inf")
# Other parameters
self.patient = training_config.patient
self.lr2_patient = training_config.lr2_patient
self.no_v1 = training_config.no_v1
self.win_rate = training_config.win_rate
self.snapshot_path = model_config.snapshot_path
os.makedirs(self.snapshot_path, exist_ok=True)
self.visualize = training_config.visualize
# Device and optimizer setup
self._set_device(model_config.device)
self.params1, self.params2 = self._initialize_params()
self.optimizer_lamb1 = torch.optim.RMSprop(
[self.params1],
lr=training_config.lr_time_1,
momentum=training_config.momentum_time_1,
weight_decay=training_config.weight_decay_time_1,
)
self.optimizer_lamb2 = torch.optim.SGD(
[self.params2], lr=training_config.lr_time_2
)
self.prior_timesteps = training_config.prior_timesteps
self.match_prior = training_config.match_prior
# Additional attributes
self.solver_extra_params = model_config.solver_extra_params or {}
self.lambda_min = self.noise_schedule.lambda_min
self.lambda_max = self.noise_schedule.lambda_max
self.time_max = self.noise_schedule.inverse_lambda(self.lambda_min)
self.time_min = self.noise_schedule.inverse_lambda(self.lambda_max)
# Initialize baseline
self._compute_baseline()
# Initialize loss function
self.loss_type = training_config.loss_type
self.loss_fn = self._initialize_loss_fn()
self.loss_vector = None
def _train_to_match_prior(self, prior_timesteps=None):
if prior_timesteps is None:
prior_timesteps = self.prior_timesteps
if prior_timesteps is None:
return
logging.info(f"Matching prior timesteps")
prior_timesteps = self.noise_schedule.inverse_lambda(-np.log(prior_timesteps)).to(self.device).float()
dis_model = discretize_model_wrapper(
self.params1,
self.params2,
self.lambda_max,
self.lambda_min,
self.noise_schedule,
self.time_mode,
self.win_rate,
)
self.params1.requires_grad = True
self.params2.requires_grad = False
loss_time = float("inf")
while loss_time > 1e-3:
self.optimizer_lamb1.zero_grad()
self.optimizer_lamb2.zero_grad()
times1, times2 = dis_model()
loss_time = (times1 - prior_timesteps).pow(2).mean()
logging.info(f"Loss time: {loss_time}")
loss_time.backward()
self.optimizer_lamb1.step()
def _initialize_loss_fn(self):
if self.loss_type == 'LPIPS':
return lpips.LPIPS(net='vgg').to(self.device)
elif self.loss_type == 'L2':
return lambda x, y : compute_distance_between_two(x, y, self.channels, self.resolution)
elif self.loss_type == 'L1':
return lambda x, y: compute_distance_between_two_L1(x, y, self.channels, self.resolution)
else:
raise NotImplementedError
def _initialize_params(self):
params1 = torch.nn.Parameter(torch.ones(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True)
params2 = torch.nn.Parameter(torch.zeros(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True)
return params1, params2
def _set_device(self, device):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _create_valid_loaders(self):
self.valid_loader = DataLoader(self.valid_data, batch_size=self.train_batch_size, shuffle=False, collate_fn=custom_collate_fn)
self.valid_only_loader = DataLoader(self.valid_data, batch_size=self.valid_batch_size, shuffle=False, collate_fn=custom_collate_fn)
def _create_train_loader(self):
self.train_loader = DataLoader(self.train_data, batch_size=self.train_batch_size, shuffle=True, collate_fn=custom_collate_fn)
def _solve_ode(self, timesteps=None, img=None, latent=None, condition=None, uncondition=None, valid=False):
batch_size = latent.shape[0]
latent = latent.reshape(batch_size, self.channels, self.resolution, self.resolution)
dis_model = discretize_model_wrapper(
self.params1,
self.params2,
self.lambda_max,
self.lambda_min,
self.noise_schedule,
self.time_mode,
self.win_rate,
)
if timesteps is None:
timesteps1, timesteps2 = dis_model()
else:
timesteps1 = timesteps
timesteps2 = timesteps
if not valid and timesteps is None:
tst = torch.cat([timesteps1, timesteps2], dim=0).detach().cpu()
torch.save(tst, os.path.join(self.snapshot_path, f"t_steps.pt"))
self.t_steps1 = timesteps1.detach()
self.t_steps2 = timesteps2.detach()
lamb1 = self.noise_schedule.marginal_lambda(timesteps1)
lamb2 = self.noise_schedule.marginal_lambda(timesteps2)
self.logSNR1 = lamb1.detach().cpu()
self.logSNR2 = lamb2.detach().cpu()
x_next_ = self.noise_schedule.prior_transformation(latent) # bs x 3 x 32 x 32
x_next_ = self.solver.sample_simple(
model_fn=self.net,
x=x_next_,
timesteps=timesteps1,
timesteps2=timesteps2,
order=self.order,
NFEs=self.steps,
condition=condition,
unconditional_condition=uncondition,
**self.solver_extra_params,
)
x_next_ = self.decoding_fn(x_next_)
self.loss_vector = self.loss_fn(img.float(), x_next_.float()).squeeze()
loss = self.loss_vector.mean()
logging.info(f"{self._current_version} Loss: {loss.item()}")
return loss, x_next_.float(), img.float()
@property
def _current_version(self):
return 'Ver1' if self._is_in_version_1() else 'Ver2'
def _is_in_version_1(self):
return self.cur_round < self.training_rounds_v1
def _compute_baseline(self):
self.straight_line = torch.linspace(self.lambda_min, self.lambda_max, self.steps + 1)
self.time_logSNR = self.noise_schedule.inverse_lambda(self.straight_line).to(self.device)
time_max = self.noise_schedule.inverse_lambda(self.lambda_min)
time_min = self.noise_schedule.inverse_lambda(self.lambda_max)
self.time_s = torch.linspace(time_max.item(), time_min.item(), 1000)
self.time_straight = torch.linspace(time_max.item(), time_min.item(), self.steps + 1)
self.time_straight = self.time_straight.to(self.device)
self.straight_time = self.noise_schedule.marginal_lambda(self.time_s)
t_order = 2
self.time_q = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), 1000)**t_order
self.quadratic_time = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), self.steps + 1)**t_order
self.quadratic_time = self.quadratic_time.to(self.device)
self.time_quadratic = self.noise_schedule.marginal_lambda(self.time_q)
# time_edm
self.time_edm = self.solver.get_time_steps('edm', time_max.item(), time_min.item(), 999, self.device)
self.lambda_edm = self.noise_schedule.marginal_lambda(self.time_edm)
def _run_validation(self):
total_loss = 0.
count = 0
outputs = list()
targets = list()
with torch.no_grad():
for img, latent, ori_latent, condition, uncondition in self.valid_only_loader:
# condition = condition.squeeze()
# uncondition = uncondition.squeeze()
img = img.to(self.device)
latent = latent.to(self.device).reshape(latent.shape[0], -1)
ori_latent = ori_latent.to(self.device).reshape(latent.shape[0], -1)
if condition is not None:
condition = condition.to(self.device)
if uncondition is not None:
uncondition = uncondition.to(self.device)
loss, output, target = self._solve_ode(img=img, latent=latent, condition=condition, uncondition=uncondition, valid=True)
total_loss += loss.item()
count += 1
outputs.append(output)
targets.append(target)
output = torch.cat(outputs, dim=0)
target = torch.cat(targets, dim=0)
return total_loss / count, output, target
def _visual_times(self) -> None:
"""
Visualize time discretization of baselines and ours
"""
log_path = os.path.join(self.snapshot_path, f"log_best_{self.cur_iter}.png")
plt.plot(self.logSNR1.cpu().numpy(), 'o', label="Our discretization1")
plt.plot(self.logSNR2.cpu().numpy(), 'x', label="Our discretization2")
x_axis = np.linspace(0, self.steps, self.steps + 1)
plt.plot(x_axis, self.straight_line.cpu().numpy(), label="Baseline logSNR")
x_axis = np.linspace(0, self.steps, 1000)
plt.plot(x_axis, self.straight_time.cpu().numpy(), label="Baseline time uniform")
plt.plot(x_axis, self.time_quadratic.cpu().numpy(), label="Baseline time quadratic")
plt.plot(x_axis, self.lambda_edm.cpu().numpy(), label="Baseline time edm")
# draw a horizontal line at low_t_lambda
plt.xlabel("Reverse step i")
plt.ylabel("LogSNR(t_i)")
plt.legend()
plt.tight_layout()
plt.savefig(log_path)
plt.close()
def _save_checkpoint(self):
snapshot = {}
snapshot["params1"] = self.params1.data
snapshot["params2"] = self.params2.data
snapshot["best_t_steps"] = torch.cat([self.t_steps1, self.t_steps2], dim=0)
if self._is_in_version_1():
torch.save(snapshot, os.path.join(self.snapshot_path, "best_v1.pt"))
torch.save(snapshot, os.path.join(self.snapshot_path, "best_v2.pt"))
torch.save(snapshot, os.path.join(self.snapshot_path, f"best_t_steps_{self.cur_iter}.pt"))
# save dataloader, valid_loader, valid_only_loader
pickle.dump(self.train_data, open(os.path.join(self.snapshot_path, "train_data.pkl"), "wb"))
pickle.dump(self.valid_data, open(os.path.join(self.snapshot_path, "valid_data.pkl"), "wb"))
def _load_checkpoint(self, reload_data:bool):
if self._is_in_version_1():
snapshot = torch.load(os.path.join(self.snapshot_path, "best_v1.pt"))
else:
snapshot = torch.load(os.path.join(self.snapshot_path, "best_v2.pt"))
self.params1.data = snapshot["params1"].cuda()
self.params2.data = snapshot["params2"].cuda()
if reload_data:
self.train_data = pickle.load(open(os.path.join(self.snapshot_path, "train_data.pkl"), "rb"))
self.valid_data = pickle.load(open(os.path.join(self.snapshot_path, "valid_data.pkl"), "rb"))
self._create_train_loader()
self._create_valid_loaders()
def _examine_checkpoint(self, iter: int) -> None:
logging.info(f"{self._current_version} Saving snapshot at iter {iter}")
total_loss, output, target = self._run_validation()
if (iter % 5 == 0 or total_loss < self.best_loss) and self.visualize:
visual(torch.cat([output[:8], target[:8]], dim=0), os.path.join(self.snapshot_path, f"learned_newnoise_ep{iter}.png"), img_resolution=self.resolution)
if total_loss < self.best_loss: # latent cua valid k doi trong luc train.
self.best_loss = total_loss
self.count_worse = 0
self._save_checkpoint()
self._visual_times()
save_gif(self.snapshot_path)
else:
self.count_worse += 1
logging.info(f"{self._current_version} Count worse: {self.count_worse}")
logging.info(f"{self._current_version} Validation loss: {total_loss}, best loss: {self.best_loss}")
logging.info(f"{self._current_version} Iter {iter} snapshot saved!")
if self.count_worse >= self.patient:
logging.info(f"{self._current_version} Loading best model")
self._load_checkpoint(reload_data=True)
self.count_worse = 0
self.optimizer_lamb1.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb1.param_groups[0]['lr'], self.min_lr_time_1)
logging.info(f"{self._current_version} Decay time1 lr to {self.optimizer_lamb1.param_groups[0]['lr']}")
if self._is_in_version_1():
if self.optimizer_lamb1.param_groups[0]['lr'] <= self.min_lr_time_1:
self.count_min_lr1_hit += 1
else:
self.optimizer_lamb2.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb2.param_groups[0]['lr'], self.min_lr_time_2)
logging.info(f"{self._current_version} Decay time2 lr to {self.optimizer_lamb2.param_groups[0]['lr']}")
if self.optimizer_lamb2.param_groups[0]['lr'] <= self.min_lr_time_2:
self.count_min_lr2_hit += 1
def _set_trainable_params(self, is_train:bool, is_no_v1:bool)->None:
if is_train:
self.params1.requires_grad = True
self.params2.requires_grad = not self._is_in_version_1()
if is_no_v1:
self.params1.requires_grad = False
self.params2.requires_grad = True
else:
self.params1.requires_grad = False
self.params2.requires_grad = False
def _log_valid_distance(self, ori_latent: torch.tensor, latent: torch.tensor):
assert ori_latent.shape == latent.shape, "Shape of ori_latent and latent mismatched"
sq = (latent.reshape(latent.shape[0], -1) - ori_latent.reshape(latent.shape[0], -1)).pow(2)
distances = sq.sum(dim=1).sqrt().detach().cpu().numpy()
logging.info(f"{self._current_version} Distance: {distances}")
def _update_dataloader(self, ori_latents:List[torch.tensor],
latents:List[torch.tensor],
targets:List[torch.tensor],
conditions: List[Optional[torch.tensor]],
unconditions: List[Optional[torch.tensor]],
is_train:bool):
custom_train_dataset = LD3Dataset(ori_latents, latents, targets, conditions, unconditions)
if is_train:
self.train_data = custom_train_dataset
self._create_train_loader()
else:
self.valid_data = custom_train_dataset
self._create_valid_loaders()
def _update_latents(self, latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, prior_bound):
parameter_data_detached = latent_params.detach()
cloned_ori_latent = ori_latent.clone()
diff = parameter_data_detached.data - cloned_ori_latent
diff_norm = diff.norm(dim=1, keepdim=True)
pass_bound = diff_norm > prior_bound
pass_bound = pass_bound.flatten()
parameter_data_detached.data[pass_bound] = cloned_ori_latent[pass_bound] + prior_bound * diff[pass_bound] / diff_norm[pass_bound]
_, _, _ = self._solve_ode(img=img, latent=parameter_data_detached.data, condition=condition, uncondition=uncondition, valid=False)
to_update_mask = self.loss_vector < loss_vector_ref
parameter_data_detached.data = parameter_data_detached.data.reshape(-1, self.channels, self.resolution, self.resolution)
latent[to_update_mask] = parameter_data_detached.data[to_update_mask]
return latent, to_update_mask
def _train_one_round(self):
no_change = True
logging.info(f"{self._current_version} Round {self.cur_round}")
if self.cur_round > 0:
self._load_checkpoint(reload_data=False)
self.count_worse = 0
self._examine_checkpoint(self.cur_iter) # run evaluation current latent and time steps
for loader_idx, loader in enumerate([self.train_loader, self.valid_loader]):
if loader_idx == 1 and self.prior_bound == 0.0:
continue
self._set_trainable_params(is_train=loader_idx == 0, is_no_v1=self.no_v1)
ori_latents, latents, targets, conditions, unconditions = [], [], [], [], []
for img, latent, ori_latent, condition, uncondition in loader:
img, latent, ori_latent, condition, uncondition = move_tensor_to_device(img, latent, ori_latent, condition, uncondition, device=self.device)
if loader_idx == 1:
self._log_valid_distance(ori_latent, latent)
# Flattent latents
batch_size = ori_latent.shape[0]
ori_latent = ori_latent.reshape(batch_size, -1)
latent_to_update = latent.clone().detach().reshape(batch_size, -1).to(self.device)
latent_params = torch.nn.Parameter(latent_to_update)
latent_params.requires_grad = True
latent_optimizer = torch.optim.SGD([latent_params], lr=self.shift_lr)
if img.device != latent_params.device:
breakpoint()
loss, _, _ = self._solve_ode(img=img, latent=latent_params, condition=condition, uncondition=uncondition, valid=False)
loss_vector_ref = self.loss_vector.clone().detach()
loss.backward()
logging.info(f"{self._current_version} Iter {self.cur_iter} {'Train' if loader_idx == 0 else 'Val'} Loss: {loss.item()}")
latent_optimizer.step()
latent_optimizer.zero_grad()
if loader_idx == 0:
torch.nn.utils.clip_grad_norm_(self.params1, 1.0)
torch.nn.utils.clip_grad_norm_(self.params2, 1.0)
self.optimizer_lamb1.step()
self.optimizer_lamb1.zero_grad()
self.optimizer_lamb2.step()
self.optimizer_lamb2.zero_grad()
self.cur_iter += 1
self._examine_checkpoint(self.cur_iter) # evaluate
if self.count_min_lr2_hit >= self.lr2_patient:
logging.info(f"{self._current_version} Reach min lr2 5 times. Stop training.")
return no_change, True
with torch.no_grad():
latent, to_update_mask = self._update_latents(latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, self.prior_bound)
if loader_idx == 1 and to_update_mask.sum().item() > 0:
# check if this valid latent is moved
no_change = False
ori_latent = ori_latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu()
latent = latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu()
img = img.detach().cpu()
condition = condition.detach().cpu() if condition is not None else None
uncondition = uncondition.detach().cpu() if uncondition is not None else None
for j in range(latent.shape[0]):
ori_latents.append(ori_latent[j])
targets.append(img[j])
latents.append(latent[j])
conditions.append(condition[j] if condition is not None else None)
unconditions.append(uncondition[j] if uncondition is not None else None)
# update dataset
if self.prior_bound > 0:
self._update_dataloader(ori_latents, latents, targets, conditions, unconditions, is_train=loader_idx==0)
return no_change, False
def train(self, training_rounds_v1: int, training_rounds_v2: int) -> None:
total_round = training_rounds_v1 + training_rounds_v2
self.training_rounds_v1 = training_rounds_v1
if self.match_prior:
self._train_to_match_prior()
while self.cur_round < total_round:
no_latent_change, should_stop = self._train_one_round()
if should_stop:
return
self.cur_round += 1
if no_latent_change and self.prior_bound > 0:
self.shift_lr *= self.shift_lr_decay
logging.info(f"{self._current_version} Max round reached, stopping")
def discretize_model_wrapper(input1, input2, lambda_max, lambda_min, noise_schedule, mode, window_rate=0.5):
'''
checked!
'''
def model_time_fn():
time1, time2 = input1, input2
t_max, t_min = noise_schedule.inverse_lambda(lambda_min).to(time1.device), noise_schedule.inverse_lambda(lambda_max).to(time1.device)
time_plus = torch.nn.functional.softmax(time1, dim=0)
time_md = torch.cumsum(time_plus, dim=0).flip(0)
normed = (time_md - time_md[-1]) / (time_md[0] - time_md[-1])
time_steps = normed * (t_max - t_min) + t_min
cloned_time_steps = time_steps.clone().detach()
max_move = (cloned_time_steps[1:] - cloned_time_steps[:-1]).abs().min().item() * window_rate
clipped_time2 = torch.clamp(time2, min=-max_move, max=max_move)
mask = torch.ones_like(normed)
mask[0] = 0.
mask[-1] = 0.
return time_steps, time_steps + (clipped_time2 * mask)
def model_lambda_fn():
lambda1, lambda2 = input1, input2
lamb_plus = F.softmax(lambda1, dim=0)
lamb_md = torch.cumsum(lamb_plus, dim=0)
normed = (lamb_md - lamb_md.min()) / (lamb_md.max() - lamb_md.min())
lamb_steps1 = normed * (lambda_max - lambda_min) + lambda_min
mask = torch.ones_like(lamb_steps1)
cloned_lamb1 = lambda1.clone().detach()
max_move = (cloned_lamb1[1:] - cloned_lamb1[:-1]).abs().min().item() * window_rate
clipped_lamb2 = torch.clamp(lambda2, min=-max_move, max=max_move)
mask[0] = 0.
mask[-1] = 0.
lamb_steps2 = lamb_steps1 + clipped_lamb2 * mask
time1 = noise_schedule.inverse_lambda(lamb_steps1)
time2 = noise_schedule.inverse_lambda(lamb_steps2)
return time1, time2
return model_time_fn if mode == 'time' else model_lambda_fn