diff --git a/example/synchronous_resnet/resnet_worker.py b/example/synchronous_resnet/resnet_worker.py index 56a4960..ed56c99 100644 --- a/example/synchronous_resnet/resnet_worker.py +++ b/example/synchronous_resnet/resnet_worker.py @@ -444,6 +444,7 @@ def pred_error(f_pred, data, iterator): x = [data[0][t] for t in valid_index] y = [data[1][t] for t in valid_index] valid_err += f_pred(x, y) + f_pred.sync_shared() i += 1 return valid_err / i @@ -524,7 +525,9 @@ def train_iter(): x, y = next(train_it) func_time = time.time() cost = f_grad_shared(x, y) + cost.sync_shared() f_update(lrate) + f_update.sync_shared() print("Func call time", time.time() - func_time) overhead_time = time.time() asgd()