Skip to content

Commit

Permalink
Merge pull request #3 from rudrasohan/comp
Browse files Browse the repository at this point in the history
Added Saving functionality.
  • Loading branch information
buridiaditya authored Dec 5, 2018
2 parents a75102a + 0ddec01 commit 198bbed
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
20 changes: 17 additions & 3 deletions baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ def learn(network, env,
tau=0.01,
eval_env=None,
param_noise_adaption_interval=50,
load_path = None,
save_path = '<specify/path>',
**network_kwargs):

print("Save PATH;{}".format(save_path))
print("Load PATH;{}".format(load_path))
set_global_seeds(seed)

if total_timesteps is not None:
Expand All @@ -58,8 +62,7 @@ def learn(network, env,
rank = 0

nb_actions = env.action_space.shape[-1]
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.

#assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
memory = Memory(limit=int(1e6), action_shape=env.action_space.shape, observation_shape=env.observation_space.shape)
critic = Critic(network=network, **network_kwargs)
actor = Actor(nb_actions, network=network, **network_kwargs)
Expand Down Expand Up @@ -91,14 +94,19 @@ def learn(network, env,
batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
reward_scale=reward_scale)

logger.info('Using agent with the following configuration:')
logger.info(str(agent.__dict__.items()))

eval_episode_rewards_history = deque(maxlen=100)
episode_rewards_history = deque(maxlen=100)
sess = U.get_session()
# Prepare everything.
sess = U.get_session()
agent.initialize(sess)
checkpoint_num = 0
if load_path is not None:
agent.load(load_path)
checkpoint_num = int(os.path.split(load_path)[1]) + 1
sess.graph.finalize()

agent.reset()
Expand All @@ -124,6 +132,8 @@ def learn(network, env,
epoch_actions = []
epoch_qs = []
epoch_episodes = 0
if load_path is None:
os.makedirs(save_path, exist_ok=True)
for epoch in range(nb_epochs):
for cycle in range(nb_epoch_cycles):
# Perform rollouts.
Expand Down Expand Up @@ -269,5 +279,9 @@ def as_scalar(x):
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
pickle.dump(eval_env.get_state(), f)

savepath = os.path.join(save_path, str(epoch+checkpoint_num))
print('Saving to ', savepath)
agent.save(savepath)


return agent
6 changes: 6 additions & 0 deletions baselines/ddpg/ddpg_learner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import copy
from functools import reduce

import functools
import numpy as np
import tensorflow as tf
import tensorflow.contrib as tc
Expand All @@ -9,6 +10,7 @@
from baselines.common.mpi_adam import MpiAdam
import baselines.common.tf_util as U
from baselines.common.mpi_running_mean_std import RunningMeanStd
from baselines.common.tf_util import save_variables, load_variables
try:
from mpi4py import MPI
except ImportError:
Expand Down Expand Up @@ -98,6 +100,8 @@ def __init__(self, actor, critic, memory, observation_shape, action_shape, param
self.batch_size = batch_size
self.stats_sample = None
self.critic_l2_reg = critic_l2_reg
self.save = None
self.load = None

# Observation normalization.
if self.normalize_observations:
Expand Down Expand Up @@ -333,6 +337,8 @@ def train(self):
def initialize(self, sess):
self.sess = sess
self.sess.run(tf.global_variables_initializer())
self.save = functools.partial(save_variables, sess=self.sess)
self.load = functools.partial(load_variables, sess=self.load)
self.actor_optimizer.sync()
self.critic_optimizer.sync()
self.sess.run(self.target_init_updates)
Expand Down
4 changes: 2 additions & 2 deletions baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
import tensorflow as tf
import numpy as np

import MADRaS
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
Expand Down Expand Up @@ -50,7 +50,7 @@
'SpaceInvaders-Snes',
}


_game_envs['madras'] = {'gym-torcs-v0','gym-madras-v0'}
def train(args, extra_args):
env_type, env_id = get_env_type(args.env)
print('env_type: {}'.format(env_type))
Expand Down

0 comments on commit 198bbed

Please sign in to comment.