Skip to content

Commit

Permalink
feat: parametrize system prompt
Browse files Browse the repository at this point in the history
Signed-off-by: Jaideep Rao <[email protected]>
  • Loading branch information
jaideepr97 committed Nov 6, 2024
1 parent 82153b1 commit 70c39b3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
32 changes: 22 additions & 10 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _convert_to_messages(sample):
return sample


def _gen_train_data(machine_instruction_data, output_file_train, output_file_messages):
def _gen_train_data(
machine_instruction_data, output_file_train, output_file_messages, system_prompt
):
"""
Generate training data in the legacy system/user/assistant format
used in train_*.jsonl as well as the legacy messages format used
Expand All @@ -94,15 +96,15 @@ def _gen_train_data(machine_instruction_data, output_file_train, output_file_mes
user += "\n" + synth_example["context"]
assistant = _unescape(_get_response_hack(synth_example))
train_entry = {
"system": _SYS_PROMPT,
"system": system_prompt if system_prompt is not None else _SYS_PROMPT,
"user": _unescape(user),
"assistant": assistant,
}
train_data.append(train_entry)
sample = {
"inputs": _unescape(user),
"targets": assistant,
"system": _SYS_PROMPT,
"system": system_prompt if system_prompt is not None else _SYS_PROMPT,
}
messages_data.append(_convert_to_messages(sample))

Expand All @@ -117,13 +119,13 @@ def _gen_train_data(machine_instruction_data, output_file_train, output_file_mes
outfile.write("\n")


def _knowledge_seed_example_to_test_data(seed_example):
def _knowledge_seed_example_to_test_data(seed_example, system_prompt):
res = []
for qna in seed_example["questions_and_answers"]:
user = qna["question"] + "\n" + seed_example["context"]
res.append(
{
"system": _SYS_PROMPT,
"system": system_prompt if system_prompt is not None else _SYS_PROMPT,
"user": _unescape(user),
"assistant": _unescape(qna["answer"]),
}
Expand All @@ -134,6 +136,7 @@ def _knowledge_seed_example_to_test_data(seed_example):
def _gen_test_data(
leaf_nodes,
output_file_test,
system_prompt,
):
"""
Generate test data in the format needed by the legacy Linux training
Expand All @@ -143,7 +146,9 @@ def _gen_test_data(
for _, leaf_node in leaf_nodes.items():
for seed_example in leaf_node:
if "questions_and_answers" in seed_example:
test_data.extend(_knowledge_seed_example_to_test_data(seed_example))
test_data.extend(
_knowledge_seed_example_to_test_data(seed_example, system_prompt)
)
continue

# skill seed example
Expand All @@ -155,7 +160,9 @@ def _gen_test_data(

test_data.append(
{
"system": _SYS_PROMPT,
"system": system_prompt
if system_prompt is not None
else _SYS_PROMPT,
"user": _unescape(user),
"assistant": _unescape(seed_example["output"]), # answer
}
Expand Down Expand Up @@ -244,15 +251,15 @@ def load_pipeline(yaml_basename):
)


def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst):
def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt):
data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")]
data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs())

return DataMixer(
data_dirs,
output_dir,
date_suffix,
_SYS_PROMPT,
system_prompt,
ctx.dataset_num_procs,
knowledge_auxiliary_inst,
)
Expand All @@ -264,6 +271,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst):
def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
model_family: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
Expand Down Expand Up @@ -322,6 +330,7 @@ def generate_data(
_gen_test_data(
leaf_nodes,
os.path.join(output_dir, output_file_test),
system_prompt,
)

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")
Expand Down Expand Up @@ -350,7 +359,9 @@ def generate_data(
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst)
mixer = _mixer_init(
ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt
)

if console_output:
logger.info(
Expand Down Expand Up @@ -412,6 +423,7 @@ def generate_data(
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
system_prompt,
)

mixer.generate()
Expand Down
13 changes: 9 additions & 4 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import yaml

# First Party
from instructlab.sdg.generate_data import _SYS_PROMPT, _context_init, generate_data
from instructlab.sdg.generate_data import _context_init, generate_data
from instructlab.sdg.llmblock import LLMBlock
from instructlab.sdg.pipeline import PipelineContext

TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."

TEST_TAXONOMY_BASE = "main"

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "testdata")
Expand All @@ -50,7 +52,7 @@ def validate_legacy_dataset(dataset_file_name, expected_samples):
assert ds.features[feature].dtype == "string"

for idx, sample in enumerate(expected_samples):
assert ds[idx]["system"] == _SYS_PROMPT
assert ds[idx]["system"] == TEST_SYS_PROMPT
assert ds[idx]["user"] == sample["user"]
assert ds[idx]["assistant"] == sample["assistant"]

Expand Down Expand Up @@ -79,7 +81,7 @@ def validate_messages_dataset(dataset_file_name, expected_samples):
assert ds[idx]["messages"][0]["content"] == sample["user"]
assert ds[idx]["messages"][1]["role"] == "assistant"
assert ds[idx]["messages"][1]["content"] == sample["assistant"]
assert ds[idx]["metadata"] == json.dumps({"system": _SYS_PROMPT})
assert ds[idx]["metadata"] == json.dumps({"system": TEST_SYS_PROMPT})


def validate_skill_leaf_node_dataset(dataset_file_name):
Expand Down Expand Up @@ -123,7 +125,7 @@ def validate_recipe(recipe_file_name):
assert len(yaml_contents["datasets"]) == 1
assert yaml_contents["datasets"][0]["path"].endswith(".jsonl")
assert "sampling_size" in yaml_contents["datasets"][0]
assert yaml_contents["metadata"]["sys_prompt"] == _SYS_PROMPT
assert yaml_contents["metadata"]["sys_prompt"] == TEST_SYS_PROMPT


def validate_mixed_dataset(dataset_file_name):
Expand Down Expand Up @@ -318,6 +320,7 @@ def test_generate(self):
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
Expand Down Expand Up @@ -396,6 +399,7 @@ def test_generate(self):
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
Expand Down Expand Up @@ -493,6 +497,7 @@ def test_generate(self):
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
mocked_logger.warning.assert_called()
assert re.search(
Expand Down

0 comments on commit 70c39b3

Please sign in to comment.