From af2f782c501cda27b5d17db1c98214c2c536ca65 Mon Sep 17 00:00:00 2001 From: jazelly Date: Mon, 17 Jun 2024 23:11:04 +0930 Subject: [PATCH] feat: trainer can actively push back to clients by group identifier --- trainer/requirements.txt | 1 + trainer/trainer_api/finetune/sft.py | 31 +++++++++---------- trainer/trainer_api/scheduler/task.py | 42 +++++++++++++++++--------- trainer/trainer_api/utils/constants.py | 1 + trainer/trainer_api/ws/consumers.py | 40 +++++++++++++++--------- 5 files changed, 70 insertions(+), 45 deletions(-) diff --git a/trainer/requirements.txt b/trainer/requirements.txt index ff0211f..65e8b81 100644 --- a/trainer/requirements.txt +++ b/trainer/requirements.txt @@ -2,6 +2,7 @@ django django-ratelimit python-dotenv channels +channels_redis daphne watchgod psutil diff --git a/trainer/trainer_api/finetune/sft.py b/trainer/trainer_api/finetune/sft.py index e244e76..15e281f 100644 --- a/trainer/trainer_api/finetune/sft.py +++ b/trainer/trainer_api/finetune/sft.py @@ -408,6 +408,17 @@ def train(trainer, new_model=get_new_model_name()): # Save trained model trainer.model.save_pretrained(new_model) +def save_trained_model(): + # Reload model in FP16 and merge it with LoRA weights + base_model = AutoModelForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=True, + return_dict=True, + torch_dtype=torch.float16, + device_map=device_map, + ) + model = PeftModel.from_pretrained(base_model, new_model) + model = model.merge_and_unload() model = get_quantized_model(model_name, bnb_config, device_map) @@ -468,7 +479,7 @@ def train(trainer, new_model=get_new_model_name()): task="text-generation", model=model, tokenizer=tokenizer, max_length=200 ) result = pipe(f"[INST] {prompt} [/INST]") -print(result[0]["generated_text"]) +logger.info(result[0]["generated_text"]) del model del pipe @@ -479,19 +490,7 @@ def train(trainer, new_model=get_new_model_name()): gc.collect() -def save_trained_model(): - # Reload model in FP16 and merge it with LoRA weights - base_model = AutoModelForCausalLM.from_pretrained( - model_name, - low_cpu_mem_usage=True, - return_dict=True, - torch_dtype=torch.float16, - device_map=device_map, - ) - model = PeftModel.from_pretrained(base_model, new_model) - model = model.merge_and_unload() - - -save_trained_model() +if trainable: + save_trained_model() -saved_tokenizer = tokenizer.save_pretrained(new_model) + saved_tokenizer = tokenizer.save_pretrained(new_model) diff --git a/trainer/trainer_api/scheduler/task.py b/trainer/trainer_api/scheduler/task.py index 64df197..1d524cf 100644 --- a/trainer/trainer_api/scheduler/task.py +++ b/trainer/trainer_api/scheduler/task.py @@ -1,3 +1,4 @@ +import json import subprocess import sys import os @@ -7,12 +8,12 @@ from trainer_api.utils.errors import InvalidArgumentError from trainer_api.utils.constants import ( BASE_MODELS, - FINETUNE_SCRIPT_DIR, + FINETUNE_SCRIPT_PATH, TRAINING_METHODS, ) from trainer_api.utils import logging -FINETUNE_SCRIPT_PATH = os.path.join(FINETUNE_SCRIPT_DIR, "./sft.py") +logger = logging.get_stream_logger("trainer_api.scheduler.task", "Task") class Task: @@ -54,27 +55,38 @@ def run(self): process.wait() def _assemble_command(self): - r = ["python", FINETUNE_SCRIPT_PATH] - r.append("--model") - r.append(f"{self.model}") - r.append("--method") - r.append(f"{self.method}") - r.append("--dataset") - r.append(f"{self.dataset}") + r = [ + "python", + FINETUNE_SCRIPT_PATH, + "--model", + self.model, + "--method", + self.method, + "--dataset", + self.dataset, + ] return r def _consume_logs_from_subprocess(self, process): for pipe in (process.stdout, process.stderr): for line in iter(pipe.readline, b""): if len(line) > 0: - print(f"for ws: {line}") + logger.info(f"for ws: {line}") self.ws.send_message_to_client_sync( - client_id=self.ws.scope["client"][1], - responseJson={ - "message": line, - "type": "log", - }, + response=json.dumps( + { + "type": "info", + "message": "new log", + "data": { + "task_id": str(self.id), + "log": line, + }, + "code": 200, + } + ), ) + # else: + # logger.warning(f"an empty line: {line}") def __str__(self): return f"[Task] method: {self.method if hasattr(self, 'method') else 'None'} training | model: {self.model if hasattr(self, 'model') else 'None'}" diff --git a/trainer/trainer_api/utils/constants.py b/trainer/trainer_api/utils/constants.py index d490bc0..440d2f5 100644 --- a/trainer/trainer_api/utils/constants.py +++ b/trainer/trainer_api/utils/constants.py @@ -17,6 +17,7 @@ class WorkerStates(Enum): FINETUNE_SCRIPT_DIR = os.path.join(settings.BASE_DIR, "./trainer_api/finetune") +FINETUNE_SCRIPT_PATH = os.path.join(FINETUNE_SCRIPT_DIR, "./sft.py") LOG_DIR = log_path = os.path.join(settings.BASE_DIR, "trainer_api/logs/") diff --git a/trainer/trainer_api/ws/consumers.py b/trainer/trainer_api/ws/consumers.py index a459217..79f360b 100644 --- a/trainer/trainer_api/ws/consumers.py +++ b/trainer/trainer_api/ws/consumers.py @@ -18,8 +18,9 @@ async def connect(self): f"----------NEW CONNECT COMING-----------------------" ) training_consumer_logger.info(f"[SCOPE]: {self.scope}") - client_port = self.scope["client"][1] - await self.channel_layer.group_add(str(client_port), self.channel_name) + # TODO: not use client_port as group name. Instead, use jobId/taskId + self.client_port = self.scope["client"][1] + await self.channel_layer.group_add(str(self.client_port), self.channel_name) # Accept the WebSocket connection await self.accept() @@ -30,11 +31,10 @@ async def connect(self): async def disconnect(self, close_code): # Clean up when the WebSocket closes - client_port = self.scope["client"][1] training_consumer_logger.info( f"Client disaconnected: {self.scope['client']} disconnected with {self.channel_name} | Close Code: {close_code}" ) - await self.channel_layer.group_discard(str(client_port), self.channel_name) + await self.channel_layer.group_discard(str(self.client_port), self.channel_name) async def receive(self, text_data): """ @@ -52,11 +52,10 @@ async def receive(self, text_data): data = text_data_json.get("data") training_consumer_logger.info( - f"Received message from {self.scope['client'][1]}: {type} | {message} | {data}" + f"Received message for {self.client_port}: {type} | {message} | {data}" ) - if type == "start": - + if type == "command": if ( not data or data.get("baseModel") not in BASE_MODELS @@ -79,7 +78,7 @@ async def receive(self, text_data): return try: - # schedule the task and repond immediately + # schedule the task and respond immediately training_consumer_logger.info("[Worker] Submitting task") worker = Worker() @@ -115,16 +114,29 @@ async def receive(self, text_data): ) ) - async def send_message_to_client(self, client_id, responseJson): + async def send_job_update(self, event): + message = event["message"] + assert isinstance(message, str), "jon update must be str" + + await self.send(text_data=message) + + async def send_message_to_client(self, responseJson): channel_layer = get_channel_layer() - training_consumer_logger.info(f"Sending message to {client_id}: {responseJson}") + training_consumer_logger.info( + f"Sending message to {self.client_port}: {responseJson}" + ) await channel_layer.group_send( - str(client_id), {"type": "targeted", **responseJson} + str(self.client_port), + {"type": "send_job_update", "message": responseJson}, ) - def send_message_to_client_sync(self, client_id, responseJson): + def send_message_to_client_sync(self, response): + assert isinstance(response, str), "response must be str" channel_layer = get_channel_layer() - training_consumer_logger.info(f"Sending message to {client_id}: {responseJson}") + training_consumer_logger.info( + f"Sending message to group {self.client_port}: {response}" + ) async_to_sync(channel_layer.group_send)( - client_id, {"type": "targeted", **responseJson} + str(self.client_port), + {"type": "send_job_update", "message": response}, )