diff --git a/baselines/ddpg/ddpg.py b/baselines/ddpg/ddpg.py index 37551d4931..fab6befb6f 100755 --- a/baselines/ddpg/ddpg.py +++ b/baselines/ddpg/ddpg.py @@ -42,8 +42,12 @@ def learn(network, env, tau=0.01, eval_env=None, param_noise_adaption_interval=50, + load_path = None, + save_path = '', **network_kwargs): + print("Save PATH;{}".format(save_path)) + print("Load PATH;{}".format(load_path)) set_global_seeds(seed) if total_timesteps is not None: @@ -58,8 +62,7 @@ def learn(network, env, rank = 0 nb_actions = env.action_space.shape[-1] - assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions. - + #assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions. memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape) critic = Critic(network=network, **network_kwargs) actor = Actor(nb_actions, network=network, **network_kwargs) @@ -91,14 +94,19 @@ def learn(network, env, batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg, actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm, reward_scale=reward_scale) + logger.info('Using agent with the following configuration:') logger.info(str(agent.__dict__.items())) eval_episode_rewards_history = deque(maxlen=100) episode_rewards_history = deque(maxlen=100) - sess = U.get_session() # Prepare everything. + sess = U.get_session() agent.initialize(sess) + checkpoint_num = 0 + if load_path is not None: + agent.load(load_path) + checkpoint_num = int(os.path.split(load_path)[1]) + 1 sess.graph.finalize() agent.reset() @@ -124,6 +132,8 @@ def learn(network, env, epoch_actions = [] epoch_qs = [] epoch_episodes = 0 + if load_path is None: + os.makedirs(save_path, exist_ok=True) for epoch in range(nb_epochs): for cycle in range(nb_epoch_cycles): # Perform rollouts. @@ -269,5 +279,9 @@ def as_scalar(x): with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f: pickle.dump(eval_env.get_state(), f) + savepath = os.path.join(save_path, str(epoch+checkpoint_num)) + print('Saving to ', savepath) + agent.save(savepath) + return agent diff --git a/baselines/ddpg/ddpg_learner.py b/baselines/ddpg/ddpg_learner.py index 3fc8a7744e..09c51216e6 100755 --- a/baselines/ddpg/ddpg_learner.py +++ b/baselines/ddpg/ddpg_learner.py @@ -1,6 +1,7 @@ from copy import copy from functools import reduce +import functools import numpy as np import tensorflow as tf import tensorflow.contrib as tc @@ -9,6 +10,7 @@ from baselines.common.mpi_adam import MpiAdam import baselines.common.tf_util as U from baselines.common.mpi_running_mean_std import RunningMeanStd +from baselines.common.tf_util import save_variables, load_variables try: from mpi4py import MPI except ImportError: @@ -98,6 +100,8 @@ def __init__(self, actor, critic, memory, observation_shape, action_shape, param self.batch_size = batch_size self.stats_sample = None self.critic_l2_reg = critic_l2_reg + self.save = None + self.load = None # Observation normalization. if self.normalize_observations: @@ -333,6 +337,8 @@ def train(self): def initialize(self, sess): self.sess = sess self.sess.run(tf.global_variables_initializer()) + self.save = functools.partial(save_variables, sess=self.sess) + self.load = functools.partial(load_variables, sess=self.load) self.actor_optimizer.sync() self.critic_optimizer.sync() self.sess.run(self.target_init_updates) diff --git a/baselines/run.py b/baselines/run.py index 609de6ec5c..451544523e 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -5,7 +5,7 @@ from collections import defaultdict import tensorflow as tf import numpy as np - +import MADRaS from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env @@ -50,7 +50,7 @@ 'SpaceInvaders-Snes', } - +_game_envs['madras'] = {'gym-torcs-v0','gym-madras-v0'} def train(args, extra_args): env_type, env_id = get_env_type(args.env) print('env_type: {}'.format(env_type))