This is the official repository for the paper Iterated Denoising Energy Matching for Sampling from Boltzmann Densities.
We propose iDEM, a scalable and efficient method to sample from unnormalized probability distributions. iDEM makes use of the DEM objective, inspired by the stochastic regression and simulation
free principles of score and flow matching objectives while allowing one to learn off-policy, in a loop while itself generating (optionally exploratory) new states which are subsequently
learned on. iDEM is also capable of incorporating symmetries, namely those represented by the product group of
- DW4 -- the 4-particle double well potential (8 dimensions total)
- LJ13 -- the 13-particle Lennard-Jones potential (39 dimensions total)
- LJ55 -- the 55-particle Lennard-Jones potential (165 dimensions total)
This code was taken from an internal repository and as such all commit history is lost here. Development credit for this repository goes primarily to @atong01, @jarridrb and @taraak who built out most of the code and experiments with help from @sarthmit and @msendera. Finally, the code is based off the hydra lightning template by @ashleve and makes use of the FAB torch code for the GMM task and replay buffers.
For installation, we recommend the use of Micromamba. Please refer here for an installation guide for Micromamba. First, we install dependencies
# clone project
git clone [email protected]:jarridrb/DEM.git
cd DEM
# create micromamba environment
micromamba create -f environment.yaml
micromamba activate dem
# install requirements
pip install -r requirements.txt
Note that the hydra configs interpolate using some environment variables set in the file .env
. We provide
an example .env.example
file for convenience. Note that to use wandb we require that you set WANDB_ENTITY in your
.env
file.
To run an experiment, e.g., GMM with iDEM, you can run on the command line
python dem/train.py experiment=gmm_idem
We include configs for all experiments matching the settings we used in our paper for both iDEM and pDEM except LJ55 for which we only include a config for iDEM as pDEM had convergence issues on this dataset.
The current repository contains code for experiments for iDEM and pDEM as specified in our paper.
In this update we provide code and more detailed instructions on how to run the CFM models including log Z and ESS computation. In doing this, we also found a few bugs in the public code implementation for LJ55 (note that this codebase is an adaptation of a large number of notebooks used for the paper) which we have fixed in a set of code updates just merged to the repository.
We will use the example of LJ55 in detailing the pipeline. First, run the training script as normal as follows
python dem/train.py experiment=lj55_idem
After training is complete, find the epochs with the best val/2-Wasserstein
values in wandb. We will use the
best checkpoint to generate a training dataset for CFM in the following command. This command will also log the
2-Wasserstein and total variation distance for the dataset generated from the trained iDEM model compared to the
test set. To run this, you must provide the eval script with the checkpoint path you are using.
python dem/eval.py experiment=lj55_idem ckpt_path=<path_to_ckpt>
This will take some time to run and will generate a file named samples_<n_samples_to_generate>.pt
in the hydra
runtime directory for the eval run. We can now use these samples to train a CFM model. We provide a config lj55_idem_cfm
which has the settings to enable the CFM pipeline to run by default for the LJ55 task, though doing so for other tasks
is also simple. The main config changes required are to set model.debug_use_train_data=true, model.nll_with_cfm=true
and model.logz_with_cfm=true
. To point the CFM training run to the dataset generated from iDEM samples we can set the
energy.data_path_train
attribute to the path of the generated samples. CFM training in this example can then be done
with
python dem/train.py experiment=lj55_idem_cfm energy.data_path_train=<path_to_samples>
Finally, to eval test set NLL, take the checkpoint of the CFM run with the best val/nll
and run the eval script
again
python dem/eval.py experiment=lj55_idem_cfm ckpt_path=<path_to_cfm_ckpt>
Finally, we note that you may need to try a couple different checkpoints from the original
python dem/train.py experiment=lj55_idem
run to be used in generating samples and downstream CFM training/eval in
order to get the best combination of eval metrics.
In preparing this update we noticed our original evaluation of ESS was evaluated on a batch size of 16 on all tasks. We recommend users of our repository instead evaluate ESS on a larger batch size, (default to 1000) in the updated code. To reproduce the results in the paper you can either set this to 16 or look at the wandb during validation when training the CFM model which evaluates on batch size 16.
In our original manuscript for LJ55 we used 10 steps of "negative time" (described in Section 4 of our manuscript) during inference where we continued SDE inference for 10 extra steps using the true score at time 0. The repository code had the flag to do this turned on in the configs but the code ignored this flag. This has been corrected in the update.
If this codebase is useful towards other research efforts please consider citing us.
@misc{akhoundsadegh2024iterated,
title={Iterated Denoising Energy Matching for Sampling from Boltzmann Densities},
author={Tara Akhound-Sadegh and Jarrid Rector-Brooks and Avishek Joey Bose and Sarthak Mittal and Pablo Lemos and Cheng-Hao Liu and Marcin Sendera and Siamak Ravanbakhsh and Gauthier Gidel and Yoshua Bengio and Nikolay Malkin and Alexander Tong},
year={2024},
eprint={2402.06121},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
We welcome issues and pull requests (especially bug fixes) and contributions. We will try our best to improve readability and answer questions!
This repo is licensed under the MIT License.
The code makes heavy use of the func torch library which is included in torch 2.0.0 as well as torch vmap.