diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index d7bfd9f2..13267422 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -74,7 +74,9 @@ def _convert_to_messages(sample): return sample -def _gen_train_data(machine_instruction_data, output_file_train, output_file_messages): +def _gen_train_data( + machine_instruction_data, output_file_train, output_file_messages, system_prompt +): """ Generate training data in the legacy system/user/assistant format used in train_*.jsonl as well as the legacy messages format used @@ -94,7 +96,7 @@ def _gen_train_data(machine_instruction_data, output_file_train, output_file_mes user += "\n" + synth_example["context"] assistant = _unescape(_get_response_hack(synth_example)) train_entry = { - "system": _SYS_PROMPT, + "system": system_prompt if system_prompt is not None else _SYS_PROMPT, "user": _unescape(user), "assistant": assistant, } @@ -102,7 +104,7 @@ def _gen_train_data(machine_instruction_data, output_file_train, output_file_mes sample = { "inputs": _unescape(user), "targets": assistant, - "system": _SYS_PROMPT, + "system": system_prompt if system_prompt is not None else _SYS_PROMPT, } messages_data.append(_convert_to_messages(sample)) @@ -117,13 +119,13 @@ def _gen_train_data(machine_instruction_data, output_file_train, output_file_mes outfile.write("\n") -def _knowledge_seed_example_to_test_data(seed_example): +def _knowledge_seed_example_to_test_data(seed_example, system_prompt): res = [] for qna in seed_example["questions_and_answers"]: user = qna["question"] + "\n" + seed_example["context"] res.append( { - "system": _SYS_PROMPT, + "system": system_prompt if system_prompt is not None else _SYS_PROMPT, "user": _unescape(user), "assistant": _unescape(qna["answer"]), } @@ -134,6 +136,7 @@ def _knowledge_seed_example_to_test_data(seed_example): def _gen_test_data( leaf_nodes, output_file_test, + system_prompt, ): """ Generate test data in the format needed by the legacy Linux training @@ -143,7 +146,9 @@ def _gen_test_data( for _, leaf_node in leaf_nodes.items(): for seed_example in leaf_node: if "questions_and_answers" in seed_example: - test_data.extend(_knowledge_seed_example_to_test_data(seed_example)) + test_data.extend( + _knowledge_seed_example_to_test_data(seed_example, system_prompt) + ) continue # skill seed example @@ -155,7 +160,9 @@ def _gen_test_data( test_data.append( { - "system": _SYS_PROMPT, + "system": system_prompt + if system_prompt is not None + else _SYS_PROMPT, "user": _unescape(user), "assistant": _unescape(seed_example["output"]), # answer } @@ -244,7 +251,7 @@ def load_pipeline(yaml_basename): ) -def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst): +def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst, system_prompt): data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")] data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs()) @@ -252,7 +259,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst): data_dirs, output_dir, date_suffix, - _SYS_PROMPT, + system_prompt, ctx.dataset_num_procs, knowledge_auxiliary_inst, ) @@ -264,6 +271,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst): def generate_data( client: openai.OpenAI, logger: logging.Logger = logger, # pylint: disable=redefined-outer-name + system_prompt: Optional[str] = None, model_family: Optional[str] = None, model_name: Optional[str] = None, num_cpus: Optional[int] = None, @@ -322,6 +330,7 @@ def generate_data( _gen_test_data( leaf_nodes, os.path.join(output_dir, output_file_test), + system_prompt, ) logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}") @@ -350,7 +359,9 @@ def generate_data( mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None) mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx) - mixer = _mixer_init(ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst) + mixer = _mixer_init( + ctx, output_dir, date_suffix, knowledge_pipe.auxiliary_inst, system_prompt + ) if console_output: logger.info( @@ -412,6 +423,7 @@ def generate_data( generated_data, os.path.join(output_dir, output_file_train), os.path.join(output_dir, output_file_messages), + system_prompt, ) mixer.generate() diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index b5dc7c6f..f382a351 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -20,10 +20,12 @@ import yaml # First Party -from instructlab.sdg.generate_data import _SYS_PROMPT, _context_init, generate_data +from instructlab.sdg.generate_data import _context_init, generate_data from instructlab.sdg.llmblock import LLMBlock from instructlab.sdg.pipeline import PipelineContext +TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant." + TEST_TAXONOMY_BASE = "main" TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "testdata") @@ -50,7 +52,7 @@ def validate_legacy_dataset(dataset_file_name, expected_samples): assert ds.features[feature].dtype == "string" for idx, sample in enumerate(expected_samples): - assert ds[idx]["system"] == _SYS_PROMPT + assert ds[idx]["system"] == TEST_SYS_PROMPT assert ds[idx]["user"] == sample["user"] assert ds[idx]["assistant"] == sample["assistant"] @@ -79,7 +81,7 @@ def validate_messages_dataset(dataset_file_name, expected_samples): assert ds[idx]["messages"][0]["content"] == sample["user"] assert ds[idx]["messages"][1]["role"] == "assistant" assert ds[idx]["messages"][1]["content"] == sample["assistant"] - assert ds[idx]["metadata"] == json.dumps({"system": _SYS_PROMPT}) + assert ds[idx]["metadata"] == json.dumps({"system": TEST_SYS_PROMPT}) def validate_skill_leaf_node_dataset(dataset_file_name): @@ -123,7 +125,7 @@ def validate_recipe(recipe_file_name): assert len(yaml_contents["datasets"]) == 1 assert yaml_contents["datasets"][0]["path"].endswith(".jsonl") assert "sampling_size" in yaml_contents["datasets"][0] - assert yaml_contents["metadata"]["sys_prompt"] == _SYS_PROMPT + assert yaml_contents["metadata"]["sys_prompt"] == TEST_SYS_PROMPT def validate_mixed_dataset(dataset_file_name): @@ -318,6 +320,7 @@ def test_generate(self): taxonomy_base=TEST_TAXONOMY_BASE, output_dir=self.tmp_path, pipeline="simple", + system_prompt=TEST_SYS_PROMPT, ) for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]: @@ -396,6 +399,7 @@ def test_generate(self): chunk_word_count=1000, server_ctx_size=4096, pipeline="simple", + system_prompt=TEST_SYS_PROMPT, ) for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]: @@ -493,6 +497,7 @@ def test_generate(self): chunk_word_count=1000, server_ctx_size=4096, pipeline="simple", + system_prompt=TEST_SYS_PROMPT, ) mocked_logger.warning.assert_called() assert re.search(