-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathlifelong_vae.py
956 lines (827 loc) · 44.3 KB
/
lifelong_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
import os
import sys
import datetime
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.distributions as distributions
# from tensorflow.python.training.moving_averages import weighted_moving_average
from reparameterizations import gumbel_reparmeterization, gaussian_reparmeterization
from encoders import forward, DenseEncoder, CNNEncoder, copy_layer, reinit_last_layer
from decoders import CNNDecoder
from utils import *
sg = tf.contrib.bayesflow.stochastic_graph
st = tf.contrib.bayesflow.stochastic_tensor
sys.setrecursionlimit(200)
# Global variables
GLOBAL_ITER = 0 # keeps track of the iteration ACROSS models
TRAIN_ITER = 0 # the iteration of the current model
class VAE(object):
""" Online Variational Autoencoder with consistent sampling.
See "Auto-Encoding Variational Bayes" by Kingma and Welling
for more details on the original work.
Note: reparam_type is fixed; provided in ctor due to compatibility
with the vanilla VAE
"""
def __init__(self, sess, x, input_size, batch_size, latent_size,
encoder, decoder, is_training, discrete_size, activation=tf.nn.elu,
reconstr_loss_type="binary_cross_entropy", learning_rate=1e-4,
submodel=0, total_true_models=0, vae_tm1=None,
p_x_given_z_func=distributions.Bernoulli,
base_dir=".", mutual_info_reg=0.0, img_shape=[28, 28, 1]):
self.x = x
self.activation = activation
self.learning_rate = learning_rate
self.is_training = is_training
self.encoder_model = encoder
self.decoder_model = decoder
self.vae_tm1 = vae_tm1
self.p_x_given_z_func = p_x_given_z_func
self.global_iter_base = GLOBAL_ITER
self.input_size = input_size
self.latent_size = latent_size
self.batch_size = batch_size
self.img_shape = img_shape
self.iteration = 0
self.test_epoch = 0
self.submodel = submodel
self.total_true_models = total_true_models
self.mutual_info_reg = mutual_info_reg
self.reconstr_loss_type = reconstr_loss_type
self.num_discrete = discrete_size
print 'latent size = ', self.latent_size, ' | disc size = ', self.num_discrete
self.base_dir = base_dir # dump all our stuff into this dir
# gumbel params
self.tau0 = 1.0
self.tau_host = self.tau0
self.anneal_rate = 0.00003
# self.anneal_rate = 0.0003 #1e-5
self.min_temp = 0.5
# sess & graph
self.sess = sess
# self.graph = tf.Graph()
# create these in scope
self._create_variables(x)
# Create autoencoder network
self._create_network()
# Define loss function based variational upper-bound and
# corresponding optimizer
self._create_loss_optimizer()
# create the required directories to hold data for this specific model
self._create_local_directories()
# Create all the summaries and their corresponding ops
self._create_summaries()
# Check for NaN's
# self.check_op = tf.add_check_numerics_ops()
# collect variables & build saver
self.vae_vars = [v for v in tf.global_variables()
if v.name.startswith(self.get_name())]
self.vae_local_vars = [v for v in tf.local_variables()
if v.name.startswith(self.get_name())]
self.saver = tf.train.Saver(tf.global_variables()) # XXX: use local
self.init_op = tf.variables_initializer(self.vae_vars
+ self.vae_local_vars)
print 'model: ', self.get_name()
# print 'there are ', len(self.vae_vars), ' vars in ', \
# tf.get_variable_scope().name, ' out of a total of ', \
# len(tf.global_variables()), ' with %d total trainable vars' \
# % len(tf.trainable_variables())
'''
Helper to create the :
1) experiment_%d/models directory
2) experiment_%d/imgs directory
3) experiment_%d/logs directory
'''
def _create_local_directories(self):
models_dir = '%s/models' % (self.base_dir)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
imgs_dir = '%s/imgs' % (self.base_dir)
if not os.path.exists(imgs_dir):
os.makedirs(imgs_dir)
logs_dir = '%s/logs' % (self.base_dir)
if not os.path.exists(logs_dir):
os.makedirs(logs_dir)
def _create_variables(self, x_placeholder):
with tf.variable_scope(self.get_name()):
# Create the placeholders if we are at the first model
# Else simply pull the references
# if self.submodel == 0:
# self.x = tf.placeholder(tf.float32, shape=[self.batch_size,
# self.input_size],
# name="input_placeholder")
# else:
# self.x = self.vae_tm1.x
# gpu iteration count
self.iteration_gpu = tf.Variable(0.0, trainable=False)
self.iteration_gpu_op = self.iteration_gpu.assign_add(1.0)
# gumbel related
self.tau = tf.Variable(5.0, trainable=False, dtype=tf.float32,
name="temperature")
# self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
'''
A helper function to create all the summaries.
Adds things like image_summary, histogram_summary, etc.
'''
def _create_summaries(self):
# Summaries and saver
summaries = [tf.summary.scalar("vae_loss_mean", self.cost_mean),
tf.summary.scalar("vae_negative_elbo", self.elbo_mean),
tf.summary.scalar("vae_latent_loss_mean", self.latent_loss_mean),
tf.summary.scalar("vae_grad_norm", self.grad_norm),
tf.summary.scalar("bits_per_dim", self.generate_bits_per_dim()),
tf.summary.scalar("vae_selected_class", tf.argmax(tf.reduce_sum(self.z_pre_gumbel, 0), 0)),
tf.summary.scalar("vae_selected_class_xtm1", tf.argmax(tf.reduce_sum(self.z_pre_gumbel[self.num_current_data:], 0), 0)),
tf.summary.histogram("vae_kl_normal", self.kl_normal),
tf.summary.histogram("vae_kl_discrete", self.kl_discrete),
tf.summary.histogram("vae_latent_dist", self.latent_kl),
tf.summary.scalar("vae_latent_loss_max", tf.reduce_max(self.latent_kl)),
tf.summary.scalar("vae_latent_loss_min", tf.reduce_min(self.latent_kl)),
tf.summary.scalar("vae_reconstr_loss_mean", self.reconstr_loss_mean),
tf.summary.scalar("vae_reconstr_loss_max", tf.reduce_max(self.reconstr_loss)),
tf.summary.scalar("vae_reconstr_loss_min", tf.reduce_min(self.reconstr_loss)),
tf.summary.histogram("z_dist", self.z)]
# Display image summaries : i.e. samples from P(X|Z=z_i)
# Visualize:
# 1) augmented images;
# 2) original images[current distribution]
# 3) reconstructed images
dimensions = len(shp(self.x))
if dimensions == 2:
x_orig, x_aug, x_reconstr = shuffle_jointly(self.x, self.x_augmented, # noqa
self.p_x_given_z.mean())
else:
# TODO: modify shuffle jointly for 3d images
x_orig, x_aug, x_reconstr = [self.x,
self.x_augmented,
self.p_x_given_z.mean()]
img_shp = [self.batch_size] + self.img_shape
image_summaries = [tf.summary.image("x_augmented_t", tf.reshape(x_aug, img_shp), # noqa
max_outputs=self.batch_size),
tf.summary.image("x_t", tf.reshape(x_orig, img_shp),
max_outputs=self.batch_size),
tf.summary.image("x_reconstr_mean_activ_t",
tf.reshape(x_reconstr, img_shp),
max_outputs=self.batch_size)]
# In addition show the following if they exist:
# 4) Images from previous interval
# 5) Distilled KL Divergence
if hasattr(self, 'xhat_tm1'):
with tf.variable_scope(self.get_name()): # accuracy operator
# selected_classes_for_xtm1 = tf.argmax(self.z_discrete[self.num_current_data:], 0)
# selected_classes_by_vae_tm1 = tf.argmax(self.q_z_t_given_x_t, 0)
selected_classes_for_xtm1 = self.z_pre_gumbel[self.num_current_data:] # self.z_discrete[self.num_current_data:]
selected_classes_by_vae_tm1 = self.q_z_t_given_x_t
correct_prediction = tf.equal(tf.argmax(selected_classes_by_vae_tm1, 1),
tf.argmax(selected_classes_for_xtm1, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
num_xhat_tm1 = self.xhat_tm1.get_shape().as_list()
image_summaries += [tf.summary.image("xhat_tm1",
tf.reshape(self.xhat_tm1, img_shp),
max_outputs=num_xhat_tm1[0])]
summaries += [tf.summary.scalar("vae_tm1_selected_class", tf.argmax(tf.reduce_sum(self.q_z_t_given_x_t, 0), 0)),
tf.summary.scalar("consistency_accuracy", self.accuracy),
tf.summary.scalar("vae_kl_distill_mean",
tf.reduce_mean(self.kl_consistency))]
# Merge all the summaries, but ensure we are post-activation
# keep the image summaries separate, but also include the regular
# summaries in them
#with tf.control_dependencies([self.p_x_given_z_logits]):
self.summaries = tf.summary.merge(summaries)
self.image_summaries = tf.summary.merge(image_summaries
+ summaries)
# Write all summaries to logs, but VARY the model name AND add a TIMESTAMP
# current_summary_name = self.get_name() + self.get_formatted_datetime()
self.train_summary_writer = tf.summary.FileWriter("%s/logs/train" % self.base_dir,
self.sess.graph,
flush_secs=60)
self.test_summary_writer = tf.summary.FileWriter("%s/logs/test" % self.base_dir,
self.sess.graph,
flush_secs=60)
def generate_bits_per_dim(self):
num_pixels = np.prod(self.img_shape[1:])
batch_size = self.img_shape[0]
return self.elbo_mean / (np.log(2.) * num_pixels *
batch_size)
'''
A helper function to format the name as a function of the hyper-parameters
'''
def get_name(self):
if self.submodel == 0:
full_hash_str = self.activation.__name__ \
+ '_enc' + str(self.encoder_model.get_sizing()) \
+ '_dec' + str(self.decoder_model.get_sizing()) \
+ "_learningrate" + str(self.learning_rate) \
+ "_latentsize" + str(self.latent_size) \
+ "_discsize" + str(self.num_discrete) \
+ "_mutintoreg" + str(self.mutual_info_reg)
full_hash_str = full_hash_str.strip().lower().replace('[', '') \
.replace(']', '') \
.replace(' ', '') \
.replace('{', '') \
.replace('}', '') \
.replace(',', '_') \
.replace(':', '') \
.replace('(', '') \
.replace(')', '') \
.replace('\'', '')
return 'vae%d_' % self.submodel + full_hash_str
else:
vae_tm1_name = self.vae_tm1.get_name()
indexof = vae_tm1_name.find('_')
return 'vae%d_' % self.submodel + vae_tm1_name[indexof+1:]
def get_formatted_datetime(self):
return str(datetime.datetime.now()).replace(" ", "_") \
.replace("-", "_") \
.replace(":", "_")
def save(self):
model_filename = "%s/models/%s.cpkt" % (self.base_dir, self.get_name())
print 'saving vae model to %s...' % model_filename
self.saver.save(self.sess, model_filename)
def restore(self):
model_filename = "%s/models/%s.cpkt" % (self.base_dir, self.get_name())
print 'into restore, model name = ', model_filename
if os.path.isfile(model_filename):
print 'restoring vae model from %s...' % model_filename
self.saver.restore(self.sess, model_filename)
@staticmethod
def kl_categorical(p=None, q=None, p_logits=None, q_logits=None, eps=1e-6):
'''
Given p and q (as EITHER BOTH logits or softmax's)
then this func returns the KL between them.
Utilizes an eps in order to resolve divide by zero / log issues
'''
if p_logits is not None and q_logits is not None:
Q = distributions.Categorical(logits=q_logits, dtype=tf.float32)
P = distributions.Categorical(logits=p_logits, dtype=tf.float32)
elif p is not None and q is not None:
print 'p shp = ', p.get_shape().as_list(), \
' | q shp = ', q.get_shape().as_list()
Q = distributions.Categorical(probs=q+eps, dtype=tf.float32)
P = distributions.Categorical(probs=p+eps, dtype=tf.float32)
else:
raise Exception("please provide either logits or dists")
return distributions.kl_divergence(P, Q)
@staticmethod
def zero_pad_smaller_cat(cat1, cat2):
c1shp = cat1.get_shape().as_list()
c2shp = cat2.get_shape().as_list()
diff = abs(c1shp[1] - c2shp[1])
# blend in extra zeros appropriately
if c1shp[1] > c2shp[1]:
cat2 = tf.concat([cat2, tf.zeros([c2shp[0], diff])], axis=1)
elif c2shp[1] > c1shp[1]:
cat1 = tf.concat([cat1, tf.zeros([c1shp[0], diff])], axis=1)
return [cat1, cat2]
def _create_constraints(self):
# 0.) add in a kl term between the old model's posterior
# and the current model's posterior using the
# data generated from the previous model [for the discrete ONLY]
#
# Recall data is : [current_data ; old_data]
if hasattr(self, 'xhat_tm1'):
# First we encode the generated data w/the student
# Note: encode returns z, z_normal, z_discrete,
# kl_normal, kl_discrete
# Note2: discrete dimension is self.submodel
self.q_z_s_given_x_t = self.z_pre_gumbel[self.num_current_data:]
assert self.q_z_s_given_x_t.get_shape().as_list()[0] \
== self.num_old_data
# _, _, self.q_z_s_given_x_t, _, _ \
# = self.encoder(self.xhat_tm1,
# rnd_sample=None,
# hard=False, # True?
# reuse=True)
# We also need to encode the data back through the teacher
# This is necessary because we need to evaluate the posterior
# in order to compare Q^T(x|z) against Q^S(x|z)
# Note2: discrete dimension is self.submodel - 1 [possibly?]
rnd_sample = self.rnd_sample[:, 0:self.vae_tm1.num_discrete]
_, _, _, self.q_z_t_given_x_t, _, _ \
= self.vae_tm1.encoder(self.xhat_tm1,
rnd_sample=rnd_sample,
hard=False, # True?
reuse=True)
# Get the number of gaussians for student and teacher
# We also only consider num_old_data of the batch
self.q_z_t_given_x_t = self.q_z_t_given_x_t[0:self.num_old_data]
self.q_z_s_given_x_t, self.q_z_t_given_x_t \
= VAE.zero_pad_smaller_cat(self.q_z_s_given_x_t,
self.q_z_t_given_x_t)
# Now we ONLY want eval the KL on the discrete z
# below is the reverse KL:
kl = self.kl_categorical(q=self.q_z_t_given_x_t,
p=self.q_z_s_given_x_t)
# forward KL :
# kl = self.kl_categorical(q=self.q_z_s_given_x_t,
# p=self.q_z_t_given_x_t)
print 'kl_consistency [prepad] : ', kl.get_shape().as_list()
kl = [tf.zeros([self.num_current_data]), kl]
self.kl_consistency = tf.concat(axis=0, values=kl)
else:
self.q_z_given_x = tf.zeros_like(self.x)
self.kl_consistency = tf.zeros([self.batch_size], dtype=tf.float32)
@staticmethod
def reparameterize(encoded, num_discrete, tau, hard=False,
rnd_sample=None, eps=1e-20):
eshp = encoded.get_shape().as_list()
print("encoded = ", eshp)
num_normal = eshp[1] - num_discrete
print 'num_normal = ', num_normal
logits_normal = encoded[:, 0:num_normal]
logits_gumbel = encoded[:, num_normal:eshp[1]]
# we reparameterize using both the N(0, I) and the gumbel(0, 1)
z_discrete, kl_discrete = gumbel_reparmeterization(logits_gumbel,
tau,
rnd_sample,
hard)
z_n, kl_n = gaussian_reparmeterization(logits_normal)
# merge and pad appropriately
z = tf.concat([z_n, z_discrete], axis=1)
return [slim.flatten(z),
slim.flatten(z_n),
slim.flatten(z_discrete),
slim.flatten(tf.nn.softmax(logits_gumbel)),
kl_n,
kl_discrete]
def encoder(self, X, rnd_sample=None, reuse=False, hard=False):
with tf.variable_scope(self.get_name() + "/encoder", reuse=reuse):
encoded = forward(X, self.encoder_model)
return VAE.reparameterize(encoded, self.num_discrete,
self.tau, hard=hard,
rnd_sample=rnd_sample)
def generator(self, Z, reuse=False):
with tf.variable_scope(self.get_name() + "/generator", reuse=reuse):
print 'generator scope: ', tf.get_variable_scope().name
logits = forward(Z, self.decoder_model)
if self.p_x_given_z_func == distributions.Bernoulli:
print 'generator: using bernoulli'
return self.p_x_given_z_func(logits=logits)
elif (self.p_x_given_z_func == distributions.Normal or self.p_x_given_z_func == distributions.Logistic) \
and self.encoder_model.layer_type == 'cnn':
print 'generator: using exponential family [cnn]'
channels = shp(logits)[3]
assert channels % 2 == 0, "need to project to 2x the channels for gaussian p(x|z)"
loc = logits[:, :, :, channels/2:] # tf.nn.sigmoid(logits[:, :, :, channels/2:])
scale = 1e-6 + tf.nn.softplus(logits[:, :, :, 0:channels/2])
return self.p_x_given_z_func(loc=loc,
scale=scale)
elif (self.p_x_given_z_func == distributions.Normal or self.p_x_given_z_func == distributions.Logistic) \
and self.encoder_model.layer_type == 'dnn':
print 'generator: using exponential family [dnn]'
features = shp(logits)[-1]
assert features % 2 == 0, "need to project to 2x the channels for gaussian p(x|z)"
loc = logits[:, features/2:] # tf.nn.sigmoid(logits[:, :, :, channels/2:])
scale = 1e-6 + tf.nn.softplus(logits[:, 0:features/2])
return self.p_x_given_z_func(loc=loc,
scale=scale)
else:
raise Exception("unknown distribution provided for likelihood")
def _augment_data(self):
'''
Augments [current_data ; old_data]
'''
def _train():
if hasattr(self, 'xhat_tm1'): # make sure we have forked
# zero pad the current data on the bottom and add to
# the data we generated in _generate_vae_tm1_data()
full_data = [self.x[0:self.num_current_data],
self.xhat_tm1[0:self.num_old_data]]
combined = tf.concat(axis=0, values=full_data,
name="current_data")
else:
combined = self.x
print 'augmented data = ', combined.get_shape().as_list()
return combined
def _test():
return self.x
#return tf.cond(self.is_training, _train, _test)
return _train()
def generate_at_least(self, vae_tm1, batch_size):
# Returns :
# 1) a categorical and a Normal distribution concatenated
# 2) x_hat_tm1 : the reconstructed data from the old model
print 'generating data from previous #discrete: ', vae_tm1.num_discrete
z_cat = generate_random_categorical(vae_tm1.num_discrete,
batch_size)
z_normal = tf.random_normal([batch_size, vae_tm1.latent_size])
z = tf.concat([z_normal, z_cat], axis=1)
zshp = z.get_shape().as_list() # TODO: debug trace
print 'z_generated = ', zshp
# Generate reconstructions of historical Z's
# xr = tf.stop_gradient(tf.nn.sigmoid(vae_tm1.generator(z, reuse=True)))
p_x_given_z_tm1 = vae_tm1.generator(z, reuse=True)
return [z, z_cat, p_x_given_z_tm1.mean()]
def _generate_vae_tm1_data(self):
if self.vae_tm1 is not None:
num_instances = self.x.get_shape().as_list()[0]
self.num_current_data = int((1.0/(self.total_true_models + 1.0))
* float(num_instances))
self.num_old_data = num_instances - self.num_current_data
# TODO: Remove debug trace
print 'total instances: %d | current_model: %d | current_true_models: %d | current data number: %d | old data number: %d'\
% (num_instances, self.submodel, self.total_true_models,
self.num_current_data, self.num_old_data)
if self.num_old_data > 0: # make sure we aren't in base case
# generate data by randomly sampling a categorical for
# N-1 positions; also sample a N(0, I) in order to
# generate variability
self.z_tm1, self.z_discrete_tm1, self.xhat_tm1 \
= self.generate_at_least(self.vae_tm1,
self.batch_size)
print 'z_tm1 = ', self.z_tm1.get_shape().as_list(), \
'| xhat_tm1 = ', self.xhat_tm1.get_shape().as_list()
@staticmethod
def _z_to_one_hot(z, latent_size):
indices = tf.arg_max(z, 1)
return tf.one_hot(indices, latent_size, dtype=tf.float32)
def _shuffle_all_data_together(self):
if not hasattr(self, 'shuffle_indices'):
self.shuffle_indices = np.random.permutation(self.batch_size)
if self.vae_tm1 is not None:
# we get the total size of the cols and jointly shuffle
# using the perms generated above.
self.x_augmented = shuffle_rows_based_on_indices(self.shuffle_indices,
self.x_augmented)
'''
Helper op to create the network structure
'''
def _create_network(self, num_test_memories=10):
self.num_current_data = self.x.get_shape().as_list()[0]
# use the same rnd_sample for all the discrete generations
self.rnd_sample = sample_gumbel([self.x.get_shape().as_list()[0],
self.num_discrete])
# generate & shuffle data together
self._generate_vae_tm1_data()
self.x_augmented = self._augment_data()
assert self.x_augmented.get_shape().as_list() \
== self.x.get_shape().as_list()
# TODO: self._shuffle_all_data_together() possible?
# run the encoder operation
self.z, \
self.z_normal,\
self.z_discrete, \
self.z_pre_gumbel, \
self.kl_normal, \
self.kl_discrete = self.encoder(self.x_augmented,
rnd_sample=self.rnd_sample)
print 'z_encoded = ', self.z.get_shape().as_list()
print 'z_discrete = ', self.z_discrete.get_shape().as_list()
print 'z_normal = ', self.z_normal.get_shape().as_list()
# reconstruct x via the generator & run activation
#self.p_x_given_z_logits = self.generator(self.z)
self.p_x_given_z = self.generator(self.z)
print 'pxgivenz = ', shp(self.p_x_given_z.mean())
# self.x_reconstr_mean_activ = tf.nn.sigmoid(self.x_reconstr_mean)
def _loss_helper(self, truth, pred):
if self.reconstr_loss_type == "binary_cross_entropy":
loss = self._cross_entropy(truth, pred)
else:
loss = self._l2_loss(truth, pred)
channels = truth.get_shape().as_list()
reduction_indices = [1, 2, 3] if len(channels) > 3 else [1]
return tf.reduce_sum(loss, reduction_indices)
@staticmethod
def _cross_entropy(x, x_reconstr):
# To ensure stability and avoid overflow, the implementation uses
# max(x, 0) - x * z + log(1 + exp(-abs(x)))
# return tf.maximum(x, 0) - x * z + tf.log(1.0 + tf.exp(-tf.abs(x)))
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_reconstr,
labels=x)
@staticmethod
def _l2_loss(x, x_reconstr):
return tf.square(x - x_reconstr)
@staticmethod
def mutual_information_bernouilli_cat(Q_z_given_x_softmax, eps=1e-9):
# we compute the mutual information term,
# which is the conditional entropy of the prior
# and our variational distribution, plus the entropy of our prior:
# first we build a uniform cat prior and sample it
qzshp = Q_z_given_x_softmax.get_shape().as_list()
# batch_size = qzshp[0]
# feature_size = qzshp[1]
# prior = tf.contrib.distributions.Categorical([1.0/feature_size]*feature_size)
# prior_sample = tf.one_hot(prior.sample(batch_size), feature_size, dtype=tf.float32)
prior_sample = generate_random_categorical(qzshp[1], qzshp[0])
cond_ent = tf.reduce_mean(-tf.reduce_sum(tf.log(Q_z_given_x_softmax + eps)
* prior_sample, 1))
ent = tf.reduce_mean(-tf.reduce_sum(tf.log(prior_sample + eps)
* prior_sample, 1))
return cond_ent + ent
# @staticmethod
# def mutual_information_bernouilli_cat(bern_logits, cat_probs, eps=1e-9):
# '''
# I(\hat{X} ; Z) = H(Z) - H(Z | \hat{X}) = H(\hat{X}) - H(\hat{X} | Z)
# '''
# p_x_given_z = distributions.Bernoulli(logits=bern_logits,
# dtype=tf.float32)
# q_z = distributions.Categorical(probs=cat_probs + eps,
# dtype=tf.float32)
# # TODO: debug traces
# # print 'q_z_entropy = ', q_z.entropy().get_shape().as_list()
# # print 'p_x_given_z.entropy() = ', p_x_given_z.entropy().get_shape().as_list()
# return q_z.entropy() - tf.reduce_sum(p_x_given_z.entropy(), 1)
def vae_loss(self, x, p_x_given_z, latent_kl, consistency_kl):
# the loss is composed of two terms:
# 1.) the reconstruction loss (the negative log probability
# of the input under the reconstructed bernoulli distribution
# induced by the decoder in the data space).
# this can be interpreted as the number of "nats" required
# for reconstructing the input when the activation in latent
# is given.
# log_likelihood = self._loss_helper(x, p_x_given_z.mean())
channels = x.get_shape().as_list()
reduction_indices = [1, 2, 3] if len(channels) > 3 else [1]
log_likelihood = tf.reduce_sum(self.p_x_given_z.log_prob(x),
reduction_indices)
mutual_info_regularizer \
= VAE.mutual_information_bernouilli_cat(self.z_pre_gumbel)
# 2.) the latent loss, which is defined as the kullback leibler divergence
# between the distribution in latent space induced by the encoder on
# the data and some prior. this acts as a kind of regularizer.
# this can be interpreted as the number of "nats" required
# for transmitting the the latent space distribution given
# the prior.
# kl_categorical(p=none, q=none, p_logits=none, q_logits=none, eps=1e-6):
# cost = reconstr_loss - latent_kl
elbo = -log_likelihood + latent_kl
cost = elbo + consistency_kl - self.mutual_info_reg * mutual_info_regularizer
# create the reductions only once
latent_loss_mean = tf.reduce_mean(latent_kl)
log_likelihood_mean = tf.reduce_mean(log_likelihood)
elbo_mean = tf.reduce_mean(elbo)
cost_mean = tf.reduce_mean(cost)
return [log_likelihood, log_likelihood_mean,
latent_loss_mean, cost, cost_mean,
elbo_mean]
def _create_loss_optimizer(self):
# build constraint graph
self._create_constraints()
with tf.variable_scope(self.get_name() + "/loss_optimizer"):
self.latent_kl = self.kl_normal + self.kl_discrete
# if self.submodel > 0:
# set the indexes[batch] of the latent kl to zero for the
# indices that we are constraining over as we are computing
# a regularizer in the above function
# zero_vals = [self.latent_kl[0:self.num_current_data],
# tf.zeros([self.num_old_data])]
# self.latent_kl = tf.concat(axis=0, values=zero_vals)
# tabulate total loss
self.reconstr_loss, self.reconstr_loss_mean, \
self.latent_loss_mean, \
self.cost, self.cost_mean, self.elbo_mean \
= self.vae_loss(self.x_augmented,
self.p_x_given_z,
self.latent_kl,
self.kl_consistency)
# construct our optimizer
#with tf.control_dependencies([self.p_x_given_z_logits]):
filtered = [v for v in tf.trainable_variables()
if v.name.startswith(self.get_name())]
self.optimizer = self._create_optimizer(filtered,
self.cost_mean,
self.learning_rate)
def _create_optimizer(self, tvars, cost, lr):
# optimizer = tf.train.rmspropoptimizer(self.learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
print 'there are %d trainable vars in cost %s\n' % (len(tvars), cost.name)
grads = tf.gradients(cost, tvars)
# DEBUG: exploding gradients test with this:
# for index in range(len(grads)):
# if grads[index] is not None:
# gradstr = "\n grad [%i] | tvar [%s] =" % (index, tvars[index].name)
# grads[index] = tf.Print(grads[index], [grads[index]], gradstr, summarize=100)
# grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.grad_norm = tf.norm(tf.concat([tf.reshape(t, [-1]) for t in grads],
axis=0))
return optimizer.apply_gradients(zip(grads, tvars))
# return tf.train.AdamOptimizer(learning_rate=lr).minimize(cost, var_list=tvars)
def partial_fit(self, inputs, iteration_print=10,
iteration_save_imgs=2000,
is_forked=False, summary="train"):
"""Train model based on mini-batch of input data.
Return cost of mini-batch.
"""
feed_dict = {self.x: inputs,
self.is_training: True if summary == "train" else False,
self.tau: self.tau_host}
if summary == "train":
writer = self.train_summary_writer
else:
writer = self.test_summary_writer
try:
# update tau for gumbel-softmax
if summary == "train" and self.iteration > 0 and self.iteration % 10 == 0:
rate = -self.anneal_rate*self.iteration
self.tau_host = np.maximum(self.tau0 * np.exp(rate),
self.min_temp)
print 'updated tau to ', self.tau_host
# full list of session ops
ops_to_run = [self.cost_mean, self.elbo_mean,
self.reconstr_loss_mean,
self.latent_loss_mean]
if summary == "train":
ops_to_run = [self.optimizer,
self.iteration_gpu_op] + ops_to_run
if iteration_save_imgs > 0 and self.iteration % iteration_save_imgs == 0:
# write images + summaries
if summary == "train":
_, _, cost, elbo, rloss, lloss, summary_img \
= self.sess.run(ops_to_run + [self.image_summaries],
feed_dict=feed_dict)
else:
cost, elbo, rloss, lloss, summary_img \
= self.sess.run(ops_to_run + [self.image_summaries],
feed_dict=feed_dict)
writer.add_summary(summary_img, self.iteration)
elif self.iteration % iteration_print == 0:
# write regular summaries
if summary == "train":
_, _, cost, elbo, rloss, lloss, summary \
= self.sess.run(ops_to_run + [self.summaries],
feed_dict=feed_dict)
else:
cost, elbo, rloss, lloss, summary \
= self.sess.run(ops_to_run + [self.summaries],
feed_dict=feed_dict)
writer.add_summary(summary, self.iteration)
else:
# write no summary
if summary == "train":
_, _, cost, elbo, rloss, lloss \
= self.sess.run(ops_to_run,
feed_dict=feed_dict)
else:
cost, elbo, rloss, lloss \
= self.sess.run(ops_to_run,
feed_dict=feed_dict)
except Exception as e:
print 'caught exception in partial fit: ', e
self.iteration += 1
return cost, elbo, rloss, lloss
def write_classes_to_file(self, filename, all_classes):
with open(filename, 'a') as f:
np.savetxt(f, self.sess.run(all_classes), delimiter=",")
def build_new_encoder_decoder_pair(self, num_new_classes=1):
updated_latent_size = 2*self.latent_size \
+ self.num_discrete \
+ num_new_classes
if self.encoder_model.layer_type is not 'cnn':
# increase the number of latent params
# if self.submodel % 4 == 0: # XXX
# print 'adding extra layer...'
# layer_sizes = self.encoder_model.sizes + [512]
# else:
# layer_sizes = self.encoder_model.sizes
layer_sizes = self.encoder_model.sizes
encoder = DenseEncoder(self.sess, updated_latent_size,
self.is_training,
scope="encoder",
sizes=layer_sizes,
use_ln=self.encoder_model.use_ln,
use_bn=self.encoder_model.use_bn)
is_dec_doubled = self.decoder_model.double_features > 1
decoder = DenseEncoder(self.sess, self.input_size,
self.is_training,
scope="decoder",
sizes=layer_sizes,
double_features=is_dec_doubled,
use_ln=self.decoder_model.use_ln,
use_bn=self.decoder_model.use_bn)
else:
encoder = CNNEncoder(self.sess, updated_latent_size,
self.is_training,
scope="encoder",
use_ln=self.encoder_model.use_ln,
use_bn=self.encoder_model.use_bn,)
decoder = CNNDecoder(self.sess,
scope="decoder",
double_channels=self.decoder_model.double_channels,
input_size=self.input_size,
is_training=self.is_training,
use_ln=self.decoder_model.use_ln,
use_bn=self.decoder_model.use_bn)
return encoder, decoder
def fork(self, num_new_class=1):
'''
Fork the current model by copying the model parameters
into the old ones.
Note: This is a slow op in tensorflow
because the session needs to be run
'''
encoder, decoder = self.build_new_encoder_decoder_pair(num_new_class)
print 'encoder = ', encoder.get_info()
print 'decoder = ', decoder.get_info()
vae_tp1 = VAE(self.sess, self.x,
input_size=self.input_size,
batch_size=self.batch_size,
latent_size=self.latent_size,
discrete_size=self.num_discrete + num_new_class,
encoder=encoder,
decoder=decoder,
p_x_given_z_func=self.p_x_given_z_func,
is_training=self.is_training,
activation=self.activation,
learning_rate=self.learning_rate,
submodel=self.submodel+1,
total_true_models=self.total_true_models+num_new_class,
vae_tm1=self,
img_shape=self.img_shape,
base_dir=self.base_dir)
# we want to reinit our weights and biases to their defaults
# after this we will copy the possible weights over
self.sess.run([vae_tp1.init_op]) # ,vae_tp1.init_local_op])
# copy the encoder and decoder layers
# this helps convergence time
copy_layer(self.sess, self.encoder_model, self.get_name(),
encoder, vae_tp1.get_name())
copy_layer(self.sess, self.decoder_model, self.get_name(),
decoder, vae_tp1.get_name())
return vae_tp1
def transform(self, X):
"""Transform data by mapping it into the latent space."""
# Note: This maps to mean of distribution, we could alternatively
# sample from Gaussian distribution
return self.sess.run(self.z, feed_dict={self.x: X,
self.tau: self.tau_host,
self.is_training: False})
def generate(self, z=None):
""" Generate data by sampling from latent space.
If z_mu is not None, data for this point in latent space is
generated. Otherwise, z_mu is drawn from prior in latent
space.
"""
if z is None:
z = generate_random_categorical(self.latent_size, self.batch_size)
# Note: This maps to mean of distribution, we could alternatively
# sample from Gaussian distribution
return self.sess.run(self.p_x_given_z.mean(),
feed_dict={self.z: z,
self.tau: self.tau_host,
self.is_training: False})
def reconstruct(self, X, return_losses=False):
""" Use VAE to reconstruct given data. """
if return_losses:
ops = [self.p_x_given_z.mean(),
self.reconstr_loss, self.reconstr_loss_mean,
self.latent_kl, self.latent_loss_mean,
self.cost, self.cost_mean, self.elbo_mean]
else:
ops = self.p_x_given_z.mean()
return self.sess.run(ops,
feed_dict={self.x: X,
self.tau: self.tau_host,
self.is_training: False})
def test(self, source, batch_size, iteration_save_imgs=10):
n_samples = source.num_examples
avg_cost = avg_elbo = avg_recon = avg_latent = 0.
total_batch = int(n_samples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs, _ = source.next_batch(batch_size)
# only save imgs if we are on the Nth test iteration
if self.test_epoch % iteration_save_imgs == 0:
iteration_save_imgs_pf = 1
else:
iteration_save_imgs_pf = -1
# Fit training using batch data
cost, elbo, recon_cost, latent_cost \
= self.partial_fit(batch_xs, summary="test",
iteration_print=1, # always print
iteration_save_imgs=iteration_save_imgs_pf)
# Compute average loss
avg_cost += cost / n_samples * batch_size
avg_elbo += elbo / n_samples * batch_size
avg_recon += recon_cost / n_samples * batch_size
avg_latent += latent_cost / n_samples * batch_size
# Display logs at the end of testing
self.test_epoch += 1
print "[Test]", \
"avg cost = ", "{:.4f} | ".format(avg_cost), \
"avg latent cost = ", "{:.4f} | ".format(avg_latent), \
"avg elbo loss = ", "{:.4f} | ".format(avg_elbo), \
"avg recon loss = ", "{:.4f}".format(avg_recon)
def train(self, source, batch_size, training_epochs=10, display_step=5):
n_samples = source.train.num_examples
for epoch in range(training_epochs):
avg_cost = avg_elbo = avg_recon = avg_latent = 0.
total_batch = int(n_samples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs, _ = source.train.next_batch(batch_size)
# Fit training using batch data
cost, elbo, recon_cost, latent_cost\
= self.partial_fit(batch_xs)
# Compute average loss
avg_cost += cost / n_samples * batch_size
avg_elbo += elbo / n_samples * batch_size
avg_recon += recon_cost / n_samples * batch_size
avg_latent += latent_cost / n_samples * batch_size
# Display logs per epoch step
if epoch % display_step == 0:
print "[Epoch:", '%04d]' % (epoch+1), \
"current cost = ", "{:.4f} | ".format(cost), \
"avg cost = ", "{:.4f} | ".format(avg_cost), \
"avg elbo = ", "{:.4f} | ".format(avg_elbo), \
"avg latent = ", "{:.4f} | ".format(avg_latent), \
"avg recon = ", "{:.4f}".format(avg_recon)