Skip to content

Commit

Permalink
added saving function for ddpg
Browse files Browse the repository at this point in the history
  • Loading branch information
rudrasohan committed Nov 30, 2018
1 parent 2d5593f commit bba5cfc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
10 changes: 10 additions & 0 deletions baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def learn(network, env,
tau=0.01,
eval_env=None,
param_noise_adaption_interval=50,
load_path = None,
save_path = '<specify/path>'
**network_kwargs):

set_global_seeds(seed)
Expand Down Expand Up @@ -91,6 +93,9 @@ def learn(network, env,
batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
reward_scale=reward_scale)

if load_path is not None:
agent.load(load_path)
logger.info('Using agent with the following configuration:')
logger.info(str(agent.__dict__.items()))

Expand Down Expand Up @@ -269,5 +274,10 @@ def as_scalar(x):
with open(os.path.join(logdir, 'eval_env_state.pkl'), 'wb') as f:
pickle.dump(eval_env.get_state(), f)

os.mkdirs(logdir,exist_ok=True)
savepath = os.path.join(save_path, str(epoch))
print('Saving to ',savepath)
agent.save(savepath)


return agent
6 changes: 6 additions & 0 deletions baselines/ddpg/ddpg_learner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import copy
from functools import reduce

import functools
import numpy as np
import tensorflow as tf
import tensorflow.contrib as tc
Expand All @@ -9,6 +10,7 @@
from baselines.common.mpi_adam import MpiAdam
import baselines.common.tf_util as U
from baselines.common.mpi_running_mean_std import RunningMeanStd
from baselines.common.tf_util import save_variables, load_variables
try:
from mpi4py import MPI
except ImportError:
Expand Down Expand Up @@ -98,6 +100,8 @@ def __init__(self, actor, critic, memory, observation_shape, action_shape, param
self.batch_size = batch_size
self.stats_sample = None
self.critic_l2_reg = critic_l2_reg
self.save = None
self.load = None

# Observation normalization.
if self.normalize_observations:
Expand Down Expand Up @@ -333,6 +337,8 @@ def train(self):
def initialize(self, sess):
self.sess = sess
self.sess.run(tf.global_variables_initializer())
self.save = functools.partial(save_variables, sess=self.sess)
self.load = functools.partial(load_variables, sess=self.load)
self.actor_optimizer.sync()
self.critic_optimizer.sync()
self.sess.run(self.target_init_updates)
Expand Down
2 changes: 1 addition & 1 deletion baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
import tensorflow as tf
import numpy as np

import MADRaS
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
Expand Down

0 comments on commit bba5cfc

Please sign in to comment.