diff --git a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml index fedd32f225..751d5b535e 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_pipeline/spec.yaml @@ -160,10 +160,6 @@ inputs: 4. MATH: Generate Math data for numerical responses 5. SUMMARIZATION: Generate Key Summary for an Article - model_asset_id: - type: string - description: The student model to finetune - # Output of validation component. validation_info: type: uri_file @@ -396,7 +392,7 @@ jobs: oss_distillation_generate_data_batch_postprocess: type: command - component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1s + component: azureml:oss_distillation_generate_data_batch_postprocess:0.0.1 compute: '${{parent.inputs.compute_data_generation}}' resources: instance_type: '${{parent.inputs.instance_type_data_generation}}' @@ -413,7 +409,6 @@ jobs: enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}' data_generation_task_type: '${{parent.inputs.data_generation_task_type}}' min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}' - model_asset_id: '${{parent.inputs.model_asset_id}}' connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}} outputs: generated_batch_train_file_path: '${{parent.outputs.generated_batch_train_file_path}}' diff --git a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml index 3291a9fe70..f89a121e99 100644 --- a/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml +++ b/assets/training/distillation/components/data_generation_batch_scoring_postprocess/spec.yaml @@ -82,9 +82,6 @@ inputs: type: uri_file description: Connection config file for batch scoring - model_asset_id: - type: string - description: The student model to finetune outputs: generated_batch_train_file_path: @@ -108,7 +105,6 @@ command: >- --min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}} $[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]] $[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]] - --model_asset_id ${{inputs.model_asset_id}} --data_generation_task_type ${{inputs.data_generation_task_type}} --connection_config_file ${{inputs.connection_config_file}} --generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}} diff --git a/assets/training/distillation/components/pipeline/spec.yaml b/assets/training/distillation/components/pipeline/spec.yaml index ba38bdaf96..2984f6a8e4 100644 --- a/assets/training/distillation/components/pipeline/spec.yaml +++ b/assets/training/distillation/components/pipeline/spec.yaml @@ -388,7 +388,6 @@ jobs: max_concurrency_per_instance: '${{parent.inputs.max_concurrency_per_instance}}' mini_batch_size: '${{parent.inputs.mini_batch_size}}' validation_info: '${{parent.jobs.oss_distillation_validate_pipeline.outputs.validation_info}}' - model_asset_id: '${{parent.inputs.model_asset_id}}' outputs: generated_batch_train_file_path: diff --git a/assets/training/distillation/src/generate_data_postprocess.py b/assets/training/distillation/src/generate_data_postprocess.py index 80e62a8900..7e037dbc08 100644 --- a/assets/training/distillation/src/generate_data_postprocess.py +++ b/assets/training/distillation/src/generate_data_postprocess.py @@ -36,7 +36,7 @@ STATUS_SUCCESS, FINISH_REASON_STOP, ) -from common.student_models import StudentModels +# from common.student_models import StudentModels from common.utils import ( get_hash_value, @@ -167,12 +167,12 @@ def get_parser(): help="A config file path that contains deployment configurations.", ) - parser.add_argument( - "--model_asset_id", - type=str, - required=True, - help="The student model asset id" - ) + # parser.add_argument( + # "--model_asset_id", + # type=str, + # required=True, + # help="The student model asset id" + # ) return parser @@ -208,8 +208,8 @@ def postprocess_data( data_generation_task_type: str, min_endpoint_success_ratio: float, output_file_path: str, - hash_data: str, - student_model: str + hash_data: str + # student_model: str ): """Generate and save synthentic data under output_dataset. @@ -300,7 +300,7 @@ def postprocess_data( raise Exception(msg) # Reformat data based on student model limitations - output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) + # output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) with open(output_file_path, "w") as f: for record in output_data: f.write(json.dumps(record) + "\n") @@ -321,7 +321,7 @@ def data_import(args: Namespace): hash_train_data = args.hash_train_data hash_validation_data = args.hash_validation_data connection_config_file = args.connection_config_file - model_asset_id = args.model_asset_id + # model_asset_id = args.model_asset_id enable_cot = True if enable_cot_str.lower() == "true" else False enable_cod = True if enable_cod_str.lower() == "true" else False @@ -344,8 +344,8 @@ def data_import(args: Namespace): data_generation_task_type=data_generation_task_type, min_endpoint_success_ratio=min_endpoint_success_ratio, output_file_path=generated_batch_train_file_path, - hash_data=hash_train_data, - student_model=StudentModels.parse_model_asset_id(model_asset_id) + hash_data=hash_train_data + # student_model=StudentModels.parse_model_asset_id(model_asset_id) ) if validation_file_path: with log_activity( @@ -364,8 +364,8 @@ def data_import(args: Namespace): data_generation_task_type=data_generation_task_type, min_endpoint_success_ratio=min_endpoint_success_ratio, output_file_path=generated_batch_validation_file_path, - hash_data=hash_validation_data, - student_model=StudentModels.parse_model_asset_id(model_asset_id) + hash_data=hash_validation_data + # student_model=StudentModels.parse_model_asset_id(model_asset_id) ) else: Path(generated_batch_validation_file_path.parent).mkdir(