Skip to content

Commit

Permalink
some small timing changes
Browse files Browse the repository at this point in the history
  • Loading branch information
olimastro committed Jun 20, 2017
1 parent 87bf136 commit 188ddba
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
4 changes: 4 additions & 0 deletions example/synchronous_resnet/resnet_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def handle_control(self, req, worker_id, req_info):
self.test_history_errs += [[None for i in range(self.nb_worker)]]
self._epoch += 1

elif req == 'time':
print("Epoch time", time.time() - self.start_time)
control_response = 'stop'

elif req == 'splits':
# the controller never loads the dataset but the worker doesn't
# know how many workers there are
Expand Down
42 changes: 21 additions & 21 deletions example/synchronous_resnet/resnet_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import, print_function
from collections import OrderedDict
import time
import six
from six import iteritems
from six.moves import range
Expand Down Expand Up @@ -29,15 +30,11 @@ def load_data():
"""
create synthetic data
"""
def trgt_reshape(trgt):
return trgt.reshape((trgt.shape[0],1))

targets = numpy.arange(1000)
train_targets = trgt_reshape(numpy.repeat(targets, 2))
train_targets = numpy.random.randint(1000, size=(2048,1))
train_data = numpy.random.random((train_targets.shape[0],3,224,224))
valid_targets = trgt_reshape(numpy.repeat(targets, 1))
valid_targets = numpy.random.randint(1000, size=(1024,1))
valid_data = numpy.random.random((valid_targets.shape[0],3,224,224))
test_targets = trgt_reshape(numpy.repeat(targets, 1))
test_targets = numpy.random.randint(1000, size=(1024,1))
test_data = numpy.random.random((test_targets.shape[0],3,224,224))

rval = ([numpy_floatX(train_data), numpy_int32(train_targets)],
Expand Down Expand Up @@ -95,14 +92,6 @@ def sgd(lr, tparams, grads, x, y, cost):
broadcastable=infer_bc_pattern(p.get_value().shape))
for p in tparams]
gsup = [(gs, g) for gs, g in zip(gshared, grads)]
#import ipdb; ipdb.set_trace()
#for i, gpair in enumerate(gsup):
# g = gpair[0]
# u = gpair[1]
# if g.broadcastable != u.broadcastable:
# #gsup[i] = (tensor.patternbroadcast(g, u.broadcastable), u)
# u.broadcastable = g.broadcastable
#import ipdb; ipdb.set_trace()

# Function that computes gradients for a mini-batch, but do not
# updates the weights.
Expand Down Expand Up @@ -461,13 +450,13 @@ def pred_error(f_pred, data, iterator):


def train_resnet(
batch_size=12, # The batch size during training.
valid_batch_size=12, # The batch size used for validation/test set.
batch_size=64, # The batch size during training.
valid_batch_size=64, # The batch size used for validation/test set.
validFreq=5,
lrate=1e-4,
optimizer=sgd,
):

print(theano.config.profile)
# Each worker needs the same seed in order to draw the same parameters.
# This will also make them shuffle the batches the same way, but splits are
# different so doesnt matter
Expand Down Expand Up @@ -514,7 +503,7 @@ def train_resnet(

def train_iter():
while True:
kf = get_minibatches_idx(train[0].shape[0], batch_size, shuffle=True)
kf = get_minibatches_idx(train[0].shape[0], batch_size, shuffle=False)
for _, train_index in kf:
y = [train[1][t] for t in train_index]
x = [train[0][t] for t in train_index]
Expand All @@ -523,13 +512,24 @@ def train_iter():
train_it = train_iter()
nb_train = train[0].shape[0] // batch_size

# first pass in function so it doesnt bias the next time count
# because of the dnn flags
dummy_x = numpy_floatX(numpy.random.random((batch_size,3,224,224)))
dummy_y = numpy_int32(numpy.random.randint(1000, size=(batch_size,1)))
dumz = f_grad_shared(dummy_x, dummy_y)

epoch = 0
while True:
for i in range(nb_train):
x, y = next(train_it)
func_time = time.time()
cost = f_grad_shared(x, y)
f_update(lrate)
print("Func call time", time.time() - func_time)
overhead_time = time.time()
asgd()
print("Overhead time", time.time() - overhead_time)
res = worker.send_req('time')

print('Train cost:', cost)

Expand All @@ -549,8 +549,8 @@ def train_iter():
# should save the param at best
pass

if res == 'stop':
break
if res == 'stop':
break
epoch += 1

# Release all shared resources.
Expand Down

0 comments on commit 188ddba

Please sign in to comment.