Skip to content

Commit

Permalink
Remove batch scoring changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchez-alex committed Feb 3, 2025
1 parent aec59c7 commit b95453e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,6 @@ inputs:
4. MATH: Generate Math data for numerical responses
5. SUMMARIZATION: Generate Key Summary for an Article
model_asset_id:
type: string
description: The student model to finetune

# Output of validation component.
validation_info:
type: uri_file
Expand Down Expand Up @@ -396,7 +392,7 @@ jobs:

oss_distillation_generate_data_batch_postprocess:
type: command
component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1s
component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1
compute: '${{parent.inputs.compute_data_generation}}'
resources:
instance_type: '${{parent.inputs.instance_type_data_generation}}'
Expand All @@ -413,7 +409,6 @@ jobs:
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'
connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
outputs:
generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ inputs:
type: uri_file
description: Connection config file for batch scoring

model_asset_id:
type: string
description: The student model to finetune

outputs:
generated_batch_train_file_path:
Expand All @@ -108,7 +105,6 @@ command: >-
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
--model_asset_id ${{inputs.model_asset_id}}
--data_generation_task_type ${{inputs.data_generation_task_type}}
--connection_config_file ${{inputs.connection_config_file}}
--generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}}
Expand Down
1 change: 0 additions & 1 deletion assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ jobs:
max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}'
mini_batch_size: '${{parent.inputs.mini_batch_size}}'
validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}'
model_asset_id: '${{parent.inputs.model_asset_id}}'

outputs:
generated_batch_train_file_path:
Expand Down
30 changes: 15 additions & 15 deletions assets/training/distillation/src/generate_data_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
STATUS_SUCCESS,
FINISH_REASON_STOP,
)
from common.student_models import StudentModels
# from common.student_models import StudentModels

from common.utils import (
get_hash_value,
Expand Down Expand Up @@ -167,12 +167,12 @@ def get_parser():
help="A config file path that contains deployment configurations.",
)

parser.add_argument(
"--model_asset_id",
type=str,
required=True,
help="The student model asset id"
)
# parser.add_argument(
# "--model_asset_id",
# type=str,
# required=True,
# help="The student model asset id"
# )

return parser

Expand Down Expand Up @@ -208,8 +208,8 @@ def postprocess_data(
data_generation_task_type: str,
min_endpoint_success_ratio: float,
output_file_path: str,
hash_data: str,
student_model: str
hash_data: str
# student_model: str
):
"""Generate and save synthentic data under output_dataset.
Expand Down Expand Up @@ -300,7 +300,7 @@ def postprocess_data(
raise Exception(msg)

# Reformat data based on student model limitations
output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data)
# output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data)
with open(output_file_path, "w") as f:
for record in output_data:
f.write(json.dumps(record) + "\n")
Expand All @@ -321,7 +321,7 @@ def data_import(args: Namespace):
hash_train_data = args.hash_train_data
hash_validation_data = args.hash_validation_data
connection_config_file = args.connection_config_file
model_asset_id = args.model_asset_id
# model_asset_id = args.model_asset_id

enable_cot = True if enable_cot_str.lower() == "true" else False
enable_cod = True if enable_cod_str.lower() == "true" else False
Expand All @@ -344,8 +344,8 @@ def data_import(args: Namespace):
data_generation_task_type=data_generation_task_type,
min_endpoint_success_ratio=min_endpoint_success_ratio,
output_file_path=generated_batch_train_file_path,
hash_data=hash_train_data,
student_model=StudentModels.parse_model_asset_id(model_asset_id)
hash_data=hash_train_data
# student_model=StudentModels.parse_model_asset_id(model_asset_id)
)
if validation_file_path:
with log_activity(
Expand All @@ -364,8 +364,8 @@ def data_import(args: Namespace):
data_generation_task_type=data_generation_task_type,
min_endpoint_success_ratio=min_endpoint_success_ratio,
output_file_path=generated_batch_validation_file_path,
hash_data=hash_validation_data,
student_model=StudentModels.parse_model_asset_id(model_asset_id)
hash_data=hash_validation_data
# student_model=StudentModels.parse_model_asset_id(model_asset_id)
)
else:
Path(generated_batch_validation_file_path.parent).mkdir(
Expand Down

0 comments on commit b95453e

Please sign in to comment.