Skip to content

Commit

Permalink
llmblock: allow model_prompt=""
Browse files Browse the repository at this point in the history
An anticipated use case for model_prompt is simply to disable
this additional prompt. Currently the pipeline author would
need to specify model_prompt="{prompt}" to achieve this.

Make this easier by allowing model_prompt="" to have this
meaning

Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc authored and russellb committed Jul 15, 2024
1 parent 4f042ae commit 7045456
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
25 changes: 15 additions & 10 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ def __init__(
"""{system}\n{introduction}\n{principles}\n{examples}\n{generation}"""
)
self.prompt_template = self.prompt_struct.format(**self.block_config)
self.model_prompt = (
model_prompt
if model_prompt is not None
else _get_model_prompt(self.ctx.model_family)
)
self.model_prompt = model_prompt
self.output_cols = output_cols
self.batch_params = batch_kwargs
self.parser_name = parser_kwargs.get("parser_name", None)
Expand Down Expand Up @@ -129,8 +125,20 @@ def _parse(self, generated_string) -> dict:

return matches

# There are three cases to handle for self.model_prompt
# 1. None - no model_prompt specified, look one up based on model family
# 2. Non-empty string - the pipeline has specified a custom model prompt
# 3. Empty string - the pipeline has specified that no model prompt is needed
def _format_prompt(self, sample: Dict) -> str:
return self.prompt_template.format(**sample).strip()
prompt = self.prompt_template.format(**sample).strip()

model_prompt = None
if self.model_prompt is None:
model_prompt = _get_model_prompt(self.ctx.model_family)
elif self.model_prompt:
model_prompt = self.model_prompt

return prompt if model_prompt is None else model_prompt.format(prompt=prompt)

def _gen_kwargs(self, **gen_kwargs):
gen_kwargs = {**self.defaults, **gen_kwargs}
Expand All @@ -141,10 +149,7 @@ def _gen_kwargs(self, **gen_kwargs):
return gen_kwargs

def _generate(self, samples, **gen_kwargs) -> list:
prompts = [
self.model_prompt.format(prompt=self._format_prompt(sample))
for sample in samples
]
prompts = [self._format_prompt(sample) for sample in samples]
generate_args = self._gen_kwargs(**gen_kwargs)

if self.server_supports_batched:
Expand Down
45 changes: 37 additions & 8 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from unittest.mock import MagicMock, patch
import unittest

# Third Party
from datasets import Dataset, Features, Value

# First Party
from src.instructlab.sdg.llmblock import LLMBlock

Expand All @@ -13,16 +16,21 @@ def setUp(self):
self.mock_ctx.model_id = "test_model"
self.mock_pipe = MagicMock()
self.config_return_value = {
"system": "system",
"system": "{fruit}",
"introduction": "introduction",
"principles": "principles",
"examples": "examples",
"generation": "generation",
}
self.dataset = Dataset.from_dict(
{"fruit": ["apple", "pear", "mango"]},
features=Features({"fruit": Value("string")}),
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_model_prompt_empty_string(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if an empty model_prompt is not specified, no model prompt is used.
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
Expand All @@ -31,17 +39,18 @@ def test_model_prompt_empty_string(self, mock_load_config):
output_cols=[],
model_prompt="",
)
prompt = block._format_prompt(self.dataset[0])
self.assertEqual(
block.model_prompt,
"",
"model_prompt should be an empty string when explicitly set to an empty string",
prompt,
"apple\nintroduction\nprinciples\nexamples\ngeneration",
"no model prompt should be used when explicitly set to an empty string",
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_model_prompt_none(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if a custom model_prompt is not specified, it defaults to setting it to
# something based on the model family. For this we just make sure it's not an empty string.
# something based on the model family (i.e. mixtral).
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
Expand All @@ -50,8 +59,28 @@ def test_model_prompt_none(self, mock_load_config):
output_cols=[],
model_prompt=None, # Or simply omit model_prompt as it defaults to None
)
self.assertNotEqual(
block.model_prompt,
"",
prompt = block._format_prompt(self.dataset[1])
self.assertEqual(
prompt,
"<s> [INST] pear\nintroduction\nprinciples\nexamples\ngeneration [/INST]",
"model_prompt based on model_family should be used set to None",
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_model_prompt_none(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if a custom model_prompt is specified, it is used correctly
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
config_path="",
output_cols=[],
model_prompt="FOO {prompt} BAR",
)
prompt = block._format_prompt(self.dataset[1])
self.assertEqual(
prompt,
"FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR",
"model_prompt should be a non-empty string when set to None",
)

0 comments on commit 7045456

Please sign in to comment.