diff --git a/baselines/ddpg/ddpg.py b/baselines/ddpg/ddpg.py index 4bbda692b9..3fb4f72ac6 100755 --- a/baselines/ddpg/ddpg.py +++ b/baselines/ddpg/ddpg.py @@ -42,6 +42,8 @@ def learn(network, env, tau=0.01, eval_env=None, param_noise_adaption_interval=50, + load_path = None, + save_path = '' **network_kwargs): set_global_seeds(seed) @@ -91,6 +93,9 @@ def learn(network, env, batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg, actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm, reward_scale=reward_scale) + + if load_path is not None: + agent.load(load_path) logger.info('Using agent with the following configuration:') logger.info(str(agent.__dict__.items())) @@ -269,5 +274,10 @@ def as_scalar(x): with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f: pickle.dump(eval_env.get_state(), f) + os.mkdirs(logdir,exist_ok=True) + savepath = os.path.join(save_path, str(epoch)) + print('Saving to ',savepath) + agent.save(savepath) + return agent diff --git a/baselines/ddpg/ddpg_learner.py b/baselines/ddpg/ddpg_learner.py index 3fc8a7744e..09c51216e6 100755 --- a/baselines/ddpg/ddpg_learner.py +++ b/baselines/ddpg/ddpg_learner.py @@ -1,6 +1,7 @@ from copy import copy from functools import reduce +import functools import numpy as np import tensorflow as tf import tensorflow.contrib as tc @@ -9,6 +10,7 @@ from baselines.common.mpi_adam import MpiAdam import baselines.common.tf_util as U from baselines.common.mpi_running_mean_std import RunningMeanStd +from baselines.common.tf_util import save_variables, load_variables try: from mpi4py import MPI except ImportError: @@ -98,6 +100,8 @@ def __init__(self, actor, critic, memory, observation_shape, action_shape, param self.batch_size = batch_size self.stats_sample = None self.critic_l2_reg = critic_l2_reg + self.save = None + self.load = None # Observation normalization. if self.normalize_observations: @@ -333,6 +337,8 @@ def train(self): def initialize(self, sess): self.sess = sess self.sess.run(tf.global_variables_initializer()) + self.save = functools.partial(save_variables, sess=self.sess) + self.load = functools.partial(load_variables, sess=self.load) self.actor_optimizer.sync() self.critic_optimizer.sync() self.sess.run(self.target_init_updates) diff --git a/baselines/run.py b/baselines/run.py index caf00e0a1c..451544523e 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -5,7 +5,7 @@ from collections import defaultdict import tensorflow as tf import numpy as np - +import MADRaS from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env