forked from albertnadal/tensorflowjs-mnist-live
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.js
115 lines (88 loc) · 3.28 KB
/
server.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/* You can disable Tensorflow warnings with the command 'export TF_CPP_MIN_LOG_LEVEL=2' */
const tf = require('@tensorflow/tfjs-node');
const fs = require("fs");
const data = require('./data');
const http = require("http");
const express = require("express");
const socketio = require("socket.io");
const argparse = require('argparse');
const SERVER_PORT = 1992;
const IMAGE_HEIGHT = 28;
const IMAGE_WIDTH = 28;
let socketConnections = new Set();
let model;
async function loadTrainedModel(modelPath) {
if (modelPath != null) {
if (!fs.existsSync(modelPath)) {
console.log("Model not found. Please, train the model using 'node trainer.js --epochs 1 --model_save_path trained_model'");
exit(0);
}
return await tf.loadLayersModel(`file://${modelPath}/model.json`);
} else {
console.log("Please, provide the path of the trained model using 'node server.js --model_path MODEL_PATH'");
exit(0);
}
}
function onNewWebsocketConnection(socket) {
console.info(`Socket ${socket.id} has connected.`);
socketConnections.add(socket.id);
socket.on("disconnect", () => {
socketConnections.delete(socket.id);
console.info(`Socket ${socket.id} has disconnected.`);
});
socket.on("predictionRequest", async function(msg) {
var uint8buffer = new Uint8Array(msg);
let float32buffer = new Float32Array(IMAGE_HEIGHT*IMAGE_WIDTH);
for(let i = 0; i<IMAGE_HEIGHT*IMAGE_WIDTH; i++) {
float32buffer[i] = uint8buffer[i];
}
/*START DRAW*/
for(let y=0; y<IMAGE_HEIGHT; y++) {
for(let x=0; x<IMAGE_WIDTH; x++) {
process.stdout.write(msg[y*IMAGE_WIDTH + x]+" ");
}
process.stdout.write("\n");
}
/*END DRAW*/
let inputTensor = tf.tensor4d(float32buffer, [1, IMAGE_HEIGHT, IMAGE_WIDTH, 1]);
let predictions = await model.predict(inputTensor, {verbose: true}).data();
var results = [];
predictions.forEach(function (probability, number) {
results.push({number: number, probability: probability});
});
results.sort(function(a, b) {
return b.probability - a.probability;
});
console.log("PREDICTION: ", results);
let topResults = results.slice(0, 3);
socket.emit("predictionResults", JSON.stringify(topResults));
});
}
async function startServer(modelPath) {
// Load Tensorflow model
model = await loadTrainedModel(modelPath);
console.log(`Trained model loaded.`);
// Create a new express server
const app = express();
const server = http.createServer(app);
const io = socketio(server);
// Serve the public frontend
app.use(express.static("public"));
// Handler for every new websocket connection
io.on("connection", onNewWebsocketConnection);
server.listen(SERVER_PORT, () => {
console.info(`Listening on port ${SERVER_PORT}`);
console.log("Open http://localhost:1992/ in your browser to start playing!\n");
});
}
const parser = new argparse.ArgumentParser({
description: 'TensorFlow.js-Node MNIST Server.',
addHelp: true
});
parser.addArgument('--model_path', {
type: 'string',
help: 'Path to which the trained model will be loaded.'
});
const args = parser.parseArgs();
// Start server to receive prediction requests
startServer(args.model_path);