Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/zjowowen/DI-engine into iql
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Dec 19, 2024
2 parents 851ef54 + 9a6e46f commit 8b7e4fc
Show file tree
Hide file tree
Showing 18 changed files with 317 additions and 36 deletions.
10 changes: 0 additions & 10 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
if python --version | grep -q "Python 3.7"; then
python -m pip install wandb==0.16.4
else
echo "Python version is not 3.7, skipping wandb installation"
fi
./ding/scripts/install-k8s-tools.sh
make unittest
- name: Upload coverage to Codecov
Expand Down Expand Up @@ -60,10 +55,5 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
if python --version | grep -q "Python 3.7"; then
python -m pip install wandb==0.16.4
else
echo "Python version is not 3.7, skipping wandb installation"
fi
./ding/scripts/install-k8s-tools.sh
make benchmark
4 changes: 2 additions & 2 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from ding.utils import set_pkg_seed, get_rank


def serial_pipeline_onpolicy(
Expand Down Expand Up @@ -68,7 +68,7 @@ def serial_pipeline_onpolicy(
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
def test_serial_pipeline_trex_onpolicy():
exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
config[0].policy.learn.learner = dict()
config[0].policy.learn.learner.hook = dict()
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
config[0].exp_name = exp_name
expert_policy = serial_pipeline_onpolicy(config, seed=0)
Expand Down
6 changes: 3 additions & 3 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
Expand Down Expand Up @@ -143,7 +143,7 @@ class EpisodeCollector:
"""
Overview:
The class of the collector running by episodes, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
process. Use the `__call__` method to execute the whole collection process.
"""

def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None:
Expand All @@ -168,7 +168,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing the \
target number of episodes.
target number of episodes.
Input of ctx:
- env_episode (:obj:`int`): The env env_episode which will increase during collection.
"""
Expand Down
6 changes: 2 additions & 4 deletions ding/framework/middleware/tests/mock_for_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def process_transition(self, obs: Any, model_output: dict, timestep: namedtuple)
'logit': 1.0,
'value': 2.0,
'reward': 0.1,
'done': True,
'done': timestep.done,
}
return transition

Expand All @@ -75,7 +75,6 @@ def __init__(self) -> None:
self.env_num = env_num
self.obs_dim = obs_dim
self.closed = False
self._reward_grow_indicator = 1
self._steps = [0 for _ in range(self.env_num)]

@property
Expand Down Expand Up @@ -111,11 +110,10 @@ def step(self, actions: tnp.ndarray) -> List[tnp.ndarray]:
obs=torch.rand(self.obs_dim),
reward=1.0,
done=done,
info={'eval_episode_return': self._reward_grow_indicator * 1.0} if done else {},
info={'eval_episode_return': 10.0} if done else {},
env_id=i,
)
timesteps.append(tnp.array(timestep))
self._reward_grow_indicator += 1 # eval_episode_return will increase as step method is called
return timesteps


Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def test_interaction_evaluator():
# there are 2 env_num and 5 episodes in the test.
# so when interaction_evaluator runs the first time, reward is [[1, 2, 3], [2, 3]] and the avg = 2.2
# the second time, reward is [[4, 5, 6], [5, 6]] . . .
assert ctx.eval_value == 2.2 + i // 10 * 3.0
assert ctx.eval_value == 10.0
2 changes: 1 addition & 1 deletion ding/torch_utils/optimizer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _state_init(self, p, momentum, centered):
# wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
else:
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if self.defaults['capturable'] else torch.tensor(0.)
if ('capturable' in self.defaults and self.defaults['capturable']) else torch.tensor(0.)
state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device)
state['square_avg'] = torch.zeros_like(p.data, device=p.data.device)
if momentum:
Expand Down
6 changes: 5 additions & 1 deletion ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def eval(
'''
# evaluator only work on rank0
stop_flag = False
episode_info = None # Initialize to ensure it's defined in all ranks

if get_rank() == 0:
if n_episode is None:
n_episode = self._default_n_episode
Expand Down Expand Up @@ -317,5 +319,7 @@ def eval(
broadcast_object_list(objects, src=0)
stop_flag, episode_info = objects

episode_info = to_item(episode_info)
# Ensure episode_info is converted to the correct format
episode_info = to_item(episode_info) if episode_info is not None else {}

return stop_flag, episode_info
2 changes: 1 addition & 1 deletion ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ding.envs import BaseEnvManager
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
allreduce_data
from ding.torch_utils import to_tensor, to_ndarray
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions

Expand Down
67 changes: 67 additions & 0 deletions dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from easydict import EasyDict

pong_dqn_config = dict(
exp_name='data_pong/pong_dqn_ddp_seed0',
env=dict(
collector_env_num=4,
evaluator_env_num=4,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
),
policy=dict(
multi_gpu=True,
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_dqn_config = EasyDict(pong_dqn_config)
main_config = pong_dqn_config
pong_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
pong_dqn_create_config = EasyDict(pong_dqn_create_config)
create_config = pong_dqn_create_config

if __name__ == '__main__':
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline
with DDPContext():
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e6))
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict

pong_onppo_config = dict(
pong_ppo_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
Expand Down Expand Up @@ -49,19 +49,19 @@
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
main_config = EasyDict(pong_onppo_config)
main_config = EasyDict(pong_ppo_config)

pong_onppo_create_config = dict(
pong_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
create_config = EasyDict(pong_onppo_create_config)
create_config = EasyDict(pong_ppo_create_config)

if __name__ == "__main__":
# or you can enter `ding -m serial_onpolicy -c pong_onppo_config.py -s 0`
# or you can enter `ding -m serial_onpolicy -c pong_ppo_config.py -s 0`
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy((main_config, create_config), seed=0)
76 changes: 76 additions & 0 deletions dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict

pong_ppo_config = dict(
exp_name='data_pong/pong_ppo_ddp_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
),
policy=dict(
multi_gpu=True,
cuda=True,
recompute_adv=True,
action_space='discrete',
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
action_space='discrete',
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
# for ppo, when we recompute adv, we need the key done in data to split traj, so we must
# use ignore_done=False here,
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
# for halfcheetah, the length=1000
ignore_done=False,
grad_clip_type='clip_norm',
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
),
)
main_config = EasyDict(pong_ppo_config)

pong_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
create_config = EasyDict(pong_ppo_create_config)

if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py
"""
from ding.utils import DDPContext
from ding.entry import serial_pipeline_onpolicy
with DDPContext():
serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(3e6))
6 changes: 6 additions & 0 deletions dizoo/atari/example/atari_dqn_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,10 @@ def main():


if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/example/atari_dqn_ddp.py
"""
main()
2 changes: 1 addition & 1 deletion dizoo/atari/example/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
gae_estimator, termination_checker
from ding.utils import set_pkg_seed
from dizoo.atari.envs.atari_env import AtariEnv
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config


def main():
Expand Down
15 changes: 11 additions & 4 deletions dizoo/atari/example/atari_ppo_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, ddp_termination_checker, online_logger
from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size
from ding.utils import set_pkg_seed, DDPContext, get_rank, get_world_size
from dizoo.atari.envs.atari_env import AtariEnv
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
with DistContext():
with DDPContext():
rank, world_size = get_rank(), get_world_size()
main_config.example = 'pong_ppo_seed0_ddp_avgsplit'
main_config.policy.multi_gpu = True
Expand Down Expand Up @@ -45,12 +45,19 @@ def main():
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
task.use(multistep_trainer(cfg, policy.learn_mode))
task.use(multistep_trainer(policy.learn_mode))
if rank == 0:
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(online_logger(record_train_iter=True))
task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank))
task.run()


if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/example/atari_ppo_ddp.py
"""
main()
Loading

0 comments on commit 8b7e4fc

Please sign in to comment.