Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi medium support #3802

Merged
merged 7 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.8
version: 0.0.9
type: command

is_deterministic: True
Expand Down Expand Up @@ -121,6 +121,10 @@ inputs:
type: uri_file
description: Validation status.
mode: rw_mount

model_asset_id:
type: string
description: Student model to use

outputs:
generated_train_file_path:
Expand Down Expand Up @@ -152,5 +156,6 @@ command: >-
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
$[[--max_len_summary ${{inputs.max_len_summary}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--model_asset_id ${{inputs.model_asset_id}}
--generated_train_file_path ${{outputs.generated_train_file_path}}
--generated_validation_file_path ${{outputs.generated_validation_file_path}}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article


# Output of validation component.
validation_info:
type: uri_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ inputs:
type: uri_file
description: Connection config file for batch scoring


outputs:
generated_batch_train_file_path:
type: uri_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article

model_asset_id:
type: string
description: The student model asset id
optional: false

# Training parameters
num_train_epochs:
Expand Down Expand Up @@ -212,7 +216,7 @@ outputs:
jobs:
oss_distillation_generate_data:
type: command
component: azureml:oss_distillation_generate_data:0.0.8
component: azureml:oss_distillation_generate_data:0.0.9
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand All @@ -236,6 +240,7 @@ jobs:
request_batch_size: '${{parent.inputs.request_batch_size}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
validation_output: '${{parent.inputs.validation_output}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
generated_train_file_path: '${{parent.outputs.generated_train_file_path}}'
generated_validation_file_path: '${{parent.outputs.generated_validation_file_path}}'
10 changes: 6 additions & 4 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ inputs:
optional: true
description: Validation parameters propagated from pipeline.

# Model parameters
# Student Model parameters
model_asset_id:
type: string
optional: false
description: Asset id of model
description: Asset id of the student model

# Model registration
registered_model_name:
Expand Down Expand Up @@ -323,6 +323,7 @@ jobs:
num_train_epochs: '${{parent.inputs.num_train_epochs}}'
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
validation_info:
type: uri_file
Expand Down Expand Up @@ -426,6 +427,7 @@ jobs:
max_len_summary: '${{parent.inputs.max_len_summary}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
validation_output: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
outputs:
generated_train_file_path:
type: uri_file
Expand Down Expand Up @@ -460,7 +462,7 @@ jobs:

oss_text_generation_data_import:
type: command
component: azureml:oss_text_generation_data_import:0.0.25
component: azureml:oss_text_generation_data_import:0.0.26
compute: '${{parent.inputs.compute_data_import}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_import}}'
Expand All @@ -478,7 +480,7 @@ jobs:

oss_chat_completion_finetune:
type: command
component: azureml:oss_chat_completion_finetune:0.0.25
component: azureml:oss_chat_completion_finetune:0.0.26
compute: '${{parent.inputs.compute_finetune}}'
resources:
instance_type: '${{parent.inputs.instance_type_finetune}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ inputs:
optional: true
description: Start learning rate.

model_asset_id:
type: string
optional: false
description: The student model to finetune

outputs:
validation_info:
type: uri_file
Expand Down Expand Up @@ -163,4 +168,5 @@ command: >-
$[[--num_train_epochs ${{inputs.num_train_epochs}}]]
$[[--per_device_train_batch_size ${{inputs.per_device_train_batch_size}}]]
$[[--learning_rate ${{inputs.learning_rate}}]]
--model_asset_id '${{inputs.model_asset_id}}'
--validation_info ${{outputs.validation_info}}
11 changes: 1 addition & 10 deletions assets/training/distillation/src/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data generatior constants."""
"""Data generation constants."""

import re
from enum import EnumMeta, Enum
Expand Down Expand Up @@ -36,15 +36,6 @@
}
}

# SUPPORTED STUDENT MODEL
# MAP keys are model name in registry, which maps to specific model details like registry and supported versions
SUPPORTED_STUDENT_MODEL_MAP = {
"Meta-Llama-3.1-8B-Instruct": {
"supported_registries": ["azureml-meta"],
"supported_version_pattern": re.compile(r"\d+"),
}
}

# Scoring paths
VLLM_CHAT_SCORE_PATH = "/v1/chat/completions"
HFTV2_TEXT_GEN_SCORE_PATH = "/score"
Expand Down
132 changes: 132 additions & 0 deletions assets/training/distillation/src/common/student_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Student Model Info and Requirements."""

import re
from typing import Dict, List

from common.constants import REGISTRY_MODEL_PATTERN, DataGenerationTaskType


class StudentModels:
"""Student model information and requirements."""

SUPPORTED_STUDENT_MODELS = {
"Meta-Llama-3.1-8B-Instruct": {
"supported_registries": ["azureml-meta"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3-mini-4k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3-mini-128k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3.5-mini-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3.5-MoE-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+")
},
"Phi-3-medium-4k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+"),
},
"Phi-3-medium-128k-instruct": {
"supported_registries": ["azureml"],
"supported_version_pattern": re.compile(r"\d+"),
},
}

# Student models that do not recognize system prompts
NO_SYSTEM_PROMPT_MODELS = [
"Phi-3-medium-4k-instruct",
"Phi-3-medium-128k-instruct"
]

@classmethod
def no_system_prompt_reformat(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]:
"""Add system prompt to user prompt for student models that do not accept system prompts.

:param data: The synthetic data generated from the teacher model
:type data: List[Dict[str, list]]
:return: Reformated data
:rtype: List[Dict[str, list]]
"""
new_data = []
system_message = ""
for messages in data:
system_message = messages["messages"][0]["content"]
question = messages["messages"][1]["content"]
reformatted_data = {
"messages":
[
{"role": "user", "content": system_message + " " + question},
messages["messages"][2]
]
}
new_data.append(reformatted_data)
return new_data

@classmethod
def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]) -> List[Dict[str, list]]:
"""Add system prompt to user prompt for student models that do not accept system prompts.

:param data: The synthetic data generated from the teacher model
:type data: List[Dict[str, list]]
:return: Reformated data
:rtype: List[Dict[str, list]]
"""
new_data = []
system_message = ""
for messages in data:
system_message = messages["messages"][0]["content"]
user_prompt = messages["messages"][1]["content"]
reformatted_data = {
"messages":
[
{"role": "user", "content": system_message + " " + user_prompt},
messages["messages"][2:]
]
}
new_data.append(reformatted_data)
return new_data

