From d9aef5c3d61125fed75347606a6fb978b3456751 Mon Sep 17 00:00:00 2001 From: Alex Sanchez Date: Mon, 3 Feb 2025 15:32:19 -0800 Subject: [PATCH] Code quality fix --- assets/training/distillation/src/common/student_models.py | 2 +- assets/training/distillation/src/common/utils.py | 2 +- .../training/distillation/src/generate_data_postprocess.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/assets/training/distillation/src/common/student_models.py b/assets/training/distillation/src/common/student_models.py index 2a3a461c84..254d51ac80 100644 --- a/assets/training/distillation/src/common/student_models.py +++ b/assets/training/distillation/src/common/student_models.py @@ -85,7 +85,7 @@ def no_system_prompt_reformat_conversation(cls, data: List[Dict[str, list]]): {"role": "user", "content": system_message + " " + user_prompt}, messages["messages"][2:] ] - } + } new_data.append(reformatted_data) return new_data diff --git a/assets/training/distillation/src/common/utils.py b/assets/training/distillation/src/common/utils.py index c5c7e48248..de778ab54f 100644 --- a/assets/training/distillation/src/common/utils.py +++ b/assets/training/distillation/src/common/utils.py @@ -500,4 +500,4 @@ def get_hash_value(data: Union[Dict[str, Any], str]) -> str: """ if isinstance(data, str): return hashlib.sha256(data.encode()).hexdigest() - return hashlib.sha256(json.dumps(data).encode()).hexdigest() \ No newline at end of file + return hashlib.sha256(json.dumps(data).encode()).hexdigest() diff --git a/assets/training/distillation/src/generate_data_postprocess.py b/assets/training/distillation/src/generate_data_postprocess.py index 7e037dbc08..b308f674a7 100644 --- a/assets/training/distillation/src/generate_data_postprocess.py +++ b/assets/training/distillation/src/generate_data_postprocess.py @@ -300,7 +300,11 @@ def postprocess_data( raise Exception(msg) # Reformat data based on student model limitations - # output_data = StudentModels.reformat(student_model=student_model, task_type=data_generation_task_type, data=output_data) + # output_data = StudentModels.reformat( + # student_model=student_model, + # task_type=data_generation_task_type, + # data=output_data + # ) with open(output_file_path, "w") as f: for record in output_data: f.write(json.dumps(record) + "\n")