Skip to content

Commit

Permalink
polish iql
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Dec 12, 2024
1 parent 1941ed8 commit 851ef54
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 24 deletions.
24 changes: 11 additions & 13 deletions ding/model/template/qvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,17 @@ def setup_conv_encoder():
)
)
)
self.critic_v_head.append(
nn.Sequential(
nn.Linear(critic_v_input_size, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type
)
)
self.critic_v_head = nn.Sequential(
nn.Linear(critic_v_input_size, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type
)
)
else:
self.critic_q_head = nn.Sequential(
nn.Linear(critic_q_input_size, critic_head_hidden_size), activation,
Expand Down Expand Up @@ -356,7 +354,7 @@ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
x = torch.cat([obs, action], dim=1)
if self.twin_critic:
x = [m(x)['pred'] for m in self.critic_q_head]
y = [m(obs)['pred'] for m in self.critic_v_head]
y = self.critic_v_head(obs)['pred']
else:
x = self.critic_q_head(x)['pred']
y = self.critic_v_head(obs)['pred']
Expand Down
14 changes: 3 additions & 11 deletions ding/policy/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,11 @@ def _init_learn(self) -> None:
self._model.critic_q_head[0][-1].last.bias.data.uniform_(-init_w, init_w)
self._model.critic_q_head[1][-1].last.weight.data.uniform_(-init_w, init_w)
self._model.critic_q_head[1][-1].last.bias.data.uniform_(-init_w, init_w)
self._model.critic_v_head[0][-1].last.weight.data.uniform_(-init_w, init_w)
self._model.critic_v_head[0][-1].last.bias.data.uniform_(-init_w, init_w)
self._model.critic_v_head[1][-1].last.weight.data.uniform_(-init_w, init_w)
self._model.critic_v_head[1][-1].last.bias.data.uniform_(-init_w, init_w)
else:
self._model.critic_q_head[2].last.weight.data.uniform_(-init_w, init_w)
self._model.critic_q_head[-1].last.bias.data.uniform_(-init_w, init_w)
self._model.critic_v_head[2].last.weight.data.uniform_(-init_w, init_w)
self._model.critic_v_head[-1].last.bias.data.uniform_(-init_w, init_w)
self._model.critic_v_head[2].last.weight.data.uniform_(-init_w, init_w)
self._model.critic_v_head[-1].last.bias.data.uniform_(-init_w, init_w)

# Optimizers
self._optimizer_q = Adam(
Expand Down Expand Up @@ -321,16 +317,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
# 3. compute v loss
if self._twin_critic:
q_value_min = torch.min(q_value[0], q_value[1]).detach()
v_loss_0 = asymmetric_l2_loss(q_value_min - v_value[0], self._tau)
v_loss_1 = asymmetric_l2_loss(q_value_min - v_value[1], self._tau)
v_loss = (v_loss_0 + v_loss_1) / 2
v_loss = asymmetric_l2_loss(q_value_min - v_value, self._tau)
else:
advantage = q_value.detach() - v_value
v_loss = asymmetric_l2_loss(advantage, self._tau)

# 4. compute q loss
if self._twin_critic:
next_v_value = torch.min(next_v_value[0], next_v_value[1])
q_data0 = v_1step_td_data(q_value[0], next_v_value, reward, done, data['weight'])
loss_dict['critic_q_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
q_data1 = v_1step_td_data(q_value[1], next_v_value, reward, done, data['weight'])
Expand Down Expand Up @@ -362,7 +355,6 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
new_q_value, new_v_value = new_value['q_value'], new_value['v_value']
if self._twin_critic:
new_q_value = torch.min(new_q_value[0], new_q_value[1])
new_v_value = torch.min(new_v_value[0], new_v_value[1])
new_advantage = new_q_value - new_v_value

# 8. compute policy loss
Expand Down
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_expert_iql_seed0",
env=dict(
env_id='halfcheetah-medium-expert-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_replay_iql_seed0",
env=dict(
env_id='halfcheetah-medium-replay-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_expert_iql_seed0",
env=dict(
env_id='hopper-medium-expert-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_iql_seed0",
env=dict(
env_id='hopper-medium-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_replay_iql_seed0",
env=dict(
env_id='hopper-medium-replay-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
Loading

0 comments on commit 851ef54

Please sign in to comment.