Skip to content

Commit

Permalink
Phi medium support (#3802)
Browse files Browse the repository at this point in the history
* Initial changes for phi medium support

* Cleanuo

* Remove batch scoring changes

* Code quality fix

* Fix pydoc issues

* Small indent
  • Loading branch information
sanchez-alex authored Feb 4, 2025
1 parent 8dc4e07 commit bb084d5
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.8
version: 0.0.9
type: command

is_deterministic: True
Expand Down Expand Up @@ -121,6 +121,10 @@ inputs:
type: uri_file
description: Validation status.
mode: rw_mount

model_asset_id:
type: string
description: Student model to use

outputs:
generated_train_file_path:
Expand Down Expand Up @@ -152,5 +156,6 @@ command: >-
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
$[[--max_len_summary ${{inputs.max_len_summary}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--model_asset_id ${{inputs.model_asset_id}}
--generated_train_file_path ${{outputs.generated_train_file_path}}
--generated_validation_file_path ${{outputs.generated_validation_file_path}}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
# Output of validation component.
validation_info:
type: uri_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ inputs:
type: uri_file
description: Connection config file for batch scoring


outputs:
generated_batch_train_file_path:
type: uri_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
model_asset_id:
type: string
description: The student model asset id
optional: false

# Training parameters
num_train_epochs:
Expand Down Expand Up @@ -212,7 +216,7 @@ outputs:
jobs:
oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.8
component: azureml:oss_distillation_generate_data:0.0.9
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand All @@ -236,6 +240,7 @@ jobs:
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
validation_output: '${{parent.inputs.validation_output}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
generated_train_file_path: '${{parent.outputs.generated_train_file_path}}'
generated_validation_file_path: '${{parent.outputs.generated_validation_file_path}}'
10 changes: 6 additions & 4 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ inputs:
optional: true
description: Validation parameters propagated from pipeline.

# Model parameters
# Student Model parameters
model_asset_id:
type: string
optional: false
description: Asset id of model
description: Asset id of the student model

# Model registration
registered_model_name:
Expand Down Expand Up @@ -323,6 +323,7 @@ jobs:
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
validation_info:
type: uri_file
Expand Down Expand Up @@ -426,6 +427,7 @@ jobs:
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
generated_train_file_path:
type: uri_file
Expand Down Expand Up @@ -460,7 +462,7 @@ jobs:

oss_text_generation_data_import:
type: command
component: azureml:oss_text_generation_data_import:0.0.25
component: azureml:oss_text_generation_data_import:0.0.26
compute: '${{parent.inputs.compute_data_import}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_import}}'
Expand All @@ -478,7 +480,7 @@ jobs:

oss_chat_completion_finetune:
type: command
component: azureml:oss_chat_completion_finetune:0.0.25
component: azureml:oss_chat_completion_finetune:0.0.26
compute: '${{parent.inputs.compute_finetune}}'
resources:
instance_type: '${{parent.inputs.instance_type_finetune}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ inputs:
optional: true
description: Start learning rate.

model_asset_id:
type: string
optional: false
description: The student model to finetune

outputs:
validation_info:
type: uri_file
Expand Down Expand Up @@ -163,4 +168,5 @@ command: >-
$[[--num_train_epochs ${{inputs.num_train_epochs}}]]
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]]
$[[--learning_rate ${{inputs.learning_rate}}]]
--model_asset_id '${{inputs.model_asset_id}}'
--validation_info ${{outputs.validation_info}}
11 changes: 1 addition & 10 deletions assets/training/distillation/src/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data generatior constants."""
"""Data generation constants."""

import re
from enum import EnumMeta, Enum
Expand Down Expand Up @@ -36,15 +36,6 @@
}
}

# SUPPORTED STUDENT MODEL
# MAP keys are model name in registry, which maps to specific model details like registry and supported versions
SUPPORTED_STUDENT_MODEL_MAP = {
"Meta-Llama-3.1-8B-Instruct": {
"supported_registries": ["azureml-meta"],
"supported_version_pattern": re.compile(r"\d+"),
}
}

# Scoring paths
VLLM_CHAT_SCORE_PATH = "/v1/chat/completions"
HFTV2_TEXT_GEN_SCORE_PATH = "/score"
Expand Down
132 changes: 132 additions & 0 deletions assets/training/distillation/src/common/student_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Student Model Info and Requirements."""

import re
from typing import Dict, List

from common.constants import REGISTRY_MODEL_PATTERN, DataGenerationTaskType


class StudentModels:
"""Student model information and requirements."""

SUPPORTED_STUDENT_MODELS = {
"Meta-Llama-3.1-8B-Instruct": {
"supported_registries": ["azureml-meta"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3-mini-4k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3-mini-128k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3.5-mini-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3.5-MoE-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3-medium-4k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+"),
},
"Phi-3-medium-128k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+"),
},
}

# Student models that do not recognize system prompts
NO_SYSTEM_PROMPT_MODELS = [
"Phi-3-medium-4k-instruct",
"Phi-3-medium-128k-instruct"
]

@classmethod
def no_system_prompt_reformat(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]:
"""Add system prompt to user prompt for student models that do not accept system prompts.
:param data: The synthetic data generated from the teacher model
:type data: List[Dict[str, list]]
:return: Reformated data
:rtype: List[Dict[str, list]]
"""
new_data = []
system_message = ""
for messages in data:
system_message = messages["messages"][0]["content"]
question = messages["messages"][1]["content"]
reformatted_data = {
"messages":
[
{"role": "user", "content": system_message + " " + question},
messages["messages"][2]
]
}
new_data.append(reformatted_data)
return new_data

@classmethod
def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]:
"""Add system prompt to user prompt for student models that do not accept system prompts.
:param data: The synthetic data generated from the teacher model
:type data: List[Dict[str, list]]
:return: Reformated data
:rtype: List[Dict[str, list]]
"""
new_data = []
system_message = ""
for messages in data:
system_message = messages["messages"][0]["content"]
user_prompt = messages["messages"][1]["content"]
reformatted_data = {
"messages":
[
{"role": "user", "content": system_message + " " + user_prompt},
messages["messages"][2:]
]
}
new_data.append(reformatted_data)
return new_data

