Skip to content

Commit

Permalink
fix multiprocessing bug empty ReplayBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonv1943 authored Mar 1, 2021
1 parent 2bc7b5a commit 76d2a4d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 5 additions & 5 deletions elegantrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def __init__(self, net_dim, state_dim, action_dim, learning_rate=1e-4):

self.obj_c = (-np.log(0.5)) ** 0.5 # for reliable_lambda

def update_net(self, buffer, max_step, batch_size, repeat_times):
def update_net(self, buffer, target_step, batch_size, repeat_times):
"""ModSAC (Modified SAC using Reliable lambda)
1. Reliable Lambda is calculated based on Critic's loss function value.
2. Increasing batch_size and update_times
Expand All @@ -527,7 +527,7 @@ def update_net(self, buffer, max_step, batch_size, repeat_times):

k = 1.0 + buffer.now_len / buffer.max_len
batch_size_ = int(batch_size * k)
train_steps = int(max_step * k * repeat_times)
train_steps = int(target_step * k * repeat_times)

alpha = self.alpha_log.exp().detach()
update_a = 0
Expand Down Expand Up @@ -559,7 +559,7 @@ def update_net(self, buffer, max_step, batch_size, repeat_times):

'''objective of actor using reliable_lambda and TTUR (Two Time-scales Update Rule)'''
reliable_lambda = np.exp(-self.obj_c ** 2) # for reliable_lambda
if_update_a = update_a / update_c < 1 / (2 - reliable_lambda)
if_update_a = (update_a / update_c) < (1 / (2 - reliable_lambda))
if if_update_a: # auto TTUR
update_a += 1

Expand Down Expand Up @@ -601,7 +601,7 @@ def select_actions(self, states):
actions = self.act.get__noise_action(states)
return actions.detach().cpu().numpy()

def update_net(self, buffer, max_step, batch_size, repeat_times): # 1111
def update_net(self, buffer, target_step, batch_size, repeat_times): # 1111
"""Contribution of InterSAC (Integrated network for SAC)
1. Encoder-DenseNetLikeNet-Decoder network architecture.
share parameter between two **different input** network
Expand All @@ -617,7 +617,7 @@ def update_net(self, buffer, max_step, batch_size, repeat_times): # 1111

k = 1.0 + buffer.now_len / buffer.max_len
batch_size_ = int(batch_size * k) # increase batch_size
train_steps = int(max_step * k * repeat_times) # increase training_step
train_steps = int(target_step * k * repeat_times) # increase training_step

update_a = 0
for update_c in range(1, train_steps):
Expand Down
5 changes: 4 additions & 1 deletion elegantrl/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def run__demo():
args.break_step = int(5e6) # 5e6 (15e6) UsedTime 3,000s (9,000s)
args.net_dim = 2 ** 8
args.target_step = args.env.max_step
args.max_memo = (args.max_step - 1) * 8
args.max_memo = (args.target_step - 1) * 8
args.batch_size = 2 ** 11
args.repeat_times = 2 ** 4
args.eval_times1 = 2 ** 4
Expand Down Expand Up @@ -776,9 +776,12 @@ def mp_explore_in_env(args, pipe2_exp, worker_id):
agent.update_buffer(env, buffer, exp_step, reward_scale, gamma)

buffer.update__now_len__before_sample()

pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
# buf_state, buf_other = pipe1_exp.recv()

buffer.empty_memories__before_explore()


def mp_evaluate_agent(args, pipe2_eva):
env = args.env
Expand Down

0 comments on commit 76d2a4d

Please sign in to comment.