Skip to content

Commit

Permalink
updates notebook contents
Browse files Browse the repository at this point in the history
  • Loading branch information
djliden committed Mar 28, 2024
1 parent 8973e32 commit a0018c6
Showing 1 changed file with 201 additions and 2 deletions.
203 changes: 201 additions & 2 deletions notebooks/5_gemma_2b_axolotl/gemma_2b_axolotl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,212 @@
"\n",
"Axolotl is \"a tool designed to streamline the fine-tuning of various AI models.\" It is primarily for training Hugging Face models via full fine-tuning, lora, qlora, relora, gptq. Configurations are specified in yaml files. It supports a variety of different dataset formats. It supports additional libraries such as xformer and flash attention. It is compatible with FSDP and deepspeed for multi-gpu training. It supports logging to MLflow or WandB.\n",
"\n",
"# The recommended workflow is to pick a quickstart notebook from the [examples](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/examples) directory and modify it as needed.\n"
"The recommended workflow is to pick a quickstart notebook from the [examples](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/examples) directory and modify it as needed.\n",
"\n",
"Let's fine-tune the Gemma 2B model using Axolotl. There is already an example script for fine-tuning the 7B model, so we will adapt that to our needs.\n",
"\n",
"Note that I ran all of this in a databricks worspace on one a10 GPU. I used qlora for fine-tuning.\n",
"\n",
"## Setup\n",
"First, we install the necessary dependencies. Note that we are following the [quickstart](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#quickstart-) in the axolotl readme.\n",
"\n",
"This part is important and less straightforward than it might seem. It is important to install the pytorch version corresponding to the correct CUDA version (see [here](https://pytorch.org/get-started/locally/)). And make sure your pip version is up to date. Getting this part right took a fair amount of trial and error."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade pip\n",
"%pip install --upgrade torch==2.2.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
"\n",
"%pip install --upgrade mlflow\n",
"%pip install --upgrade packaging deepspeed transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, clone the axolotl repository."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sh\n",
"git clone https://github.com/OpenAccess-AI-Collective/axolotl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`cd` into the `axolotl` directory and install Axolotl, deepspeed, and flash-attention."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sh\n",
"cd axolotl\n",
"pip install -e '.[flash-attn,deepspeed]'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, because Gemma 2 is in a gated repo on Hugging Face, we need to log in to Hugging Face before we obtain and train the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import login\n",
"login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Obtain and modify the training YAML file\n",
"\n",
"Axolotl is a configuration-based fine-tuning tool. Let's get the configuration from the gemma 7b example file and modify it."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"--2024-03-28 16:49:54-- https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/gemma/qlora.yml\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8003::154, 2606:50c0:8000::154, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1095 (1.1K) [text/plain]\n",
"Saving to: ‘qlora.yml’\n",
"\n",
" 0K . 100% 45.4M=0s\n",
"\n",
"2024-03-28 16:49:55 (45.4 MB/s) - ‘qlora.yml’ saved [1095/1095]\n",
"\n"
]
}
],
"source": [
"%%sh\n",
"wget https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/gemma/qlora.yml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We make the following modifications to the configuration:\n",
"- change the model to `google/gemma-2b`\n",
"- change `sequence_len` to 2048 (otherwise we will encounter OOM errors)\n",
"\n",
"With this completed, we can run the qlora fine-tuning job."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-process the data\n",
"The axolotl repo quickstart recommends pre-processing the data before training. This can be done as follows (using the example dataset):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sh\n",
"CUDA_VISIBLE_DEVICES=\"\" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sh\n",
"accelerate launch -m axolotl.cli.train ./qlora.yml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Resolving Issues\n",
"\n",
"Hopefully, this just works as-is. However, I ran into a few different issues across the different environments I tested this in. Here are a few tips for debugging:\n",
"- Double check your python, torch, and CUDA versions. At the time of writing, axolotl requires Python >=3.10 and Pytorch >=2.1.1.\n",
"- Make sure Pytorch is compiled with the correct CUDA version.\n",
"- Make sure pip is up to date.\n",
"- Get the torch/cuda installations figured out before worrying about the other dependencies.\n",
"\n",
"There were a few databricks-specific issues as well.\n",
"- I had to run `databricks configure` in the CLI and put in my credentials, otherwise I ran into errors. I believe this is due to the transformers library attempting to autolog to mlflow.\n",
"- to log to MLflow, the following lines were necessary in the config:\n",
"\n",
"```yaml\n",
"# mlflow configuration\n",
"mlflow_tracking_uri: \"<your_mlflow_tracking_uri>\"\n",
"mlflow_experiment_name: \"<your_mlflow_experiment>\"\n",
"# optionally, save checkpoints to artifact registry\n",
"hf_mlflow_log_artifacts: false\n",
"```\n",
"To *avoid* using MLflow, even without this config, I needed to prepend the `accelerate launch...` command with `export DISABLE_MLFLOW_INTEGRATION=true`. Otherwise the autologging built into the transformers library (I think?) would attempt to log to mlflow, but would encounter an error without a run or experiment set.\n",
"\n",
"For more, see [this guide](https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/docs/debugging.qmd) on debugging axolotl."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit a0018c6

Please sign in to comment.