@classmethod
def reformat(cls, student_model: str, task_type: str, data: List[Dict[str, list]]) -> List[Dict[str, list]]:
"""Reformats synthetic data based on the student model and task type requirements.

:param student_model: The student model to finetune
:type student_model: str
:param task_type: The data generation task type
:type task_type: str
:param data: The synthetic data generated from the teacher model
:type data: List[Dict[str, list]]
:return: Reformatted data based on student model and task type
:rtype: List[Dict[str, list]]
"""
if student_model in cls.NO_SYSTEM_PROMPT_MODELS:
if task_type == DataGenerationTaskType.CONVERSATION:
return cls.no_system_prompt_reformat_conversation(data)
return cls.no_system_prompt_reformat(data)
return data

@classmethod
def parse_model_asset_id(cls, asset_id: str) -> str:
"""Parse asset id to extract the student model name.

:param asset_id: The asset id of the student model in the form
azureml://registries/{registry}/models/{model}/versions/{version}.
:type asset_id: str
"""
match = re.search(REGISTRY_MODEL_PATTERN, asset_id)
model = match.group("model")

if model not in cls.NO_SYSTEM_PROMPT_MODELS:
raise Exception("Model is not in supported student model list")
return model
5 changes: 3 additions & 2 deletions assets/training/distillation/src/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
from common.constants import (
REQUESTS_RETRY_DELAY,
REGISTRY_MODEL_PATTERN,
SUPPORTED_STUDENT_MODEL_MAP,
SUPPORTED_TEACHER_MODEL_MAP,
BackoffConstants,
)

from common.student_models import StudentModels


logger = get_logger_app(
"azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import"
Expand Down Expand Up @@ -396,7 +397,7 @@ def validate_student_model_details(model_asset_id: str) -> Tuple[str, str, str]:
Returns:
Tuple[str, str, str]: Tuple containing registry name, model name and model version
"""
return _get_model_details(model_asset_id, SUPPORTED_STUDENT_MODEL_MAP)
return _get_model_details(model_asset_id, StudentModels.SUPPORTED_STUDENT_MODELS)


def get_base_url(url: str) -> str:
Expand Down
Loading
Loading