Skip to content

Commit

Permalink
Add Support for Gradient Checkpointing (#759)
Browse files Browse the repository at this point in the history
# Add Support for Gradient Checkpointing
This PR adds support for gradient checkpointing.

Gradient checkpointing is a technique that trades computation for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful when training large models. Because we
recompute values during the backpropagation, we need to preserve the original ForwardContext in this phase. I solved this by overwriting the `gradient_checkpointing_enable` function so that the checkpoint function
receives the current ForwardContext as the backward pass context manager.

---------

Co-authored-by: calpt <[email protected]>
  • Loading branch information
lenglaender and calpt authored Jan 26, 2025
1 parent 127f51b commit adef6dc
Show file tree
Hide file tree
Showing 20 changed files with 659 additions and 10 deletions.
4 changes: 4 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,7 @@ trainer = AdapterTrainer(
_Adapters_ supports fine-tuning of quantized language models similar to [QLoRA (Dettmers et al., 2023)](https://arxiv.org/pdf/2305.14314.pdf) via the `bitsandbytes` library integrated into Transformers.
Quantized training is supported for LoRA-based adapters as well as bottleneck adapters and prefix tuning.
Please refer to [this notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) for a hands-on guide.

## Gradient Checkpointing
Gradient checkpointing is supported for all models (e.g. Llama 1/2/3) except for the models that are not supported by Hugging Face Transformers (like ALBERT). Please refer to [this notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Gradient_Checkpointing_Llama.ipynb) for a hands-on guide.

339 changes: 339 additions & 0 deletions notebooks/Gradient_Checkpointing_Llama.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "introduction",
"metadata": {},
"source": [
"# Efficient Llama Training with Gradient Checkpointing and _Adapters_\n",
"\n",
"In this notebook, we show how to efficiently fine-tune a **Llama 3** model using **gradient checkpointing** and adapter methods.\n",
"\n",
"**Gradient checkpointing** is a technique to reduce peak memory usage significantly and thus enables training larger models with larger batch sizes. Gradient checkpointing achieves this by trading compute for memory: During the forward pass, gradient checkpointing only stores a subset of activations (thus saving memory). During backpropagation, gradient checkpointing recomputes the activations that were not stored. This can significantly reduce memory requirements at the cost of slightly increased computation time.\n",
"\n",
"In this notebook, we finetune Llama-3 8B on supervised instruction tuning data collected by the [Open Assistant project](https://github.com/LAION-AI/Open-Assistant) for training chatbots.\n",
"\n",
"Another way to reduce memore usage is to use quantization. Have a look a the [QLora notebook](QLoRA_Llama_Finetuning.ipynb) for an example. This gradient checkpointing notebook is based on the QLoRA notebook. While we use a normal LoRA setup in this notebook, you can easily replace LoRA with QLoRA to reduce memory usage even further."
]
},
{
"cell_type": "markdown",
"id": "installation",
"metadata": {},
"source": [
"## Installation\n",
"\n",
"We need `adapters`, `datasets` and `pytorch` for training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "install",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qq -U adapters datasets torch"
]
},
{
"cell_type": "markdown",
"id": "dataset",
"metadata": {},
"source": [
"## Load Open Assistant dataset\n",
"\n",
"We use the [`timdettmers/openassistant-guanaco`](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset, which contains a small subset of conversations from the full Open Assistant database."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "load_dataset",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text'],\n",
" num_rows: 9846\n",
" })\n",
" test: Dataset({\n",
" features: ['text'],\n",
" num_rows: 518\n",
" })\n",
"})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(\"timdettmers/openassistant-guanaco\")\n",
"dataset"
]
},
{
"cell_type": "markdown",
"id": "model_setup",
"metadata": {},
"source": [
"## Load and prepare model\n",
"\n",
"We download the official Llama-2 7B/ Llama-3 8B checkpoint from the HuggingFace Hub. Note that you must request access to this model on the HuggingFace website and use an API token to download it.\n",
"\n",
"The key difference in this notebook is that we'll enable gradient checkpointing to reduce memory usage during training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "load_model",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "83e60dee3c434bb3a2bc656bd7f4b667",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"# To access the Llama 3 model, you need to provide your key:\n",
"HUGGINGFACE_ACCESS_TOKEN = \"<PASTE_YOUR_TOKEN_HERE>\"\n",
"\n",
"modelpath=\"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"# Load model with gradient checkpointing enabled\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" modelpath, \n",
" device_map=\"auto\",\n",
" torch_dtype=torch.bfloat16,\n",
" token=HUGGINGFACE_ACCESS_TOKEN,\n",
")\n",
"model.config.use_cache = False\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(modelpath, token=HUGGINGFACE_ACCESS_TOKEN)\n",
"tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "markdown",
"id": "5cd73b7d",
"metadata": {},
"source": [
"If you get a message similar to `WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu and disk.`, then the model itself is too big for your GPU. If you don't have a bigger / additional GPU at hand, you can use a quantization method like we show in the [QLoRA notebook](QLoRA_Llama_Finetuning.ipynb). Adding the quantization_config when loading the model and choosing a quantized `LoRAConfig` in the next step will enable quantized training."
]
},
{
"cell_type": "markdown",
"id": "adapter_setup",
"metadata": {},
"source": [
"## Initialize adapter\n",
"\n",
"We initialize the adapter functionality and add a LoRA adapter. When using gradient checkpointing with adapters, we need to enable input gradients explicitly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "init_adapter",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================================================================\n",
"Name Architecture #Param %Param Active Train\n",
"--------------------------------------------------------------------------------\n",
"lora_adapter lora 3,407,872 0.085 1 1\n",
"--------------------------------------------------------------------------------\n",
"Full model 4,015,263,744 100.000 0\n",
"================================================================================\n"
]
}
],
"source": [
"import adapters\n",
"from adapters import LoRAConfig\n",
"\n",
"adapters.init(model)\n",
"\n",
"config = LoRAConfig()\n",
"model.add_adapter(\"lora_adapter\", config=config)\n",
"model.train_adapter(\"lora_adapter\")\n",
"\n",
"# Activate gradient checkpointing\n",
"model.gradient_checkpointing_enable()\n",
"\n",
"print(model.adapter_summary())"
]
},
{
"cell_type": "markdown",
"id": "data_prep",
"metadata": {},
"source": [
"## Prepare data for training\n",
"\n",
"The dataset is tokenized and truncated."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "tokenize",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"def tokenize(element):\n",
" return tokenizer(\n",
" element[\"text\"],\n",
" truncation=True,\n",
" max_length=512,\n",
" add_special_tokens=False,\n",
" )\n",
"\n",
"dataset_tokenized = dataset.map(\n",
" tokenize, \n",
" batched=True, \n",
" num_proc=os.cpu_count(),\n",
" remove_columns=[\"text\"]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "training",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We specify training hyperparameters and train the model using the `AdapterTrainer` class. With gradient checkpointing enabled, we can use larger batch sizes than would otherwise be possible."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "training_args",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"args = TrainingArguments(\n",
" output_dir=\"output/llama_gradient_checkpointing\",\n",
" per_device_train_batch_size=1,\n",
" per_device_eval_batch_size=1,\n",
" evaluation_strategy=\"steps\",\n",
" logging_steps=10,\n",
" save_steps=500,\n",
" eval_steps=187,\n",
" save_total_limit=3,\n",
" gradient_accumulation_steps=16,\n",
" max_steps=1875,\n",
" learning_rate=0.0002,\n",
" bf16=True,\n",
" warmup_ratio=0.03,\n",
" group_by_length=True,\n",
" lr_scheduler_type=\"constant\",\n",
" optim=\"adamw_torch\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "train",
"metadata": {},
"outputs": [],
"source": [
"from adapters import AdapterTrainer\n",
"from transformers import DataCollatorForLanguageModeling\n",
"\n",
"trainer = AdapterTrainer(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
" train_dataset=dataset_tokenized[\"train\"],\n",
" eval_dataset=dataset_tokenized[\"test\"],\n",
" args=args,\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "inference",
"metadata": {},
"source": [
"## Inference\n",
"\n",
"For inference, we can disable gradient checkpointing since we don't need gradients:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "inference_setup",
"metadata": {},
"outputs": [],
"source": [
"# Disable gradient checkpointing for inference\n",
"model.gradient_checkpointing_disable()\n",
"model.config.use_cache = True\n",
"\n",
"def prompt_model(model, text: str):\n",
" batch = tokenizer(f\"### Human: {text}\\n### Assistant:\", return_tensors=\"pt\")\n",
" batch = batch.to(model.device)\n",
" \n",
" model.eval()\n",
" with torch.inference_mode():\n",
" output_tokens = model.generate(**batch, max_new_tokens=50)\n",
"\n",
" return tokenizer.decode(output_tokens[0], skip_special_tokens=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "test_inference",
"metadata": {},
"outputs": [],
"source": [
"print(prompt_model(model, \"Explain gradient checkpointing in simple terms\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
7 changes: 6 additions & 1 deletion notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us
| Notebook | Description | |
|:----------------|:---------------------|--:|
| [Text Generation](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Text_Generation_Training.ipynb) | How to train an adapter for language generation. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Text_Generation_Training.ipynb) |
| [QLoRA LLama Finetuning](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) |
| [Training a NER Adapter](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | How to train an adapter on a named entity recoginition task. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) |
| [Adapter Drop Training](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Drop_Training.ipynb) | How to train an adapter using AdapterDrop | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Drop_Training.ipynb) |
| [Inference example for id2label](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | How to use the id2label dictionary for inference | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_id2label_inference.ipynb) |
| [NER on Wikiann](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/08_NER_Wikiann.ipynb) | Evaluating adapters on NER on the wikiann dataset | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/08_NER_Wikiann.ipynb) |
| [Finetuning Whisper with Adapters](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | Fine Tuning Whisper using LoRA | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) |
| [Adapter Training with ReFT](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | Fine Tuning using ReFT Adapters | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) |
| [ViT Fine-Tuning with AdapterPlus](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | ViT Fine-Tuning with AdapterPlus | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) |

### Memory Efficient Training
| Notebook | Description | |
|:----------------|:---------------------|--:|
| [QLoRA LLama Finetuning](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) |
| [Gradient Checkpointing](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Gradient_Checkpointing_Llama.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) |
4 changes: 2 additions & 2 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def adjust_tensors_for_parallel(hidden_states, *tensors):
"""
outputs = []
for tensor in tensors:
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]:
if tensor is not None and hidden_states.shape[0] > tensor.shape[0]:
repeats = [1] * len(tensor.shape)
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
Expand All @@ -288,7 +288,7 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors):
In-place version of adjust_tensors_for_parallel().
"""
for tensor in tensors:
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]:
if tensor is not None and hidden_states.shape[0] > tensor.shape[0]:
repeats = [1] * len(tensor.shape)
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
Expand Down
Loading

0 comments on commit adef6dc

Please sign in to comment.