Skip to content

Commit

Permalink
feat: trainer can actively push back to clients by group identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
jazelly committed Jun 17, 2024
1 parent 483f46a commit af2f782
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 45 deletions.
1 change: 1 addition & 0 deletions trainer/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ django
django-ratelimit
python-dotenv
channels
channels_redis
daphne
watchgod
psutil
Expand Down
31 changes: 15 additions & 16 deletions trainer/trainer_api/finetune/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,17 @@ def train(trainer, new_model=get_new_model_name()):
# Save trained model
trainer.model.save_pretrained(new_model)

def save_trained_model():
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()

model = get_quantized_model(model_name, bnb_config, device_map)

Expand Down Expand Up @@ -468,7 +479,7 @@ def train(trainer, new_model=get_new_model_name()):
task="text-generation", model=model, tokenizer=tokenizer, max_length=200
)
result = pipe(f"<s>[INST] {prompt} [/INST]")
print(result[0]["generated_text"])
logger.info(result[0]["generated_text"])

del model
del pipe
Expand All @@ -479,19 +490,7 @@ def train(trainer, new_model=get_new_model_name()):
gc.collect()


def save_trained_model():
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()


save_trained_model()
if trainable:
save_trained_model()

saved_tokenizer = tokenizer.save_pretrained(new_model)
saved_tokenizer = tokenizer.save_pretrained(new_model)
42 changes: 27 additions & 15 deletions trainer/trainer_api/scheduler/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import subprocess
import sys
import os
Expand All @@ -7,12 +8,12 @@
from trainer_api.utils.errors import InvalidArgumentError
from trainer_api.utils.constants import (
BASE_MODELS,
FINETUNE_SCRIPT_DIR,
FINETUNE_SCRIPT_PATH,
TRAINING_METHODS,
)
from trainer_api.utils import logging

FINETUNE_SCRIPT_PATH = os.path.join(FINETUNE_SCRIPT_DIR, "./sft.py")
logger = logging.get_stream_logger("trainer_api.scheduler.task", "Task")


class Task:
Expand Down Expand Up @@ -54,27 +55,38 @@ def run(self):
process.wait()

def _assemble_command(self):
r = ["python", FINETUNE_SCRIPT_PATH]
r.append("--model")
r.append(f"{self.model}")
r.append("--method")
r.append(f"{self.method}")
r.append("--dataset")
r.append(f"{self.dataset}")
r = [
"python",
FINETUNE_SCRIPT_PATH,
"--model",
self.model,
"--method",
self.method,
"--dataset",
self.dataset,
]
return r

def _consume_logs_from_subprocess(self, process):
for pipe in (process.stdout, process.stderr):
for line in iter(pipe.readline, b""):
if len(line) > 0:
print(f"for ws: {line}")
logger.info(f"for ws: {line}")
self.ws.send_message_to_client_sync(
client_id=self.ws.scope["client"][1],
responseJson={
"message": line,
"type": "log",
},
response=json.dumps(
{
"type": "info",
"message": "new log",
"data": {
"task_id": str(self.id),
"log": line,
},
"code": 200,
}
),
)
# else:
# logger.warning(f"an empty line: {line}")

def __str__(self):
return f"[Task] method: {self.method if hasattr(self, 'method') else 'None'} training | model: {self.model if hasattr(self, 'model') else 'None'}"
1 change: 1 addition & 0 deletions trainer/trainer_api/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class WorkerStates(Enum):


FINETUNE_SCRIPT_DIR = os.path.join(settings.BASE_DIR, "./trainer_api/finetune")
FINETUNE_SCRIPT_PATH = os.path.join(FINETUNE_SCRIPT_DIR, "./sft.py")

LOG_DIR = log_path = os.path.join(settings.BASE_DIR, "trainer_api/logs/")

Expand Down
40 changes: 26 additions & 14 deletions trainer/trainer_api/ws/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ async def connect(self):
f"----------NEW CONNECT COMING-----------------------"
)
training_consumer_logger.info(f"[SCOPE]: {self.scope}")
client_port = self.scope["client"][1]
await self.channel_layer.group_add(str(client_port), self.channel_name)
# TODO: not use client_port as group name. Instead, use jobId/taskId
self.client_port = self.scope["client"][1]
await self.channel_layer.group_add(str(self.client_port), self.channel_name)

# Accept the WebSocket connection
await self.accept()
Expand All @@ -30,11 +31,10 @@ async def connect(self):

async def disconnect(self, close_code):
# Clean up when the WebSocket closes
client_port = self.scope["client"][1]
training_consumer_logger.info(
f"Client disaconnected: {self.scope['client']} disconnected with {self.channel_name} | Close Code: {close_code}"
)
await self.channel_layer.group_discard(str(client_port), self.channel_name)
await self.channel_layer.group_discard(str(self.client_port), self.channel_name)

async def receive(self, text_data):
"""
Expand All @@ -52,11 +52,10 @@ async def receive(self, text_data):
data = text_data_json.get("data")

training_consumer_logger.info(
f"Received message from {self.scope['client'][1]}: {type} | {message} | {data}"
f"Received message for {self.client_port}: {type} | {message} | {data}"
)

if type == "start":

if type == "command":
if (
not data
or data.get("baseModel") not in BASE_MODELS
Expand All @@ -79,7 +78,7 @@ async def receive(self, text_data):
return

try:
# schedule the task and repond immediately
# schedule the task and respond immediately
training_consumer_logger.info("[Worker] Submitting task")
worker = Worker()

Expand Down Expand Up @@ -115,16 +114,29 @@ async def receive(self, text_data):
)
)

async def send_message_to_client(self, client_id, responseJson):
async def send_job_update(self, event):
message = event["message"]
assert isinstance(message, str), "jon update must be str"

await self.send(text_data=message)

async def send_message_to_client(self, responseJson):
channel_layer = get_channel_layer()
training_consumer_logger.info(f"Sending message to {client_id}: {responseJson}")
training_consumer_logger.info(
f"Sending message to {self.client_port}: {responseJson}"
)
await channel_layer.group_send(
str(client_id), {"type": "targeted", **responseJson}
str(self.client_port),
{"type": "send_job_update", "message": responseJson},
)

def send_message_to_client_sync(self, client_id, responseJson):
def send_message_to_client_sync(self, response):
assert isinstance(response, str), "response must be str"
channel_layer = get_channel_layer()
training_consumer_logger.info(f"Sending message to {client_id}: {responseJson}")
training_consumer_logger.info(
f"Sending message to group {self.client_port}: {response}"
)
async_to_sync(channel_layer.group_send)(
client_id, {"type": "targeted", **responseJson}
str(self.client_port),
{"type": "send_job_update", "message": response},
)

0 comments on commit af2f782

Please sign in to comment.