Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add instruct tuning support to LoRA training #1211

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,48 @@ the final text in the `chat` example above with Hugging Face's default template
becomes:

```text
<|im_start|>system
<|im_end|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|im_end|>user
Hello.<|im_end|>
<|im_start|>assistant
<|im_end|>assistant
How can I assistant you today.<|im_end|>
```

If you are unsure of the format to use, the `chat` or `completions` are good to
start with. For custom requirements on the format of the dataset, use the
`text` format to assemble the content yourself.

## Instruct Tuning

Instruct tuning allows you to fine-tune a model with input/output pairs and alternative loss functions. This is useful for tasks where the input is an input prompt, and the loss function targets the input/output pair.

### Dataset Format

For instruct tuning, the dataset should be in the following format:

```jsonl
{"prompt": "[INST] Your input prompt here[/INST]", "completion": "The expected output result here"}
```

### Fine-tune with Instruct Tuning

To fine-tune a model with instruct tuning, use the following command:

```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--data <path_to_data> \
--iters 600 \
--prompt-feature "prompt" \
--completion-feature "completion"
```

### Alternative Loss Functions

You can specify alternative loss functions for instruct tuning. For example, to use a custom loss function, modify the `default_loss` function in `llms/mlx_lm/tuner/trainer.py` to support alternative loss functions.

## Memory Issues

Fine-tuning a large model with LoRA requires a machine with a decent amount
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml

from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.datasets import load_dataset, CompletionsDataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
build_schedule,
Expand Down
17 changes: 17 additions & 0 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def default_loss(model, inputs, targets, lengths):
return ce, ntoks


def instruct_loss(model, inputs, targets, lengths, mask_input=True):
logits = model(inputs)
logits = logits.astype(mx.float32)

length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]

if mask_input:
input_mask = mx.arange(inputs.shape[1])[None, :] < (lengths[:, None] // 2)
length_mask = length_mask & ~input_mask

ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks

return ce, ntoks


def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
Expand Down
54 changes: 54 additions & 0 deletions llms/tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
from transformers import PreTrainedTokenizerFast
from llms.mlx_lm.tuner.datasets import CompletionsDataset, create_dataset

class TestCompletionsDataset(unittest.TestCase):

def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.data = [
{"prompt": "What is the capital of France?", "completion": "Paris."},
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
]

def test_completions_dataset(self):
dataset = CompletionsDataset(self.data, self.tokenizer, "prompt", "completion")
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))

class TestCreateDataset(unittest.TestCase):

def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.data_completions = [
{"prompt": "What is the capital of France?", "completion": "Paris."},
{"prompt": "What is the capital of Germany?", "completion": "Berlin."}
]
self.data_text = [
{"text": "This is a sample text."},
{"text": "This is another sample text."}
]
self.data_chat = [
{"messages": [{"role": "user", "content": "Hello."}, {"role": "assistant", "content": "Hi there!"}]}
]

def test_create_completions_dataset(self):
dataset = create_dataset(self.data_completions, self.tokenizer, "prompt", "completion")
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))

def test_create_text_dataset(self):
dataset = create_dataset(self.data_text, self.tokenizer)
self.assertEqual(len(dataset), 2)
self.assertTrue(isinstance(dataset[0], list))
self.assertTrue(isinstance(dataset[1], list))

def test_create_chat_dataset(self):
dataset = create_dataset(self.data_chat, self.tokenizer)
self.assertEqual(len(dataset), 1)
self.assertTrue(isinstance(dataset[0], list))

if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions llms/tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
import numpy as np
import mlx.nn as nn
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizerFast
from llms.mlx_lm.tuner.trainer import default_loss, instruct_loss

class TestLossFunctions(unittest.TestCase):

def setUp(self):
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
self.model = nn.Module()
self.inputs = np.array([[1, 2, 3], [4, 5, 6]])
self.targets = np.array([[1, 2, 3], [4, 5, 6]])
self.lengths = np.array([3, 3])

def test_default_loss(self):
loss, ntoks = default_loss(self.model, self.inputs, self.targets, self.lengths)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)

def test_instruct_loss(self):
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)

def test_instruct_loss_with_masking(self):
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths, mask_input=True)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)

def test_instruct_loss_without_masking(self):
loss, ntoks = instruct_loss(self.model, self.inputs, self.targets, self.lengths, mask_input=False)
self.assertIsInstance(loss, nn.Tensor)
self.assertIsInstance(ntoks, nn.Tensor)

if __name__ == "__main__":
unittest.main()