Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_prompt config param for LLMBlock #141

Merged
merged 4 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/pipeline_config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Pipeline Configuration

Built-in pipeline configurations can be found in [`src/instructlab/sdg/pipelines/`](../src/instructlab/sdg/pipelines/).

## Pipeline Configuration Schema

A schema for validating pipeline configuration can be found in [`src/instructlab/sdg/pipelines/schema/v1.json`](../src//instructlab/sdg/pipelines/schema/v1.json)
Copy link
Contributor

@derekhiggins derekhiggins Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
double "/"'s in path name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops


## Version History

| Version | Description |
| --- | --- |
| 1.0 | Initial version |
24 changes: 18 additions & 6 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
block_name,
config_path,
output_cols,
model_prompt=None,
parser_kwargs={},
batch_kwargs={},
) -> None:
Expand All @@ -69,7 +70,7 @@ def __init__(
"""{system}\n{introduction}\n{principles}\n{examples}\n{generation}"""
)
self.prompt_template = self.prompt_struct.format(**self.block_config)
self.model_prompt = _get_model_prompt(self.ctx.model_family)
self.model_prompt = model_prompt
self.output_cols = output_cols
self.batch_params = batch_kwargs
self.parser_name = parser_kwargs.get("parser_name", None)
Expand Down Expand Up @@ -124,8 +125,20 @@ def _parse(self, generated_string) -> dict:

return matches

# There are three cases to handle for self.model_prompt
# 1. None - no model_prompt specified, look one up based on model family
# 2. Non-empty string - the pipeline has specified a custom model prompt
# 3. Empty string - the pipeline has specified that no model prompt is needed
def _format_prompt(self, sample: Dict) -> str:
return self.prompt_template.format(**sample).strip()
prompt = self.prompt_template.format(**sample).strip()

model_prompt = None
if self.model_prompt is None:
model_prompt = _get_model_prompt(self.ctx.model_family)
elif self.model_prompt:
model_prompt = self.model_prompt

return prompt if model_prompt is None else model_prompt.format(prompt=prompt)

def _gen_kwargs(self, **gen_kwargs):
gen_kwargs = {**self.defaults, **gen_kwargs}
Expand All @@ -136,10 +149,7 @@ def _gen_kwargs(self, **gen_kwargs):
return gen_kwargs

def _generate(self, samples, **gen_kwargs) -> list:
prompts = [
self.model_prompt.format(prompt=self._format_prompt(sample))
for sample in samples
]
prompts = [self._format_prompt(sample) for sample in samples]
generate_args = self._gen_kwargs(**gen_kwargs)

if self.server_supports_batched:
Expand Down Expand Up @@ -221,6 +231,7 @@ def __init__(
config_paths,
output_cols,
selector_column_name,
model_prompt=None,
parser_kwargs={},
batch_kwargs={},
) -> None:
Expand All @@ -230,6 +241,7 @@ def __init__(
block_name,
config_paths[0][0],
output_cols,
model_prompt=model_prompt,
parser_kwargs=parser_kwargs,
batch_kwargs=batch_kwargs,
)
Expand Down
6 changes: 6 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@
"type": "string"
}
},
"model_prompt": {
"type": "string"
},
"parser_kwargs": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -171,6 +174,9 @@
"type": "string"
}
},
"model_prompt": {
"type": "string"
},
"selector_column_name": {
"type": "string"
},
Expand Down
86 changes: 86 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Standard
from unittest.mock import MagicMock, patch
import unittest

# Third Party
from datasets import Dataset, Features, Value

# First Party
from src.instructlab.sdg.llmblock import LLMBlock


class TestLLMBlockModelPrompt(unittest.TestCase):
def setUp(self):
self.mock_ctx = MagicMock()
self.mock_ctx.model_family = "mixtral"
self.mock_ctx.model_id = "test_model"
self.mock_pipe = MagicMock()
self.config_return_value = {
"system": "{fruit}",
"introduction": "introduction",
"principles": "principles",
"examples": "examples",
"generation": "generation",
}
self.dataset = Dataset.from_dict(
{"fruit": ["apple", "pear", "mango"]},
features=Features({"fruit": Value("string")}),
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_model_prompt_empty_string(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if an empty model_prompt is not specified, no model prompt is used.
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
config_path="",
output_cols=[],
model_prompt="",
)
prompt = block._format_prompt(self.dataset[0])
self.assertEqual(
prompt,
"apple\nintroduction\nprinciples\nexamples\ngeneration",
"no model prompt should be used when explicitly set to an empty string",
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_model_prompt_none(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if a custom model_prompt is not specified, it defaults to setting it to
# something based on the model family (i.e. mixtral).
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
config_path="",
output_cols=[],
model_prompt=None, # Or simply omit model_prompt as it defaults to None
)
prompt = block._format_prompt(self.dataset[1])
self.assertEqual(
prompt,
"<s> [INST] pear\nintroduction\nprinciples\nexamples\ngeneration [/INST]",
"model_prompt based on model_family should be used set to None",
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_model_prompt_none(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if a custom model_prompt is specified, it is used correctly
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
config_path="",
output_cols=[],
model_prompt="FOO {prompt} BAR",
)
prompt = block._format_prompt(self.dataset[1])
self.assertEqual(
prompt,
"FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR",
"model_prompt should be a non-empty string when set to None",
)