-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
126 lines (106 loc) · 5.35 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import urllib.error
import pika
from train import Model
from dataset import get_dataset
import tensorflow as tf
import requests
import os
from numba import cuda
class Trainer:
connection = None
channel = None
queue = ''
def __init__(self, host, queue):
self.connection = pika.BlockingConnection(
pika.ConnectionParameters(host=host, virtual_host=os.environ['VHOST']))
self.channel = self.connection.channel()
self.queue = queue
def run(self):
# self.channel.queue_declare(queue=self.queue)
self.channel.basic_qos(prefetch_count=1)
self.channel.basic_consume(queue=self.queue, on_message_callback=train_callback, auto_ack=True)
print(' [*] Waiting for messages. To exit press CTRL+C')
self.channel.start_consuming()
def train_callback(ch, method, props, body):
data = None
label = None
req_body = json.loads(body)
headers = {
'Content-Type': 'application/json; charset=utf-8',
'train_id': str(req_body['train_id'])
}
try:
model = Model(req_body['config'], req_body['user_id'], req_body['train_id'], req_body['project_no'])
except ValueError as e:
res = {'status_code': 400, 'msg': e.args[0], 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
#data, label = get_dataset(req_body['data_set'], model.model)
res = {'status_code': 200, 'msg': 'start loading dataset...', 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/log', res, headers)
try:
data, label = get_dataset(req_body['data_set'], model.model)
except:
res = {'status_code': 400, 'msg': f'failed to get dataset from {req_body["data_set"]["train_uri"]}', 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
res = {'status_code': 200, 'msg': 'loading dataset finished', 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/log', res, headers)
try:
model.fit(data, label, req_body['data_set']['kind'])
except tf.errors.InvalidArgumentError as e:
res = {'status_code': 500, 'msg': e, 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
except tf.errors.AbortedError as e:
res = {'status_code': 500, 'msg': e, 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
except tf.errors.FailedPreconditionError as e:
res = {'status_code': 500, 'msg': e, 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
except tf.errors.UnknownError as e:
res = {'status_code': 500, 'msg': e, 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
except json.JSONDecodeError as e:
res = {'status_code': 500, 'msg': e, 'train_id': req_body['train_id']}
reply_request(
f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply',
res, headers)
return
except ValueError as e :
res = {'status_code': 500, 'msg': e.args[0], 'train_id': req_body['train_id']}
reply_request(
f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply',
res, headers)
return
except:
print("error occurred while training.")
res = {'status_code': 500, 'msg': 'internal server error while training', 'train_id': req_body['train_id']}
reply_request(
f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply',
res, headers)
return
try:
model.save_model()
except:
res = {'status_code': 500, 'msg': 'OS error', 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
# for releasing GPU memory
# device = cuda.get_current_device()
# device.reset()
res = {'status_code': 200, 'msg': 'Train finished successfully.', 'train_id': req_body['train_id']}
reply_request(f'https://{os.environ["API_SERVER"]}/api/project/{req_body["project_no"]}/train/{req_body["train_id"]}/reply', res, headers)
return
def reply_request(url, data, headers):
data = json.dumps(data).encode('utf-8')
try:
res = requests.post(url, data=data, headers=headers)
except urllib.error.URLError as e:
return e
else:
return res