Skip to content

Commit

Permalink
Adding allow_growth option (deepfakes#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
Clorr authored Feb 14, 2018
1 parent 0085b5b commit 51f1993
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def parse_arguments(self, description, subparser, command):
type=int,
default=64,
help="Batch size, as a power of 2 (64, 128, 256, etc)")
parser.add_argument('-ag', '--allow-growth',
action="store_true",
dest="allow_growth",
default=False,
help="Sets allow_growth option of Tensorflow to spare memory on some configs")
parser.add_argument('-ep', '--epochs',
type=int,
default=1000000,
Expand Down Expand Up @@ -122,6 +127,9 @@ def process(self):
thr.join() # waits until thread finishes

def processThread(self):
if self.arguments.allow_growth:
self.set_tf_allow_growth()

print('Loading data, this may take a while...')
# this is so that you can enter case insensitive values for trainer
trainer = self.arguments.trainer
Expand Down Expand Up @@ -164,6 +172,14 @@ def processThread(self):
print(e)
exit(1)

def set_tf_allow_growth(self):
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list="0"
set_session(tf.Session(config=config))

preview_buffer = {}

def show(self, image, name=''):
Expand Down

0 comments on commit 51f1993

Please sign in to comment.