From 49cd760dd71c51062b33cf2ef62c112a55b92cba Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Fri, 31 Jan 2025 09:26:20 -0800 Subject: [PATCH 1/6] Initial changes for phi medium support --- .../components/data_generation/spec.yaml | 7 +- .../spec.yaml | 18 ++- .../spec.yaml | 7 +- .../spec.yaml | 2 +- .../spec.yaml | 2 +- .../data_generation_file_selector/spec.yaml | 2 +- .../spec.yaml | 9 +- .../spec.yaml | 2 +- .../components/pipeline/spec.yaml | 27 ++-- .../components/pipeline_validation/spec.yaml | 8 +- .../distillation/src/common/constants.py | 11 +- .../distillation/src/common/student_models.py | 126 ++++++++++++++++++ .../training/distillation/src/common/utils.py | 7 +- .../distillation/src/generate_data.py | 38 +++++- .../src/generate_data_postprocess.py | 18 +++ .../distillation/src/validate_pipeline.py | 2 +- 16 files changed, 237 insertions(+), 49 deletions(-) create mode 100644 assets/training/distillation/src/common/student_models.py diff --git a/assets/training/distillation/components/data_generation/spec.yaml b/assets/training/distillation/components/data_generation/spec.yaml index 0de2961f05..15092911d2 100644 --- a/assets/training/distillation/components/data_generation/spec.yaml +++ b/assets/training/distillation/components/data_generation/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_generate_data -version: 0.0.8 +version: 0.0.9.test1 type: command is_deterministic: True @@ -121,6 +121,10 @@ inputs: type: uri_file description: Validation status. mode: rw_mount + + model_asset_id: + type: string + description: Student model to use outputs: generated_train_file_path: @@ -152,5 +156,6 @@ command: >- $[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]] $[[--max_len_summary ${{inputs.max_len_summary}}]] --data_generation_task_type ${{inputs.data_generation_task_type}} + --model_asset_id ${{inputs.model_asset_id}} --generated_train_file_path ${{outputs.generated_train_file_path}} --generated_validation_file_path ${{outputs.generated_validation_file_path}} diff --git a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml index 0d4d8454f0..6e9ea394ad 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: oss_distillation_batchscoring_datagen_pipeline -version: 0.0.1 +version: 0.0.1.test1 type: pipeline @@ -160,6 +160,9 @@ inputs: 4. MATH: Generate Math data for numerical responses 5. SUMMARIZATION: Generate Key Summary for an Article + model_asset_id: + type: string + description: The student model to finetune # Output of validation component. validation_info: @@ -256,7 +259,7 @@ outputs: jobs: oss_distillation_generate_data_batch_preprocess: type: command - component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1 + component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1.test1 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' @@ -296,7 +299,7 @@ jobs: # Config generator job oss_distillation_generate_data_config_generator: type: command - component: azureml:batch_benchmark_config_generator:0.0.9 + component: azureml://registries/azureml/components/batch_benchmark_config_generator/versions/0.0.9 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -322,7 +325,7 @@ jobs: # Batch score job oss_distillation_train_data_batch_score: type: parallel - component: azureml:batch_score_oss:0.0.1 + component: azureml://registries/azureml/components/batch_score_oss/versions/0.0.1 compute: '${{parent.inputs.compute_data_generation}}' identity: type: user_identity @@ -349,7 +352,7 @@ jobs: validation_file_path_exists: type: command - component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1 + component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1.test1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -366,7 +369,7 @@ jobs: # Batch score job oss_distillation_validation_data_batch_score: type: parallel - component: azureml:batch_score_oss:0.0.1 + component: azureml://registries/azureml/components/batch_score_oss/versions/0.0.1 compute: '${{parent.inputs.compute_data_generation}}' identity: type: user_identity @@ -393,7 +396,7 @@ jobs: oss_distillation_generate_data_batch_postprocess: type: command - component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1 + component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1.test1 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' @@ -410,6 +413,7 @@ jobs: enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}' data_generation_task_type: '${{parent.inputs.data_generation_task_type}}' min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}' + model_asset_id: '${{parent.inputs.model_asset_id}}' connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}} outputs: generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}' diff --git a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml index 1720d6d8f6..6b8a7a0fa8 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_generate_data_batch_postprocess -version: 0.0.1 +version: 0.0.1.test1 type: command is_deterministic: False @@ -82,6 +82,10 @@ inputs: type: uri_file description: Connection config file for batch scoring + model_asset_id: + type: string + description: The student model to finetune + outputs: generated_batch_train_file_path: type: uri_file @@ -104,6 +108,7 @@ command: >- --min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}} $[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]] $[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]] + --model_asset_id ${{inputs.model_asset_id}} --data_generation_task_type ${{inputs.data_generation_task_type}} --connection_config_file ${{inputs.connection_config_file}} --generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}} diff --git a/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml index 5ae8aa2f68..11637bbbcc 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_generate_data_batch_preprocess -version: 0.0.1 +version: 0.0.1.test1 type: command is_deterministic: False diff --git a/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml index 7c3dfe12e6..93ea3f3ce9 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_data_generation_batch_scoring_selector -version: 0.0.1 +version: 0.0.1.test1 type: command is_deterministic: True diff --git a/assets/training/distillation/components/data_generation_file_selector/spec.yaml b/assets/training/distillation/components/data_generation_file_selector/spec.yaml index 448eb8d2f1..5fb7f4a10a 100644 --- a/assets/training/distillation/components/data_generation_file_selector/spec.yaml +++ b/assets/training/distillation/components/data_generation_file_selector/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_data_generation_file_selector -version: 0.0.1 +version: 0.0.1.test1 type: command is_deterministic: True diff --git a/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml b/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml index ce2d2cae1a..9e3db28ecd 100644 --- a/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml +++ b/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: oss_distillation_seq_scoring_pipeline -version: 0.0.1 +version: 0.0.1.test1 type: pipeline @@ -172,6 +172,10 @@ inputs: 4. MATH: Generate Math data for numerical responses 5. SUMMARIZATION: Generate Key Summary for an Article + model_asset_id: + type: string + description: The student model asset id + optional: false # Training parameters num_train_epochs: @@ -212,7 +216,7 @@ outputs: jobs: oss_distillation_generate_data: type: command - component: azureml:oss_distillation_generate_data:0.0.8 + component: azureml:oss_distillation_generate_data:0.0.9.test1 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' @@ -236,6 +240,7 @@ jobs: request_batch_size: '${{parent.inputs.request_batch_size}}' min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}' validation_output: '${{parent.inputs.validation_output}}' + model_asset_id: '${{parent.inputs.model_asset_id}}' outputs: generated_train_file_path: '${{parent.outputs.generated_train_file_path}}' generated_validation_file_path: '${{parent.outputs.generated_validation_file_path}}' diff --git a/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml b/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml index e44d9b6ceb..0ba460f58b 100644 --- a/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml +++ b/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_data_generation_validation_file_checker -version: 0.0.1 +version: 0.0.1.test1 type: command is_deterministic: True diff --git a/assets/training/distillation/components/pipeline/spec.yaml b/assets/training/distillation/components/pipeline/spec.yaml index 8e179ea29c..81e88e34ce 100644 --- a/assets/training/distillation/components/pipeline/spec.yaml +++ b/assets/training/distillation/components/pipeline/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: oss_distillation_pipeline -version: 0.0.10 +version: 0.0.10.test1 type: pipeline @@ -270,11 +270,11 @@ inputs: optional: true description: Validation parameters propagated from pipeline. - # Model parameters + # Student Model parameters model_asset_id: type: string optional: false - description: Asset id of model + description: Asset id of the student model # Model registration registered_model_name: @@ -297,7 +297,7 @@ outputs: jobs: oss_distillation_validate_pipeline: type: command - component: azureml:oss_distillation_validate_pipeline:0.0.5 + component: azureml:oss_distillation_validate_pipeline:0.0.5.test1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -323,6 +323,7 @@ jobs: num_train_epochs: '${{parent.inputs.num_train_epochs}}' per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}' learning_rate: '${{parent.inputs.learning_rate}}' + model_asset_id: '${{parent.inputs.model_asset_id}}' outputs: validation_info: type: uri_file @@ -330,7 +331,7 @@ jobs: data_generation_batch_scoring_selector: type: command - component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1 + component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1.test1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -347,7 +348,7 @@ jobs: oss_distillation_batchscoring_datagen_pipeline: type: pipeline - component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1 + component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1.test1 inputs: instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}' instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}' @@ -387,6 +388,7 @@ jobs: max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}' mini_batch_size: '${{parent.inputs.mini_batch_size}}' validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}' + model_asset_id: '${{parent.inputs.model_asset_id}}' outputs: generated_batch_train_file_path: @@ -398,7 +400,7 @@ jobs: oss_distillation_seq_scoring_pipeline: type: pipeline - component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1 + component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1.test1 inputs: instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}' instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}' @@ -426,6 +428,7 @@ jobs: max_len_summary: '${{parent.inputs.max_len_summary}}' data_generation_task_type: '${{parent.inputs.data_generation_task_type}}' validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}' + model_asset_id: '${{parent.inputs.model_asset_id}}' outputs: generated_train_file_path: type: uri_file @@ -437,7 +440,7 @@ jobs: oss_distillation_train_data_generation_file_selector: type: command - component: azureml:oss_distillation_data_generation_file_selector:0.0.1 + component: azureml:oss_distillation_data_generation_file_selector:0.0.1.test1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -460,7 +463,7 @@ jobs: oss_text_generation_data_import: type: command - component: azureml:oss_text_generation_data_import:0.0.25 + component: azureml://registries/azureml/components/oss_text_generation_data_import/versions/0.0.26 compute: '${{parent.inputs.compute_data_import}}' resources: instance_type: '${{parent.inputs.instance_type_data_import}}' @@ -472,13 +475,13 @@ jobs: environment_variables: _AZUREML_CR_ENABLE_ITP_CAP: "false" inputs: - train_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_train_file_path}}' - validation_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_validation_file_path}}' + train_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_train_file_path}}' + validation_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_validation_file_path}}' system_properties: '${{parent.inputs.system_properties}}' oss_chat_completion_finetune: type: command - component: azureml:oss_chat_completion_finetune:0.0.25 + component: azureml://registries/azureml/components/oss_chat_completion_finetune/versions/0.0.26 compute: '${{parent.inputs.compute_finetune}}' resources: instance_type: '${{parent.inputs.instance_type_finetune}}' diff --git a/assets/training/distillation/components/pipeline_validation/spec.yaml b/assets/training/distillation/components/pipeline_validation/spec.yaml index 84a99106f0..490f5165bd 100644 --- a/assets/training/distillation/components/pipeline_validation/spec.yaml +++ b/assets/training/distillation/components/pipeline_validation/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_validate_pipeline -version: 0.0.5 +version: 0.0.5.test1 type: command is_deterministic: true @@ -135,6 +135,11 @@ inputs: optional: true description: Start learning rate. + model_asset_id: + type: string + optional: false + description: The student model to finetune + outputs: validation_info: type: uri_file @@ -163,4 +168,5 @@ command: >- $[[--num_train_epochs ${{inputs.num_train_epochs}}]] $[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]] $[[--learning_rate ${{inputs.learning_rate}}]] + --model_asset_id '${{inputs.model_asset_id}}' --validation_info ${{outputs.validation_info}} diff --git a/assets/training/distillation/src/common/constants.py b/assets/training/distillation/src/common/constants.py index 6835f7f713..8d02799bff 100644 --- a/assets/training/distillation/src/common/constants.py +++ b/assets/training/distillation/src/common/constants.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Data generatior constants.""" +"""Data generation constants.""" import re from enum import EnumMeta, Enum @@ -36,15 +36,6 @@ } } -# SUPPORTED STUDENT MODEL -# MAP keys are model name in registry, which maps to specific model details like registry and supported versions -SUPPORTED_STUDENT_MODEL_MAP = { - "Meta-Llama-3.1-8B-Instruct": { - "supported_registries": ["azureml-meta"], - "supported_version_pattern": re.compile(r"\d+"), - } -} - # Scoring paths VLLM_CHAT_SCORE_PATH = "/v1/chat/completions" HFTV2_TEXT_GEN_SCORE_PATH = "/score" diff --git a/assets/training/distillation/src/common/student_models.py b/assets/training/distillation/src/common/student_models.py new file mode 100644 index 0000000000..e1395e108f --- /dev/null +++ b/assets/training/distillation/src/common/student_models.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Student Model Info and Requirements.""" + +import re +from typing import Dict, List + +from common.constants import REGISTRY_MODEL_PATTERN, DataGenerationTaskType + + +class StudentModels: + SUPPORTED_STUDENT_MODELS = { + "Meta-Llama-3.1-8B-Instruct": { + "supported_registries": ["azureml-meta"], + "supported_version_pattern": re.compile(r"\d+") + }, + "Phi-3-mini-4k-instruct": { + "supported_registries": ["azureml"], + "supported_version_pattern": re.compile(r"\d+") + }, + "Phi-3-mini-128k-instruct": { + "supported_registries": ["azureml"], + "supported_version_pattern": re.compile(r"\d+") + }, + "Phi-3.5-mini-instruct": { + "supported_registries": ["azureml"], + "supported_version_pattern": re.compile(r"\d+") + }, + "Phi-3.5-MoE-instruct": { + "supported_registries": ["azureml"], + "supported_version_pattern": re.compile(r"\d+") + }, + "Phi-3-medium-4k-instruct": { + "supported_registries": ["azureml"], + "supported_version_pattern": re.compile(r"\d+"), + }, + "Phi-3-medium-128k-instruct": { + "supported_registries": ["azureml"], + "supported_version_pattern": re.compile(r"\d+"), + }, + } + + # Student models that do not recognize system prompts + NO_SYSTEM_PROMPT_MODELS = [ + "Phi-3-medium-4k-instruct", + "Phi-3-medium-128k-instruct" + ] + + @classmethod + def no_system_prompt_reformat(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]: + """Adds system prompt to user prompt for student models that do not + accept system prompts. + + :param data: The synthetic data generated from the teacher model + :type data: List[Dict[str, list]] + :return: Reformated data + :rtype: List[Dict[str, list]] + """ + new_data = [] + system_message = "" + for messages in data: + system_message = messages["messages"][0]["content"] + question = messages["messages"][1]["content"] + reformatted_data = { + "messages": + [ + {"role": "user", "content": system_message + " " + question}, + messages["messages"][2] + ] + } + new_data.append(reformatted_data) + return new_data + + @classmethod + def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]): + new_data = [] + system_message = "" + for messages in data: + system_message = messages["messages"][0]["content"] + user_prompt = messages["messages"][1]["content"] + reformatted_data = { + "messages": + [ + {"role": "user", "content": system_message + " " + user_prompt}, + messages["messages"][2:] + ] + } + new_data.append(reformatted_data) + return new_data + + @classmethod + def reformat(cls, student_model: str, task_type: str, data: List[Dict[str, list]]) -> List[Dict[str, list]]: + """Reformats synthetic data based on the student model and task type requirements. + + :param student_model: The student model to finetune + :type student_model: str + :param task_type: The data generation task type + :type task_type: str + :param data: The synthetic data generated from the teacher model + :type data: List[Dict[str, list]] + :return: Reformatted data based on student model and task type + :rtype: List[Dict[str, list]] + """ + if student_model in cls.NO_SYSTEM_PROMPT_MODELS: + if task_type == DataGenerationTaskType.CONVERSATION: + return cls.no_system_prompt_reformat_conversation(data) + return cls.no_system_prompt_reformat(data) + return data + + @classmethod + def parse_model_asset_id(cls, asset_id: str) -> str: + """Parse asset id to extract the student model name. + + :param asset_id: The asset id of the student model in the form + azureml://registries/{registry}/models/{model}/versions/{version}. + :type asset_id: str + """ + match = re.search(REGISTRY_MODEL_PATTERN, asset_id) + model = match.group("model") + + if model not in cls.NO_SYSTEM_PROMPT_MODELS: + raise Exception("Model is not in supported student model list") + return model + + diff --git a/assets/training/distillation/src/common/utils.py b/assets/training/distillation/src/common/utils.py index b0dac291d4..c5c7e48248 100644 --- a/assets/training/distillation/src/common/utils.py +++ b/assets/training/distillation/src/common/utils.py @@ -33,11 +33,12 @@ from common.constants import ( REQUESTS_RETRY_DELAY, REGISTRY_MODEL_PATTERN, - SUPPORTED_STUDENT_MODEL_MAP, SUPPORTED_TEACHER_MODEL_MAP, BackoffConstants, ) +from common.student_models import StudentModels + logger = get_logger_app( "azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import" @@ -396,7 +397,7 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]: Returns: Tuple[str, str, str]: Tuple containing registry name, model name and model version """ - return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP) + return _get_model_details(model_asset_id, StudentModels.SUPPORTED_STUDENT_MODELS) def get_base_url(url: str) -> str: @@ -499,4 +500,4 @@ def get_hash_value(data: Union[Dict[str, Any], str]) -> str: """ if isinstance(data, str): return hashlib.sha256(data.encode()).hexdigest() - return hashlib.sha256(json.dumps(data).encode()).hexdigest() + return hashlib.sha256(json.dumps(data).encode()).hexdigest() \ No newline at end of file diff --git a/assets/training/distillation/src/generate_data.py b/assets/training/distillation/src/generate_data.py index 2a50a281af..7fc3508c1a 100644 --- a/assets/training/distillation/src/generate_data.py +++ b/assets/training/distillation/src/generate_data.py @@ -52,6 +52,8 @@ DEFAULT_MAX_LEN_SUMMARY, ) +from common.student_models import StudentModels + from common.utils import ( get_workspace_mlclient, get_endpoint_details, @@ -219,6 +221,13 @@ def get_parser(): choices=[v.value for v in DataGenerationTaskType], ) + parser.add_argument( + "--model_asset_id", + type=str, + required=True, + help="Student model to use" + ) + return parser @@ -273,7 +282,9 @@ def generate_synthetic_data( generated_validation_file_path: Path, train_file_path: Path, data_generation_task_type: str, - validation_file_path: Path = None, + student_model: str, + validation_file_path: Path = None + ): """Generate and save synthentic data under output_dataset. @@ -288,6 +299,7 @@ def generate_synthetic_data( max_len_summary (int): Maximum word count for text summarization output_dataset (Path): Path to output directory train_file_path (Path): Train JSONL file path + student_model (str): Student model name validation_file_path (Path, optional): Validation JSONL file path. Defaults to None. """ @@ -351,7 +363,7 @@ def process_request(idx: str, data: dict, url: str, endpoint_key: str): dict: result dictionary """ try: - # Basic validation for the input data + # Basic validation for the input data messages = data.pop("messages", []) if not messages: # empty messages return { @@ -527,10 +539,20 @@ def batch_process_data( ) else: output_data.append({"messages": future_result["messages"]}) - Path(output_file_path.parent).mkdir(exist_ok=True, parents=True) - with open(output_file_path, "w") as f: - for entry in output_data: - f.write(json.dumps(entry) + "\n") + Path(output_file_path.parent).mkdir(exist_ok=True, parents=True) + + # Reformat finetune data based on student model limitations + logger.info(f"output data before reformatting: {output_data}") + + output_data = StudentModels.reformat( + student_model=student_model, + task_type=data_generation_task_type, + data=output_data + ) + logger.info(f"output data after reformatting: {output_data}") + with open(output_file_path, "w") as f: + for entry in output_data: + f.write(json.dumps(entry) + "\n") if error_map: logger.info( @@ -594,6 +616,7 @@ def data_import(args: Namespace): enable_cod_str = args.enable_chain_of_density max_len_summary = args.max_len_summary data_generation_task_type = args.data_generation_task_type + model_asset_id = args.model_asset_id # validate file formats validate_file_paths_with_supported_formats( @@ -680,7 +703,8 @@ def data_import(args: Namespace): generated_validation_file_path=generated_validation_file_path, train_file_path=train_file_path, data_generation_task_type=data_generation_task_type, - validation_file_path=validation_file_path, + student_model=StudentModels.parse_model_asset_id(model_asset_id), + validation_file_path=validation_file_path ) diff --git a/assets/training/distillation/src/generate_data_postprocess.py b/assets/training/distillation/src/generate_data_postprocess.py index c90a84075f..92822a1240 100644 --- a/assets/training/distillation/src/generate_data_postprocess.py +++ b/assets/training/distillation/src/generate_data_postprocess.py @@ -36,6 +36,7 @@ STATUS_SUCCESS, FINISH_REASON_STOP, ) +from common.student_models import StudentModels from common.utils import ( get_hash_value, @@ -166,6 +167,13 @@ def get_parser(): help="A config file path that contains deployment configurations.", ) + parser.add_argument( + "--model_asset_id", + type=str, + required=True, + help="The student model asset id" + ) + return parser @@ -201,6 +209,7 @@ def postprocess_data( min_endpoint_success_ratio: float, output_file_path: str, hash_data: str, + student_model: str ): """Generate and save synthentic data under output_dataset. @@ -213,6 +222,7 @@ def postprocess_data( min_endpoint_success_ratio (float): Minimum success ratio below which run will be considered a failure output_file_path (str): Output JSONL file path. hash_data (str): Path to the jsonl file containing the hash for each payload. + student_model (str): The student model to finetune """ error_count = 0 output_data = [] @@ -288,6 +298,11 @@ def postprocess_data( if success_ratio < min_endpoint_success_ratio: msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}." raise Exception(msg) + + # Reformat finetune data based on student model limitations + logger.info(f"output data before reformatting: {output_data}") + output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) + logger.info(f"output data after reformatting: {output_data}") with open(output_file_path, "w") as f: for record in output_data: f.write(json.dumps(record) + "\n") @@ -308,6 +323,7 @@ def data_import(args: Namespace): hash_train_data = args.hash_train_data hash_validation_data = args.hash_validation_data connection_config_file = args.connection_config_file + model_asset_id = args.model_asset_id enable_cot = True if enable_cot_str.lower() == "true" else False enable_cod = True if enable_cod_str.lower() == "true" else False @@ -331,6 +347,7 @@ def data_import(args: Namespace): min_endpoint_success_ratio=min_endpoint_success_ratio, output_file_path=generated_batch_train_file_path, hash_data=hash_train_data, + student_model=StudentModels.parse_model_asset_id(model_asset_id) ) if validation_file_path: with log_activity( @@ -350,6 +367,7 @@ def data_import(args: Namespace): min_endpoint_success_ratio=min_endpoint_success_ratio, output_file_path=generated_batch_validation_file_path, hash_data=hash_validation_data, + student_model=StudentModels.parse_model_asset_id(model_asset_id) ) else: Path(generated_batch_validation_file_path.parent).mkdir( diff --git a/assets/training/distillation/src/validate_pipeline.py b/assets/training/distillation/src/validate_pipeline.py index a916ba7dea..eaa0d236a1 100644 --- a/assets/training/distillation/src/validate_pipeline.py +++ b/assets/training/distillation/src/validate_pipeline.py @@ -205,7 +205,7 @@ def _validate_model_inference(self): url = url if VLLM_CHAT_SCORE_PATH in url else f"{url}{VLLM_CHAT_SCORE_PATH}" logger.info(f"Model endpoint: {url}") response = requests.post( - url=url, headers=headers, data=json.dumps(inference_params) + url=url, headers=headers, data=json.dumps(inference_params), timeout=180 ) response.raise_for_status() From aec59c7e7351a9709284655f4ed9905741a3ca55 Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Fri, 31 Jan 2025 09:53:59 -0800 Subject: [PATCH 2/6] Cleanuo --- .../components/data_generation/spec.yaml | 2 +- .../spec.yaml | 14 ++++++------- .../spec.yaml | 2 +- .../spec.yaml | 2 +- .../spec.yaml | 2 +- .../data_generation_file_selector/spec.yaml | 2 +- .../spec.yaml | 4 ++-- .../spec.yaml | 2 +- .../components/pipeline/spec.yaml | 20 +++++++++---------- .../components/pipeline_validation/spec.yaml | 2 +- .../distillation/src/common/student_models.py | 4 +--- .../distillation/src/generate_data.py | 5 +---- .../src/generate_data_postprocess.py | 4 +--- 13 files changed, 29 insertions(+), 36 deletions(-) diff --git a/assets/training/distillation/components/data_generation/spec.yaml b/assets/training/distillation/components/data_generation/spec.yaml index 15092911d2..8116018547 100644 --- a/assets/training/distillation/components/data_generation/spec.yaml +++ b/assets/training/distillation/components/data_generation/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_generate_data -version: 0.0.9.test1 +version: 0.0.9 type: command is_deterministic: True diff --git a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml index 6e9ea394ad..fedd32f225 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: oss_distillation_batchscoring_datagen_pipeline -version: 0.0.1.test1 +version: 0.0.1 type: pipeline @@ -259,7 +259,7 @@ outputs: jobs: oss_distillation_generate_data_batch_preprocess: type: command - component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1.test1 + component: azureml:oss_distillation_generate_data_batch_preprocess:0.0.1 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' @@ -299,7 +299,7 @@ jobs: # Config generator job oss_distillation_generate_data_config_generator: type: command - component: azureml://registries/azureml/components/batch_benchmark_config_generator/versions/0.0.9 + component: azureml:batch_benchmark_config_generator:0.0.9 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -325,7 +325,7 @@ jobs: # Batch score job oss_distillation_train_data_batch_score: type: parallel - component: azureml://registries/azureml/components/batch_score_oss/versions/0.0.1 + component: azureml:batch_score_oss:0.0.1 compute: '${{parent.inputs.compute_data_generation}}' identity: type: user_identity @@ -352,7 +352,7 @@ jobs: validation_file_path_exists: type: command - component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1.test1 + component: azureml:oss_distillation_data_generation_validation_file_checker:0.0.1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -369,7 +369,7 @@ jobs: # Batch score job oss_distillation_validation_data_batch_score: type: parallel - component: azureml://registries/azureml/components/batch_score_oss/versions/0.0.1 + component: azureml:batch_score_oss:0.0.1 compute: '${{parent.inputs.compute_data_generation}}' identity: type: user_identity @@ -396,7 +396,7 @@ jobs: oss_distillation_generate_data_batch_postprocess: type: command - component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1.test1 + component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1s compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' diff --git a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml index 6b8a7a0fa8..3291a9fe70 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_generate_data_batch_postprocess -version: 0.0.1.test1 +version: 0.0.1 type: command is_deterministic: False diff --git a/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml index 11637bbbcc..5ae8aa2f68 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_preprocess/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_generate_data_batch_preprocess -version: 0.0.1.test1 +version: 0.0.1 type: command is_deterministic: False diff --git a/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml index 93ea3f3ce9..7c3dfe12e6 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_selector/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_data_generation_batch_scoring_selector -version: 0.0.1.test1 +version: 0.0.1 type: command is_deterministic: True diff --git a/assets/training/distillation/components/data_generation_file_selector/spec.yaml b/assets/training/distillation/components/data_generation_file_selector/spec.yaml index 5fb7f4a10a..448eb8d2f1 100644 --- a/assets/training/distillation/components/data_generation_file_selector/spec.yaml +++ b/assets/training/distillation/components/data_generation_file_selector/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_data_generation_file_selector -version: 0.0.1.test1 +version: 0.0.1 type: command is_deterministic: True diff --git a/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml b/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml index 9e3db28ecd..edcc5951b2 100644 --- a/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml +++ b/assets/training/distillation/components/data_generation_seq_scoring_pipeline/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: oss_distillation_seq_scoring_pipeline -version: 0.0.1.test1 +version: 0.0.1 type: pipeline @@ -216,7 +216,7 @@ outputs: jobs: oss_distillation_generate_data: type: command - component: azureml:oss_distillation_generate_data:0.0.9.test1 + component: azureml:oss_distillation_generate_data:0.0.9 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' diff --git a/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml b/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml index 0ba460f58b..e44d9b6ceb 100644 --- a/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml +++ b/assets/training/distillation/components/data_generation_validation_file_checker/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_data_generation_validation_file_checker -version: 0.0.1.test1 +version: 0.0.1 type: command is_deterministic: True diff --git a/assets/training/distillation/components/pipeline/spec.yaml b/assets/training/distillation/components/pipeline/spec.yaml index 81e88e34ce..ba38bdaf96 100644 --- a/assets/training/distillation/components/pipeline/spec.yaml +++ b/assets/training/distillation/components/pipeline/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: oss_distillation_pipeline -version: 0.0.10.test1 +version: 0.0.10 type: pipeline @@ -297,7 +297,7 @@ outputs: jobs: oss_distillation_validate_pipeline: type: command - component: azureml:oss_distillation_validate_pipeline:0.0.5.test1 + component: azureml:oss_distillation_validate_pipeline:0.0.5 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -331,7 +331,7 @@ jobs: data_generation_batch_scoring_selector: type: command - component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1.test1 + component: azureml:oss_distillation_data_generation_batch_scoring_selector:0.0.1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -348,7 +348,7 @@ jobs: oss_distillation_batchscoring_datagen_pipeline: type: pipeline - component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1.test1 + component: azureml:oss_distillation_batchscoring_datagen_pipeline:0.0.1 inputs: instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}' instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}' @@ -400,7 +400,7 @@ jobs: oss_distillation_seq_scoring_pipeline: type: pipeline - component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1.test1 + component: azureml:oss_distillation_seq_scoring_pipeline:0.0.1 inputs: instance_type_pipeline_validation: '${{parent.inputs.instance_type_pipeline_validation}}' instance_type_data_generation: '${{parent.inputs.instance_type_data_generation}}' @@ -440,7 +440,7 @@ jobs: oss_distillation_train_data_generation_file_selector: type: command - component: azureml:oss_distillation_data_generation_file_selector:0.0.1.test1 + component: azureml:oss_distillation_data_generation_file_selector:0.0.1 compute: '${{parent.inputs.compute_pipeline_validation}}' resources: instance_type: '${{parent.inputs.instance_type_pipeline_validation}}' @@ -463,7 +463,7 @@ jobs: oss_text_generation_data_import: type: command - component: azureml://registries/azureml/components/oss_text_generation_data_import/versions/0.0.26 + component: azureml:oss_text_generation_data_import:0.0.26 compute: '${{parent.inputs.compute_data_import}}' resources: instance_type: '${{parent.inputs.instance_type_data_import}}' @@ -475,13 +475,13 @@ jobs: environment_variables: _AZUREML_CR_ENABLE_ITP_CAP: "false" inputs: - train_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_train_file_path}}' - validation_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_validation_file_path}}' + train_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_train_file_path}}' + validation_file_path: '${{parent.jobs.oss_distillation_train_data_generation_file_selector.outputs.ft_input_validation_file_path}}' system_properties: '${{parent.inputs.system_properties}}' oss_chat_completion_finetune: type: command - component: azureml://registries/azureml/components/oss_chat_completion_finetune/versions/0.0.26 + component: azureml:oss_chat_completion_finetune:0.0.26 compute: '${{parent.inputs.compute_finetune}}' resources: instance_type: '${{parent.inputs.instance_type_finetune}}' diff --git a/assets/training/distillation/components/pipeline_validation/spec.yaml b/assets/training/distillation/components/pipeline_validation/spec.yaml index 490f5165bd..a52d287e19 100644 --- a/assets/training/distillation/components/pipeline_validation/spec.yaml +++ b/assets/training/distillation/components/pipeline_validation/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: oss_distillation_validate_pipeline -version: 0.0.5.test1 +version: 0.0.5 type: command is_deterministic: true diff --git a/assets/training/distillation/src/common/student_models.py b/assets/training/distillation/src/common/student_models.py index e1395e108f..2a3a461c84 100644 --- a/assets/training/distillation/src/common/student_models.py +++ b/assets/training/distillation/src/common/student_models.py @@ -85,7 +85,7 @@ def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]): {"role": "user", "content": system_message + " " + user_prompt}, messages["messages"][2:] ] - } + } new_data.append(reformatted_data) return new_data @@ -122,5 +122,3 @@ def parse_model_asset_id(cls, asset_id: str) -> str: if model not in cls.NO_SYSTEM_PROMPT_MODELS: raise Exception("Model is not in supported student model list") return model - - diff --git a/assets/training/distillation/src/generate_data.py b/assets/training/distillation/src/generate_data.py index 7fc3508c1a..a3251cbbda 100644 --- a/assets/training/distillation/src/generate_data.py +++ b/assets/training/distillation/src/generate_data.py @@ -541,15 +541,12 @@ def batch_process_data( output_data.append({"messages": future_result["messages"]}) Path(output_file_path.parent).mkdir(exist_ok=True, parents=True) - # Reformat finetune data based on student model limitations - logger.info(f"output data before reformatting: {output_data}") - + # Reformat data based on student model limitations output_data = StudentModels.reformat( student_model=student_model, task_type=data_generation_task_type, data=output_data ) - logger.info(f"output data after reformatting: {output_data}") with open(output_file_path, "w") as f: for entry in output_data: f.write(json.dumps(entry) + "\n") diff --git a/assets/training/distillation/src/generate_data_postprocess.py b/assets/training/distillation/src/generate_data_postprocess.py index 92822a1240..80e62a8900 100644 --- a/assets/training/distillation/src/generate_data_postprocess.py +++ b/assets/training/distillation/src/generate_data_postprocess.py @@ -299,10 +299,8 @@ def postprocess_data( msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}." raise Exception(msg) - # Reformat finetune data based on student model limitations - logger.info(f"output data before reformatting: {output_data}") + # Reformat data based on student model limitations output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) - logger.info(f"output data after reformatting: {output_data}") with open(output_file_path, "w") as f: for record in output_data: f.write(json.dumps(record) + "\n") From b95453e2866b49913250668e484b30d8fc7f53cb Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Mon, 3 Feb 2025 15:01:17 -0800 Subject: [PATCH 3/6] Remove batch scoring changes --- .../spec.yaml | 7 +---- .../spec.yaml | 4 --- .../components/pipeline/spec.yaml | 1 - .../src/generate_data_postprocess.py | 30 +++++++++---------- 4 files changed, 16 insertions(+), 26 deletions(-) diff --git a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml index fedd32f225..751d5b535e 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml @@ -160,10 +160,6 @@ inputs: 4. MATH: Generate Math data for numerical responses 5. SUMMARIZATION: Generate Key Summary for an Article - model_asset_id: - type: string - description: The student model to finetune - # Output of validation component. validation_info: type: uri_file @@ -396,7 +392,7 @@ jobs: oss_distillation_generate_data_batch_postprocess: type: command - component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1s + component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' @@ -413,7 +409,6 @@ jobs: enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}' data_generation_task_type: '${{parent.inputs.data_generation_task_type}}' min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}' - model_asset_id: '${{parent.inputs.model_asset_id}}' connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}} outputs: generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}' diff --git a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml index 3291a9fe70..f89a121e99 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml @@ -82,9 +82,6 @@ inputs: type: uri_file description: Connection config file for batch scoring - model_asset_id: - type: string - description: The student model to finetune outputs: generated_batch_train_file_path: @@ -108,7 +105,6 @@ command: >- --min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}} $[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]] $[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]] - --model_asset_id ${{inputs.model_asset_id}} --data_generation_task_type ${{inputs.data_generation_task_type}} --connection_config_file ${{inputs.connection_config_file}} --generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}} diff --git a/assets/training/distillation/components/pipeline/spec.yaml b/assets/training/distillation/components/pipeline/spec.yaml index ba38bdaf96..2984f6a8e4 100644 --- a/assets/training/distillation/components/pipeline/spec.yaml +++ b/assets/training/distillation/components/pipeline/spec.yaml @@ -388,7 +388,6 @@ jobs: max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}' mini_batch_size: '${{parent.inputs.mini_batch_size}}' validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}' - model_asset_id: '${{parent.inputs.model_asset_id}}' outputs: generated_batch_train_file_path: diff --git a/assets/training/distillation/src/generate_data_postprocess.py b/assets/training/distillation/src/generate_data_postprocess.py index 80e62a8900..7e037dbc08 100644 --- a/assets/training/distillation/src/generate_data_postprocess.py +++ b/assets/training/distillation/src/generate_data_postprocess.py @@ -36,7 +36,7 @@ STATUS_SUCCESS, FINISH_REASON_STOP, ) -from common.student_models import StudentModels +# from common.student_models import StudentModels from common.utils import ( get_hash_value, @@ -167,12 +167,12 @@ def get_parser(): help="A config file path that contains deployment configurations.", ) - parser.add_argument( - "--model_asset_id", - type=str, - required=True, - help="The student model asset id" - ) + # parser.add_argument( + # "--model_asset_id", + # type=str, + # required=True, + # help="The student model asset id" + # ) return parser @@ -208,8 +208,8 @@ def postprocess_data( data_generation_task_type: str, min_endpoint_success_ratio: float, output_file_path: str, - hash_data: str, - student_model: str + hash_data: str + # student_model: str ): """Generate and save synthentic data under output_dataset. @@ -300,7 +300,7 @@ def postprocess_data( raise Exception(msg) # Reformat data based on student model limitations - output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) + # output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) with open(output_file_path, "w") as f: for record in output_data: f.write(json.dumps(record) + "\n") @@ -321,7 +321,7 @@ def data_import(args: Namespace): hash_train_data = args.hash_train_data hash_validation_data = args.hash_validation_data connection_config_file = args.connection_config_file - model_asset_id = args.model_asset_id + # model_asset_id = args.model_asset_id enable_cot = True if enable_cot_str.lower() == "true" else False enable_cod = True if enable_cod_str.lower() == "true" else False @@ -344,8 +344,8 @@ def data_import(args: Namespace): data_generation_task_type=data_generation_task_type, min_endpoint_success_ratio=min_endpoint_success_ratio, output_file_path=generated_batch_train_file_path, - hash_data=hash_train_data, - student_model=StudentModels.parse_model_asset_id(model_asset_id) + hash_data=hash_train_data + # student_model=StudentModels.parse_model_asset_id(model_asset_id) ) if validation_file_path: with log_activity( @@ -364,8 +364,8 @@ def data_import(args: Namespace): data_generation_task_type=data_generation_task_type, min_endpoint_success_ratio=min_endpoint_success_ratio, output_file_path=generated_batch_validation_file_path, - hash_data=hash_validation_data, - student_model=StudentModels.parse_model_asset_id(model_asset_id) + hash_data=hash_validation_data + # student_model=StudentModels.parse_model_asset_id(model_asset_id) ) else: Path(generated_batch_validation_file_path.parent).mkdir( From d9aef5c3d61125fed75347606a6fb978b3456751 Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Mon, 3 Feb 2025 15:32:19 -0800 Subject: [PATCH 4/6] Code quality fix --- assets/training/distillation/src/common/student_models.py | 2 +- assets/training/distillation/src/common/utils.py | 2 +- .../training/distillation/src/generate_data_postprocess.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/assets/training/distillation/src/common/student_models.py b/assets/training/distillation/src/common/student_models.py index 2a3a461c84..254d51ac80 100644 --- a/assets/training/distillation/src/common/student_models.py +++ b/assets/training/distillation/src/common/student_models.py @@ -85,7 +85,7 @@ def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]): {"role": "user", "content": system_message + " " + user_prompt}, messages["messages"][2:] ] - } + } new_data.append(reformatted_data) return new_data diff --git a/assets/training/distillation/src/common/utils.py b/assets/training/distillation/src/common/utils.py index c5c7e48248..de778ab54f 100644 --- a/assets/training/distillation/src/common/utils.py +++ b/assets/training/distillation/src/common/utils.py @@ -500,4 +500,4 @@ def get_hash_value(data: Union[Dict[str, Any], str]) -> str: """ if isinstance(data, str): return hashlib.sha256(data.encode()).hexdigest() - return hashlib.sha256(json.dumps(data).encode()).hexdigest() \ No newline at end of file + return hashlib.sha256(json.dumps(data).encode()).hexdigest() diff --git a/assets/training/distillation/src/generate_data_postprocess.py b/assets/training/distillation/src/generate_data_postprocess.py index 7e037dbc08..b308f674a7 100644 --- a/assets/training/distillation/src/generate_data_postprocess.py +++ b/assets/training/distillation/src/generate_data_postprocess.py @@ -300,7 +300,11 @@ def postprocess_data( raise Exception(msg) # Reformat data based on student model limitations - # output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) + # output_data = StudentModels.reformat( + # student_model=student_model, + # task_type=data_generation_task_type, + # data=output_data + # ) with open(output_file_path, "w") as f: for record in output_data: f.write(json.dumps(record) + "\n") From 66d520aefd7ab639498a1eb24221828d45eef81a Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Mon, 3 Feb 2025 15:54:22 -0800 Subject: [PATCH 5/6] Fix pydoc issues --- .../distillation/src/common/student_models.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/assets/training/distillation/src/common/student_models.py b/assets/training/distillation/src/common/student_models.py index 254d51ac80..4552df460d 100644 --- a/assets/training/distillation/src/common/student_models.py +++ b/assets/training/distillation/src/common/student_models.py @@ -10,6 +10,8 @@ class StudentModels: + """Student model information and requirements.""" + SUPPORTED_STUDENT_MODELS = { "Meta-Llama-3.1-8B-Instruct": { "supported_registries": ["azureml-meta"], @@ -49,8 +51,7 @@ class StudentModels: @classmethod def no_system_prompt_reformat(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]: - """Adds system prompt to user prompt for student models that do not - accept system prompts. + """Add system prompt to user prompt for student models that do not accept system prompts. :param data: The synthetic data generated from the teacher model :type data: List[Dict[str, list]] @@ -73,7 +74,14 @@ def no_system_prompt_reformat(cls, data: List[Dict[str, list]]) -> List[Dict[str return new_data @classmethod - def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]): + def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]: + """Add system prompt to user prompt for student models that do not accept system prompts. + + :param data: The synthetic data generated from the teacher model + :type data: List[Dict[str, list]] + :return: Reformated data + :rtype: List[Dict[str, list]] + """ new_data = [] system_message = "" for messages in data: From 0f96b891ac968e735e75b0b43b4fa68140a017b0 Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Mon, 3 Feb 2025 15:56:40 -0800 Subject: [PATCH 6/6] Small indent --- assets/training/distillation/src/common/student_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assets/training/distillation/src/common/student_models.py b/assets/training/distillation/src/common/student_models.py index 4552df460d..813614aaee 100644 --- a/assets/training/distillation/src/common/student_models.py +++ b/assets/training/distillation/src/common/student_models.py @@ -75,7 +75,7 @@ def no_system_prompt_reformat(cls, data: List[Dict[str, list]]) -> List[Dict[str @classmethod def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]: - """Add system prompt to user prompt for student models that do not accept system prompts. + """Add system prompt to user prompt for student models that do not accept system prompts. :param data: The synthetic data generated from the teacher model :type data: List[Dict[str, list]]