Skip to content
/ L2M Public

Learning to Modulate pre-trained Models in RL (Decision Transformer, LoRA, Fine-tuning)

License

Notifications You must be signed in to change notification settings

ml-jku/L2M

Repository files navigation

Learning to Modulate pre-trained Models in RL

arXiv License: MIT

Thomas Schmied1, Markus Hofmarcher2, Fabian Paischer1, Razvan Pacscanu3,4, Sepp Hochreiter1,5

1ELLIS Unit Linz and LIT AI Lab, Institute for Machine Learning, Johannes Kepler University Linz, Austria
2JKU LIT SAL eSPML Lab, Institute for Machine Learning, Johannes Kepler University Linz, Austria
3Google DeepMind
4UCL
5Institute of Advanced Research in Artificial Intelligence (IARAI), Vienna, Austria

This repository contains the source code for "Learning to Modulate pre-trained Models in RL" accepted at NeurIPS 2023. The paper is available here.

Multi-domain Decision Transformer (MDDT)

Overview

This codebase supports training Decision Transformer (DT) models online or from offline datasets on the following domains:

This codebase relies on open-source frameworks, including:

What is in this repository?

.
├── configs                    # Contains all .yaml config files for Hydra to configure agents, envs, etc.
│   ├── agent_params            
│   ├── wandb_callback_params
│   ├── env_params
│   ├── eval_params
│   ├── run_params
│   └── config.yaml            # Main config file for Hydra - specifies log/data/model directories.
├── continual_world            # Submodule for Continual-World.
├── dmc2gym_custom             # Custom wrapper for DMControl.
├── figures             
├── scripts                    # Scrips for running experiments on Slurm/PBS in multi-gpu/node setups.
├── src                        # Main source directory.
│   ├── algos                  # Contains agent/model/prompt classes.
│   ├── augmentations          # Image augmentations.
│   ├── buffers                # Contains replay trajectory buffers.
│   ├── callbacks              # Contains callbacks for training (e.g., WandB, evaluation, etc.).
│   ├── data                   # Contains data utilities (e.g., for downloading Atari)
│   ├── envs                   # Contains functionality for creating environments.
│   ├── exploration            # Contains exploration strategies.
│   ├── optimizers             # Contains (custom) optimizers.
│   ├── schedulers             # Contains learning rate schedulers.
│   ├── tokenizers_custom      # Contains custom tokenizers for discretizing states/actions.
│   ├── utils                  
│   └── __init__.py
├── LICENSE
├── README.md
├── environment.yaml
├── requirements.txt
└── main.py                     # Main entry point for training/evaluating agents.

Installation

Environment configuration and dependencies are available in environment.yaml and requirements.txt.

First, create the conda environment.

conda env create -f environment.yaml
conda activate mddt

Then install the remaining requirements (with MuJoCo already downloaded, if not see here):

pip install -r requirements.txt

Init the continualworld submodule and install:

git submodule init
git submodule update
cd continualworld
pip install .

Install meta-world:

pip install git+https://github.com/rlworkgroup/metaworld.git@18118a28c06893da0f363786696cc792457b062b

Install custom version of dmc2gym. Our version makes flatten_obs optional, and, thus, allows us to construct the full observation space of all DMControl envs.

cd dmc2gym_custom
pip install -e .

MuJoCo installation

Download MuJoCo:

mkdir ~/.mujoco
cd ~/.mujoco
wget https://www.roboti.us/download/mujoco200_linux.zip
unzip mujoco200_linux.zip
mv mujoco200_linux mujoco200
wget https://www.roboti.us/file/mjkey.txt

Then add the following line to .bashrc:

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin

Troubleshooting on cluster (without root access)

The following issues were helpful:

First, install the following packages:

conda install -c conda-forge glew mesalib
conda install -c menpo glfw3 osmesa
pip install patchelf

Create the symlink manually:

cp /usr/lib64/libGL.so.1 $CONDA_PREFIX/lib
ln -s $CONDA_PREFIX/lib/libGL.so.1 $CONDA_PREFIX/lib/libGL.so

Then do:

mkdir ~/rpm
cd ~/rpm
curl -o libgcrypt11.rpm ftp://ftp.pbone.net/mirror/ftp5.gwdg.de/pub/opensuse/repositories/home:/bosconovic:/branches:/home:/elimat:/lsi/openSUSE_Leap_15.1/x86_64/libgcrypt11-1.5.4-lp151.23.29.x86_64.rpm
rpm2cpio libgcrypt11.rpm | cpio -id

Finally, export the path to rpm dir (add to ~/.bashrc):

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/rpm/usr/lib64
export LDFLAGS="-L/~/rpm/usr/lib64"

Setup

Experiment configuration

This codebase relies on Hydra, which configures experiments via .yaml files. Hydra automatically creates the log folder structure for a given run, as specified in the respective config.yaml file.

