Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QDT #442

Merged
merged 8 commits into from
Feb 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 332 additions & 0 deletions reproductions/offline/qdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
import argparse
from datetime import datetime
from typing import Optional, Union

import gym
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR

import d3rlpy
from d3rlpy.algos import CQL, IQL
from d3rlpy.dataset import InfiniteBuffer, ReplayBuffer
from d3rlpy.types import NDArray


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="hopper-medium-v2")
parser.add_argument("--context_size", type=int, default=20)
parser.add_argument("--model_file", type=str, default=None)
parser.add_argument(
"--q_learning_type", type=str, default="cql"
) # Q-learning algorithm ("cql" or "iql")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--gpu", type=int)
args = parser.parse_args()

# get timestamp
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")

dataset, env = d3rlpy.datasets.get_dataset(args.dataset)

# first fit Q-learning algorithm to the dataset
if args.model_file is not None:
# load model and assert type
q_algo_loaded = d3rlpy.load_learnable(args.model_file)
if not isinstance(q_algo_loaded, (CQL, IQL)):
raise ValueError(
"The loaded model is not an instance of CQL or IQL."
)
# cast to the expected type
q_algo = q_algo_loaded
else:
if args.q_learning_type == "cql":
q_algo = fit_cql(dataset, env, args.seed, args.gpu, timestamp)
elif args.q_learning_type == "iql":
q_algo = fit_iql(dataset, env, args.seed, args.gpu, timestamp)

# relabel dataset RTGs with the learned value functions
print("Relabeling dataset with RTGs...")
if not isinstance(dataset._buffer, InfiniteBuffer):
raise ValueError("Dataset must be an InfiniteBuffer.")

relabel_dataset_rtg(
dataset._buffer, q_algo, args.context_size, seed=args.seed
)

# fit decision transformer to the relabeled dataset
fit_dt(
dataset, env, args.context_size, args.seed, args.gpu, False, timestamp
)

return


""" --------------------------------------------------------------------
Aargument dataset
-------------------------------------------------------------------- """


def relabel_dataset_rtg(
buffer: InfiniteBuffer,
q_algo: Union[CQL, IQL],
k: int,
num_action_samples: int = 10,
seed: int = 0,
) -> None:
"""
Relabel RTG (reward-to-go) to the given dataset using the given Q-function.

Args:
buffer (InfiniteBuffer): Buffer holding trajectory dataset.
q_algo: Trained Q-learning algoirthm.
k (int): Context length for DT.
seed (int): The random seed.
num_action_samples (int, optional): The number of action samples for
V function estimation. Defaults to 10.
gpu (int, optional): The GPU device ID. Defaults to None.
timestamp (str, optional): The timestamp for experiment name.
Defaults to None.
"""
# fix seed
d3rlpy.seed(seed)

prev_idx = -1
for n in range(buffer.transition_count):
episode, idx = buffer._transitions[-n - 1] # get transitions backwards
if idx > prev_idx:
# get values for all observations in the episode
values = []
for _ in range(num_action_samples):
sampled_actions = q_algo.sample_action(episode.observations)
values.append(
q_algo.predict_value(episode.observations, sampled_actions)
)
value = np.array(values).mean(axis=0)
rewards = np.squeeze(episode.rewards, axis=1)
rtg = 0
else:
start = max(0, idx - k + 1)
rtg = rewards[idx] + np.maximum(rtg, value[idx + 1]) # relabel rtg
relabelled_rewards = np.zeros_like(rewards)
relabelled_rewards[idx] = rtg
relabelled_rewards[start:idx] = rewards[start:idx]
relabelled_episode = d3rlpy.dataset.components.Episode(
observations=episode.observations,
actions=episode.actions,
rewards=np.expand_dims(relabelled_rewards, axis=1),
terminated=episode.terminated,
)
buffer._transitions[-n - 1] = (relabelled_episode, idx)

prev_idx = idx

return


""" --------------------------------------------------------------------
Fit offline RL algorithms to the given dataset.
-------------------------------------------------------------------- """


def fit_cql(
dataset: ReplayBuffer,
env: gym.Env[NDArray, int],
seed: int = 1,
gpu: Optional[int] = None,
timestamp: Optional[str] = None,
) -> "CQL":
"""
Fit the CQL algorithm to the given dataset and environment.

Args:
dataset (MDPdataset): Dataset for the training.
env (gym.Env): The environment instance.
seed (int): The random seed.
gpu (int, optional): The GPU device ID. Defaults to None.
timestamp (str, optional): The timestamp for experiment name.
Defaults to None.
"""
# fix seed
d3rlpy.seed(seed)
d3rlpy.envs.seed_env(env, seed)

encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256])

if "medium-v0" in env.spec.id:
conservative_weight = 10.0
else:
conservative_weight = 5.0

cql = d3rlpy.algos.CQLConfig(
actor_learning_rate=1e-4,
critic_learning_rate=3e-4,
temp_learning_rate=1e-4,
actor_encoder_factory=encoder,
critic_encoder_factory=encoder,
batch_size=256,
n_action_samples=10,
alpha_learning_rate=0.0,
conservative_weight=conservative_weight,
).create(device=gpu)

cql.fit(
dataset,
n_steps=500000,
n_steps_per_epoch=1000,
save_interval=50,
evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
experiment_name=(
f"CQL_{env.spec.id}_{seed}"
if timestamp is None
else f"CQL_{env.spec.id}_{seed}_{timestamp}"
),
with_timestamp=False,
)

return cql


def fit_iql(
dataset: ReplayBuffer,
env: gym.Env[NDArray, int],
seed: int = 1,
gpu: Optional[int] = None,
timestamp: Optional[str] = None,
) -> "IQL":
"""
Fit the IQL algorithm to the given dataset and environment.

Args:
dataset (MDPdataset): Dataset for the training.
env (gym.Env): The environment instance.
seed (int): The random seed.
gpu (int, optional): The GPU device ID. Defaults to None.
timestamp (str, optional): The timestamp for experiment name.
Defaults to None.
"""
# fix seed
d3rlpy.seed(seed)
d3rlpy.envs.seed_env(env, seed)

reward_scaler = d3rlpy.preprocessing.ReturnBasedRewardScaler(
multiplier=1000.0
)

iql = d3rlpy.algos.IQLConfig(
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
batch_size=256,
gamma=0.99,
weight_temp=3.0,
max_weight=100.0,
expectile=0.7,
reward_scaler=reward_scaler,
).create(device=gpu)

# workaround for learning scheduler
iql.build_with_dataset(dataset)
assert iql.impl
scheduler = CosineAnnealingLR(
iql.impl._modules.actor_optim, # pylint: disable=protected-access
500000,
)

def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None:
scheduler.step()

iql.fit(
dataset,
n_steps=500000,
n_steps_per_epoch=1000,
save_interval=10,
callback=callback,
evaluators={
"environment": d3rlpy.metrics.EnvironmentEvaluator(env, n_trials=10)
},
experiment_name=(
f"IQL_{env.spec.id}_{seed}"
if timestamp is None
else f"IQL_{env.spec.id}_{seed}_{timestamp}"
),
with_timestamp=False,
)

return iql


def fit_dt(
dataset: ReplayBuffer,
env: gym.Env[NDArray, int],
context_size: int = 20,
seed: int = 1,
gpu: Optional[int] = None,
compile: bool = False,
timestamp: Optional[str] = None,
) -> None:
"""
Fit decisiton transformer to the given dataset and environment.

Args:
dataset (MDPdataset): Dataset for the training.
env (gym.Env): The environment instance.
seed (int): The random seed.
gpu (int, optional): The GPU device ID. Defaults to None.
timestamp (str, optional): The timestamp for experiment name.
Defaults to None.
"""
# fix seed
d3rlpy.seed(seed)
d3rlpy.envs.seed_env(env, seed)

if "halfcheetah" in env.spec.id:
target_return = 6000
elif "hopper" in env.spec.id:
target_return = 3600
elif "walker" in env.spec.id:
target_return = 5000
else:
raise ValueError("unsupported dataset")

dt = d3rlpy.algos.DecisionTransformerConfig(
batch_size=64,
learning_rate=1e-4,
optim_factory=d3rlpy.optimizers.AdamWFactory(
weight_decay=1e-4,
clip_grad_norm=0.25,
lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
warmup_steps=10000
),
),
encoder_factory=d3rlpy.models.VectorEncoderFactory(
[128],
exclude_last_activation=True,
),
observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.001),
position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
context_size=context_size,
num_heads=1,
num_layers=3,
max_timestep=1000,
compile_graph=compile,
).create(device=gpu)

dt.fit(
dataset,
n_steps=100000,
n_steps_per_epoch=1000,
save_interval=10,
eval_env=env,
eval_target_return=target_return,
experiment_name=(
f"QDT_{env.spec.id}_{seed}"
if timestamp is None
else f"QDT_{env.spec.id}_{seed}_{timestamp}"
),
with_timestamp=False,
)


if __name__ == "__main__":
main()
Loading