-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask_dispatcher.py
95 lines (74 loc) · 2.64 KB
/
task_dispatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import redis
import json
import argparse
from multiprocessing import Queue
from threading import Thread
from model import Task, TaskInfo
from util import new_task_handler
from config import redis_url, redis_port, redis_password, redis_db, redis_topic, redis_fail
from local_worker import local_worker
from push_worker_router import push_worker_router
from pull_worker_router import pull_worker_router
redis_conn = redis.StrictRedis(host=redis_url, port=redis_port, password=redis_password, db=redis_db)
def main():
args = parse_args()
if args.mode == "local":
hooks = new_task_handler(local_worker, num_processes=args.workers)
elif args.mode == "push":
hooks = new_task_handler(push_worker_router)
elif args.mode == "pull":
hooks = new_task_handler(pull_worker_router)
else:
print(f"TODO: Implement {args.mode} mode")
exit(1)
t, task_queue, result_queue = hooks
#get failed tasks from redis
failedTask = recover_tasks()
for task in failedTask:
task_queue.put(task)
queue_tasks_thread = Thread(target=queue_tasks, args=(task_queue,), daemon=True)
queue_tasks_thread.start()
dequeue_results_thread = Thread(target=dequeue_results, args=(result_queue,), daemon=True)
dequeue_results_thread.start()
t.join()
def recover_tasks():
taskList = []
set_elements = redis_conn.smembers(redis_fail)
for element in set_elements:
task_info_json = redis_conn.get(element)
task_info = TaskInfo(**json.loads(task_info_json))
taskList.append(Task(task_id=element, task_info=task_info))
return taskList
def queue_tasks(task_queue):
pubsub = redis_conn.pubsub()
pubsub.subscribe(redis_topic)
for message in pubsub.listen():
task = get_task(message)
if task != None:
task_queue.put(task)
def get_task(message: "dict[str, any]") -> "Task | None":
if message['type'] != 'message':
return None
task_id = message['data']
try:
task_info_json = redis_conn.get(task_id)
task_info = TaskInfo(**json.loads(task_info_json))
return Task(task_id=task_id, task_info=task_info)
except Exception as e:
print("failed to get the task: ", str(e))
return None
def dequeue_results(result_queue: Queue):
while True:
task = result_queue.get()
redis_conn.set(task.task_id, task.task_info.model_dump_json())
redis_conn.srem(redis_fail, task.task_id)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--mode', type=str, default="push")
parser.add_argument('-p', '--port', type=int, default=20000)
parser.add_argument('-w', '--workers', type=int, default=2)
args = parser.parse_args()
print(args)
return args
if __name__ == '__main__':
main()