The config.yaml is the main configuration entry point and contains the default parameters. The file references the respective default parameter files under the block defaults. In addition, config.yaml contains 4 important constants that configure the directory paths:

LOG_DIR: ../logs
DATA_DIR: ../data
SSD_DATA_DIR: ../data
MODELS_DIR: ../models

Datasets

The genereated datasets are currently hosted via our web-server. Download Meta-World and DMControl datasets to the specified DATA_DIR:

# Meta-World
wget --recursive --no-parent --no-host-directories --cut-dirs=2 -R "index.html*" https://ml.jku.at/research/l2m/metaworld
# DMControl
wget --recursive --no-parent --no-host-directories --cut-dirs=2 -R "index.html*" https://ml.jku.at/research/l2m/dm_control_1M

The datasets are also available on the Huggingface hub. Download using the huggingface-cli:

# Meta-World
huggingface-cli download ml-jku/meta-world --local-dir=./meta-world --repo-type dataset
# DMControl
huggingface-cli download ml-jku/dm_control --local-dir=./dm_control --repo-type dataset

The framework also supports Atari, D4RL, and visual DMControl datasets. For Atari and visual DMControl, we refer to the respective READMEs.

Running experiments

In the following, we provide some illustrative examples of how to run the experiments in the paper.

Pre-training runs

To train a 40M multi-domain Decision Transformer (MDDT) model on MT40 + DMC10 with 3 seeds on a single GPU, run:

python main.py -m experiment_name=pretrain seed=42,43,44 env_params=multi_domain_mtdmc run_params=pretrain eval_params=pretrain_disc agent_params=cdt_pretrain_disc agent_params.kind=MDDT agent_params/model_kwargs=multi_domain_mtdmc agent_params/data_paths=mt40v2_dmc10 +agent_params/replay_buffer_kwargs=multi_domain_mtdmc +agent_params.accumulation_steps=2

Single-task fine-tuning

To fine-tune the pre-trained model using LoRA on a single CW10 task with 3 seeds, run:

python main.py -m experiment_name=cw10_lora seed=42,43,44 env_params=mt50_pretrain run_params=finetune eval_params=finetune agent_params=cdt_mpdt_disc agent_params/model_kwargs=mdmpdt_mtdmc agent_params/data_paths=cw10_v2_cwnet_2M +agent_params/replay_buffer_kwargs=mtdmc_ft agent_params/model_kwargs/prompt_kwargs=lora env_params.envid=hammer-v2 agent_params.data_paths.names='${env_params.envid}.pkl' env_params.eval_env_names=

Continual fine-tuning

To fine-tune the pre-trained model using L2M on all CW10 tasks in a sequential manner with 3 seeds, run:

python main.py -m experiment_name=cw10_cl_l2m seed=42,43,44 env_params=multi_domain_ft env_params.eval_env_names=cw10_v2 run_params=finetune_coff eval_params=finetune_md_cl agent_params=cdt_mpdt_disc +agent_params.steps_per_task=100000 agent_params/model_kwargs=mdmpdt_mtdmc agent_params/data_paths=cw10_v2_cwnet_2M +agent_params/replay_buffer_kwargs=mtdmc_ft +agent_params.replay_buffer_kwargs.kind=continual agent_params/model_kwargs/prompt_kwargs=l2m_lora

Multi-GPU training

For multi-GPU training, we use torchrun. The tool conflicts with hydra. Therefore, a launcher plugin hydra_torchrun_launcher was created.

To enable the plugin, clone the hydra repo, cd to contrib/hydra_torchrun_launcher, and pip install the plugin:

git clone https://github.com/facebookresearch/hydra.git
cd hydra/contrib/hydra_torchrun_launcher
pip install -e .

The plugin can be used from the commandline:

python main.py -m hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 [...]

Running experiments on a local cluster on a single node can be done via CUDA_VISIBLE_DEVICES to specify the GPUs to use:

CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 [...]

On Slurm, executing torchrun on a single node works alike. E.g., to run on 2 GPUs on a single node:

#!/bin/bash
#SBATCH --account=X
#SBATCH --qos=X
#SBATCH --partition=X
#SBATCH --nodes=1
#SBATCH --gpus=2
#SBATCH --cpus-per-task=32

source activate mddt
python main.py -m hydra/launcher=torchrun hydra.launcher.nproc_per_node=2 [...]

Example scripts for multi-gpu training on Slurm or PBS are available in scripts.

Multi-node training

Running on Slurm/PBS in a multi-node setup requires a little more care. Example scripts are provided in scripts.

Citation

If you find this useful, please consider citing our work:

@article{schmied2024learning,
  title={Learning to Modulate pre-trained Models in RL},
  author={Schmied, Thomas and Hofmarcher, Markus and Paischer, Fabian and Pascanu, Razvan and Hochreiter, Sepp},
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2024}
}

Releases

No releases published

Packages

No packages published