Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/zjowowen/DI-engine into iql
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Dec 11, 2024
2 parents 63363f4 + 765b8fb commit 1941ed8
Show file tree
Hide file tree
Showing 93 changed files with 1,817 additions and 310 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/platform_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
os: [macos-12, windows-latest]
os: [macos-13, windows-latest]
python-version: [3.8, 3.9]

steps:
Expand Down
9 changes: 2 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3
- Exploration algorithms: HER, RND, ICM, NGU
- LLM + RL Algorithms: PPO-max, DPO, PromptPG
- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR
- Other algorithms: such as PER, PLR, PCGrad
- MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero)
- Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL)
Expand Down Expand Up @@ -150,12 +150,6 @@ You can simply install DI-engine from PyPI with the following command:
pip install DI-engine
```

If you use Anaconda or Miniconda, you can install DI-engine from conda-forge through the following command:

```bash
conda install -c opendilab di-engine
```

For more information about installation, you can refer to [installation](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/installation.html).

And our dockerhub repo can be found [here](https://hub.docker.com/repository/docker/opendilab/ding),we prepare `base image` and `env image` with common RL environments.
Expand Down Expand Up @@ -283,6 +277,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py |

</details>

Expand Down
9 changes: 8 additions & 1 deletion ding/data/buffer/deque_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,17 @@ def count(self) -> int:
def get(self, idx: int) -> BufferedData:
"""
Overview:
The method that returns the BufferedData object given a specific index.
The method that returns the BufferedData object by subscript idx (int).
"""
return self.storage[idx]

def get_by_index(self, index: str) -> BufferedData:
"""
Overview:
The method that returns the BufferedData object given a specific index (str).
"""
return self.storage[self.indices.get(index)]

@apply_middleware("clear")
def clear(self) -> None:
"""
Expand Down
12 changes: 6 additions & 6 deletions ding/data/buffer/middleware/priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwa
self.max_priority = max(self.max_priority, new_priority)

def delete(self, chain: Callable, index: str, *args, **kwargs) -> None:
for item in self.buffer.storage:
meta = item.meta
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
item = self.buffer.get_by_index(index)
meta = item.meta
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
return chain(index, *args, **kwargs)

def clear(self, chain: Callable) -> None:
Expand Down
1 change: 1 addition & 0 deletions ding/data/buffer/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_priority():
assert data[0].meta['priority'] == 3.0
buffer.delete(data[0].index)
assert buffer.count() == N + N - 1
assert len(buffer._middleware[0].buffer_idx) == N + N - 1
buffer.clear()
assert buffer.count() == 0

Expand Down
12 changes: 11 additions & 1 deletion ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def serial_pipeline(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -86,6 +94,8 @@ def serial_pipeline(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
27 changes: 14 additions & 13 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ def mbrl_entry_setup(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)

if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -70,18 +78,7 @@ def mbrl_entry_setup(
cfg.policy.other.commander, learner, collector, evaluator, env_buffer, policy.command_mode
)

return (
cfg,
policy,
world_model,
env_buffer,
learner,
collector,
collector_env,
evaluator,
commander,
tb_logger,
)
return (cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger)


def create_img_buffer(
Expand Down Expand Up @@ -131,6 +128,8 @@ def serial_pipeline_dyna(
img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)

learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down Expand Up @@ -202,6 +201,8 @@ def serial_pipeline_dream(
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down
12 changes: 11 additions & 1 deletion ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def serial_pipeline_ngu(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -89,6 +97,8 @@ def serial_pipeline_ngu(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
13 changes: 12 additions & 1 deletion ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,16 @@ def serial_pipeline_onpolicy(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)

# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +89,8 @@ def serial_pipeline_onpolicy(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
12 changes: 11 additions & 1 deletion ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def serial_pipeline_onpolicy_ppg(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +88,8 @@ def serial_pipeline_onpolicy_ppg(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
8 changes: 5 additions & 3 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def __init__(
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._action_space.seed(0) # default seed
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
try:
low, high = self._env.reward_range
except: # for compatibility with gymnasium high-version API
low, high = -1, 1
self._reward_space = gym.spaces.Box(low=low, high=high, shape=(1, ), dtype=np.float32)
self._init_flag = True
else:
assert 'env_id' in self._cfg
Expand Down
3 changes: 1 addition & 2 deletions ding/envs/env/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def test_cartpole_pendulum(self, env_id):
def test_cartpole_pendulum_gymnasium(self, env_id):
env = gymnasium.make(env_id)
ding_env = DingEnvWrapper(env=env)
print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
cfg = EasyDict(dict(
collector_env_num=16,
evaluator_env_num=3,
Expand Down Expand Up @@ -142,7 +141,7 @@ def test_atari(self, atari_env_id):
# assert isinstance(action, np.ndarray)
assert action.shape == (1, )

@pytest.mark.unittest
@pytest.mark.envtest
@pytest.mark.parametrize('lun_bip_env_id', ['LunarLander-v2', 'LunarLanderContinuous-v2', 'BipedalWalker-v3'])
def test_lunarlander_bipedalwalker(self, lun_bip_env_id):
env_cfg = EasyDict(
Expand Down
2 changes: 2 additions & 0 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def _rollout(ctx: "OnlineRLContext"):
'step': env_info[timestep.env_id.item()]['step'],
'train_sample': env_info[timestep.env_id.item()]['train_sample'],
}
# reset corresponding env info
env_info[timestep.env_id.item()] = {'time': 0., 'step': 0, 'train_sample': 0}

episode_info.append(info)
policy.reset([timestep.env_id.item()])
Expand Down
18 changes: 15 additions & 3 deletions ding/framework/middleware/tests/mock_for_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Any, List, Callable, Dict, Optional
from collections import namedtuple
import random
import torch
import treetensor.numpy as tnp
from easydict import EasyDict
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(self) -> None:
self.obs_dim = obs_dim
self.closed = False
self._reward_grow_indicator = 1
self._steps = [0 for _ in range(self.env_num)]

@property
def ready_obs(self) -> tnp.array:
Expand All @@ -90,16 +92,26 @@ def launch(self, reset_param: Optional[Dict] = None) -> None:
return

def reset(self, reset_param: Optional[Dict] = None) -> None:
return
self._steps = [0 for _ in range(self.env_num)]

def step(self, actions: tnp.ndarray) -> List[tnp.ndarray]:
timesteps = []
for i in range(self.env_num):
if self._steps[i] < 5:
done = False
elif self._steps[i] < 10:
done = random.random() > 0.5
else:
done = True
if done:
self._steps[i] = 0
else:
self._steps[i] += 1
timestep = dict(
obs=torch.rand(self.obs_dim),
reward=1.0,
done=True,
info={'eval_episode_return': self._reward_grow_indicator * 1.0},
done=done,
info={'eval_episode_return': self._reward_grow_indicator * 1.0} if done else {},
env_id=i,
)
timesteps.append(tnp.array(timestep))
Expand Down
13 changes: 8 additions & 5 deletions ding/framework/middleware/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ def test_inferencer():

@pytest.mark.unittest
def test_rolloutor():
N = 20
ctx = OnlineRLContext()
transitions = TransitionList(2)
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
policy = MockPolicy()
env = MockEnv()
for _ in range(10):
inferencer(0, policy, env)(ctx)
rolloutor(policy, env, transitions)(ctx)
assert ctx.env_episode == 20 # 10 * env_num
assert ctx.env_step == 20 # 10 * env_num
i = inferencer(0, policy, env)
r = rolloutor(policy, env, transitions)
for _ in range(N):
i(ctx)
r(ctx)
assert ctx.env_step == N * 2 # N * env_num
assert ctx.env_episode >= N // 10 * 2 # N * env_num


@pytest.mark.unittest
Expand Down
11 changes: 5 additions & 6 deletions ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from os import path
import os
import copy
from functools import partial
from easydict import EasyDict
from collections import deque
import pytest
import shutil
import wandb
import h5py
import torch.nn as nn
from unittest.mock import MagicMock
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -207,7 +205,6 @@ def test_wandb_online_logger():
env = TheEnvClass()
ctx = OnlineRLContext()
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
model = TheModelClass()
wandb.init(config=cfg, anonymous="must")

def mock_metric_logger(data, step):
Expand All @@ -233,15 +230,17 @@ def mock_metric_logger(data, step):
]
assert set(data.keys()) <= set(metric_list)

def mock_gradient_logger(input_model, log, log_freq, log_graph):
def mock_gradient_logger(input_model, model, log, log_freq, log_graph):
assert input_model == model

def test_wandb_online_logger_metric():
model = TheModelClass()
with patch.object(wandb, 'log', new=mock_metric_logger):
wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx)

def test_wandb_online_logger_gradient():
with patch.object(wandb, 'watch', new=mock_gradient_logger):
model = TheModelClass()
with patch.object(wandb, 'watch', new=partial(mock_gradient_logger, model=model)):
wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx)

test_wandb_online_logger_metric()
Expand Down
Loading

0 comments on commit 1941ed8

Please sign in to comment.