@classmethod
def reformat(cls, student_model: str, task_type: str, data: List[Dict[str, list]]) -> List[Dict[str, list]]:
"""Reformats synthetic data based on the student model and task type requirements.
:param student_model: The student model to finetune
:type student_model: str
:param task_type: The data generation task type
:type task_type: str
:param data: The synthetic data generated from the teacher model
:type data: List[Dict[str, list]]
:return: Reformatted data based on student model and task type
:rtype: List[Dict[str, list]]
"""
if student_model in cls.NO_SYSTEM_PROMPT_MODELS:
if task_type == DataGenerationTaskType.CONVERSATION:
return cls.no_system_prompt_reformat_conversation(data)
return cls.no_system_prompt_reformat(data)
return data

@classmethod
def parse_model_asset_id(cls, asset_id: str) -> str:
"""Parse asset id to extract the student model name.
:param asset_id: The asset id of the student model in the form
azureml://registries/{registry}/models/{model}/versions/{version}.
:type asset_id: str
"""
match = re.search(REGISTRY_MODEL_PATTERN, asset_id)
model = match.group("model")

if model not in cls.NO_SYSTEM_PROMPT_MODELS:
raise Exception("Model is not in supported student model list")
return model
5 changes: 3 additions & 2 deletions assets/training/distillation/src/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
from common.constants import (
REQUESTS_RETRY_DELAY,
REGISTRY_MODEL_PATTERN,
SUPPORTED_STUDENT_MODEL_MAP,
SUPPORTED_TEACHER_MODEL_MAP,
BackoffConstants,
)

from common.student_models import StudentModels


logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
Expand Down Expand Up @@ -396,7 +397,7 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
Returns:
Tuple[str, str, str]: Tuple containing registry name, model name and model version
"""
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
return _get_model_details(model_asset_id, StudentModels.SUPPORTED_STUDENT_MODELS)


def get_base_url(url: str) -> str:
Expand Down
Loading

0 comments on commit bb084d5

Please sign in to comment.