PyTorch implementation for the paper:
DataMUX: Data Multiplexing for Neural Networks
Vishvak Murahari, Carlos E. Jimenez, Runzhe Yang, Karthik Narasimhan
This repository contains code for reproducing results. We provide pretrained model weights and associated configs to run inference or train these models from scratch. If you find this work useful in your research, please cite:
@inproceedings{
murahari2022datamux,
title={Data{MUX}: Data Multiplexing for Neural Networks},
author={Vishvak Murahari and Carlos E Jimenez and Runzhe Yang and Karthik R Narasimhan},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=UdgtTVTdswg}
}
Our code is implemented in PyTorch. To setup, do the following:
- Install Python 3.6
- Get the source:
git clone https://github.com/princeton-nlp/DataMUX.git datamux
- Install requirements into the
datamux
virtual environment, using Anaconda:
conda env create -f env.yaml
For sentence-level classification tasks, refer to run_glue.py
and run_glue.sh
. For token-level classification tasks, refer to run_ner.py
and run_ner.sh
.
We release all the pretrained checkpoints on the Hugging Face model hub. We list the checkpoints below. For number of instances, use 2, 5, 10, 20 or 40.
Task | Model name on hub | Full path |
---|---|---|
Retrieval Warmup | datamux-retrieval-<num_instances> | princeton-nlp/datamux-retrieval-<num_instances> |
MNLI | datamux-mnli-<num_instances> | princeton-nlp/datamux-mnli-<num_instances> |
QNLI | datamux-qnli-<num_instances> | princeton-nlp/datamux-qnli-<num_instances> |
QQP | datamux-qqp-<num_instances> | princeton-nlp/datamux-qqp-<num_instances> |
SST2 | datamux-sst2-<num_instances> | princeton-nlp/datamux-sst2-<num_instances> |
NER | datamux-ner-<num_instances> | princeton-nlp/datamux-ner-<num_instances> |
The bash scripts run_ner.sh
and run_glue.sh
take the following arguments:
Argument | Flag | Explanation | Argument Choices |
---|---|---|---|
NUM_INSTANCES | -N --num_instances | Number of multiplexing instances | 2,5,10,20,40 |
DEMUXING | -d --demuxing | Demultiplexing architecture | "index", "mlp" |
MUXING | -m --muxing | Multiplexing architecture | "gaussian_hadamard", "binary_hadamard", "random_ortho" |
SETTING | -s --setting | Training setting | "baseline", "finetuning", "retrieval_pretraining" |
TASK_NAME | --task | Task name during finetuning | "mnli", "qnli", "sst2", "qqp" for run_glue.py or "ner" for run_ner.py |
LEARNING_RATE | --lr | Learning rate for optimization | Any float but we use either 2e-5 or 5e-5 |
BATCH_SIZE | --batch_size | Batch size (after multiplexing); note that the effective batch size is BATCH_SIZE * NUM_INSTANCES | Any integer. If left unset, will be set automatically based on value of N |
CONFIG_NAME | --config_name | Config path for backbone Transformer Model | Any config file in configs directory |
MODEL_PATH | --model_path | Model path if either continuing to train from a checkpoint or initialize from retrieval task pretrained checkpoint | Path to local checkpoint or path to model on the hub |
LEARN_MUXING | --learn_muxing | Whether to learn instance embeddings in multiplexing | |
DO_TRAIN | --do_train | Pass flag to do training | |
DO_EVAL | --do_eval | Pass flag to do eval |
Below we list exemplar commands for different training settings:
This commands runs retrieval pretraining for N=2
sh run_glue.sh \
-N 2 \
-d index \
-m gaussian_hadamard \
-s retrieval_pretraining \
--config_name configs/ablations/base_model/roberta.json \
--lr 5e-5 \
--do_train \
--do_eval
This command finetunes from a retrieval pretrained checkpoint with N=2
sh run_glue.sh \
-N 2 \
-d index \
-m gaussian_hadamard \
-s finetuning \
--config_name configs/ablations/base_model/roberta.json \
--lr 5e-5 \
--task mnli \
--model_path princeton-nlp/datamux-retrieval-2 \
--do_train \
--do_eval
Similar, to run token-level classification tasks like NER, change run_glue.sh
to run_ner.sh
sh run_ner.sh \
-N 2 \
-d index \
-m gaussian_hadamard \
-s finetuning \
--config_name configs/ablations/base_model/roberta.json \
--lr 5e-5 \
--task ner \
--model_path princeton-nlp/datamux-retrieval-2 \
--do_train \
--do_eval
For the non-multiplexed baselines, run the following commnands
sh run_glue.sh \
-N 1 \
-s baseline \
--config_name configs/ablations/base_model/roberta.json \
--lr 2e-5 \
--task mnli
For reproducing results on the vision tasks for MLPs and CNNs, please use this notebook
@inproceedings{
murahari2022datamux,
title={Data{MUX}: Data Multiplexing for Neural Networks},
author={Vishvak Murahari and Carlos E Jimenez and Runzhe Yang and Karthik R Narasimhan},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=UdgtTVTdswg}
}
Check LICENSE.md