From 2c45d8ba0fea2c14718444daab1b926c069d6e3c Mon Sep 17 00:00:00 2001 From: Vit Brunner Date: Fri, 3 Aug 2018 19:59:05 +0200 Subject: [PATCH] Server to serve the network results --- __init__.py | 0 munch/convert_records.py | 2 +- munch/converter.py | 54 ++++++++++++++++++++++++++-------- munch/mutable.py | 2 +- run_server.py | 3 ++ server/predictor.py | 21 ++++++++++++++ server/results.py | 63 ++++++++++++++++++++++++++++++++++++++++ server/serve.py | 25 ++++++++++++++++ server/test.sh | 13 +++++++++ 9 files changed, 169 insertions(+), 14 deletions(-) create mode 100644 __init__.py create mode 100644 run_server.py create mode 100644 server/predictor.py create mode 100644 server/results.py create mode 100644 server/serve.py create mode 100644 server/test.sh diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/munch/convert_records.py b/munch/convert_records.py index b927403..9653823 100644 --- a/munch/convert_records.py +++ b/munch/convert_records.py @@ -1,4 +1,4 @@ -from mutable import Converter +from .mutable import Converter import numpy as np import h5py diff --git a/munch/converter.py b/munch/converter.py index b8fe4b5..bb00123 100644 --- a/munch/converter.py +++ b/munch/converter.py @@ -2,6 +2,8 @@ import textwrap from pprint import pprint +import numpy as np + columns = list('abcdefghi') def square(dim): @@ -73,7 +75,7 @@ def rotate(square, should_rotate = True): lst = [row[::-1] for row in square] return lst[::-1] -def convert_one(moves, expected): +def convert_in(moves, should_rotate): horizontal_walls = out_walls(moves, 'h') vertical_walls = out_walls(moves, 'v') @@ -89,12 +91,6 @@ def convert_one(moves, expected): x_walls_left = 10 - len(get_walls(x_moves)) o_walls_left = 10 - len(get_walls(o_moves)) - expected_horizontal, expected_vertical, expected_pawn = out_expected(expected) - - should_rotate = False - if len(moves) % 2 == 1: - should_rotate = True - if should_rotate: onturn_pawn = rotate(o_pawn) onturn_walls = o_walls_left @@ -106,18 +102,52 @@ def convert_one(moves, expected): other_pawn = o_pawn other_walls = o_walls_left - return "\n%s\n\n%s\n\n%s\n\n%s\n\n%s\n\n%s\n\n-\n\n%s\n\n%s\n\n%s\n" % ( - out_square(rotate(horizontal_walls, should_rotate)), - out_square(rotate(vertical_walls, should_rotate)), - out_square(onturn_pawn), + return ( + rotate(horizontal_walls, should_rotate), + rotate(vertical_walls, should_rotate), + onturn_pawn, onturn_walls, - out_square(other_pawn), + other_pawn, other_walls, + ) + +def convert_one(moves, expected): + should_rotate = len(moves) % 2 == 1 + hw, vw, tp, tw, op, ow = convert_in(moves, should_rotate) + + expected_horizontal, expected_vertical, expected_pawn = out_expected(expected) + + return "\n%s\n\n%s\n\n%s\n\n%s\n\n%s\n\n%s\n\n-\n\n%s\n\n%s\n\n%s\n" % ( + out_square(hw), + out_square(vw), + out_square(tp), + tw, + out_square(op), + ow, out_square(rotate(expected_horizontal, should_rotate)), out_square(rotate(expected_vertical, should_rotate)), out_square(rotate(expected_pawn, should_rotate)), ) +def pad_walls(walls): + return np.pad(np.array(walls), (0, 1), 'constant') + +def pad_wallcount(wallcount): + return np.full((9, 9), wallcount) + +def convert_record(moves): + should_rotate = len(moves) % 2 == 1 + hw, vw, tp, tw, op, ow = convert_in(moves, should_rotate) + + return np.array([ + pad_walls(hw), + pad_walls(vw), + tp, + pad_wallcount(tw), + op, + pad_wallcount(ow), + ]) + def convert(record): split = record.split(";") diff --git a/munch/mutable.py b/munch/mutable.py index 8efb93f..75fd41e 100644 --- a/munch/mutable.py +++ b/munch/mutable.py @@ -1,4 +1,4 @@ -from converter import * +from .converter import * from collections import namedtuple import numpy as np diff --git a/run_server.py b/run_server.py new file mode 100644 index 0000000..cec8605 --- /dev/null +++ b/run_server.py @@ -0,0 +1,3 @@ +from server.serve import run + +run(port=8008) diff --git a/server/predictor.py b/server/predictor.py new file mode 100644 index 0000000..b8cb83e --- /dev/null +++ b/server/predictor.py @@ -0,0 +1,21 @@ +import h5py +import numpy as np + +from keras.models import load_model +from keras_resnet import custom_objects + +from .results import order, rotate +from munch.converter import convert_record + +network = 'nn/01_first_nn.h5' +model = load_model(network) + +def predict(game_record): + moves = game_record.split(';') + nparr = np.array([convert_record(moves)]) + predictions = order(model.predict(nparr)[0].tolist()) + + if len(moves) % 2 == 1: + predictions = [rotate(p) for p in predictions] + + return predictions diff --git a/server/results.py b/server/results.py new file mode 100644 index 0000000..5003aef --- /dev/null +++ b/server/results.py @@ -0,0 +1,63 @@ +hp = list("abcdefghi") +vp = list("987654321") + +hw = list("abcdefghx") +vw = list("87654321x") + + +available_moves = [] +for v in vw: + for h in hw: + available_moves.append(h + v + "h") + +for v in vw: + for h in hw: + available_moves.append(h + v + "v") + +for v in vp: + for h in hp: + available_moves.append(h + v) + + +p_rotations = {} +for i, val in enumerate(hp): + p_rotations[val] = hp[8 - i] +for i, val in enumerate(vp): + p_rotations[val] = vp[8 - i] + +w_rotations = {} +for i, val in enumerate(hw[:8]): + w_rotations[val] = hw[7 - i] +for i, val in enumerate(vw[:8]): + w_rotations[val] = vw[7 - i] + + +def convert(pred): + prob, move = pred + return (move, int(prob * 1000)) + +def nonzero(pred): + prob, move = pred + return prob != 0 + +def order(predictions): + preds = sorted(list(zip(predictions, available_moves)), reverse=True) + filtered = filter(nonzero, map(convert, preds)) + + return list(filtered)[:20] + + +def rotate(pred): + move, prob = pred + + col = move[0] + row = move[1] + + if len(move) == 3: + newmove = w_rotations[col] + w_rotations[row] + move[2] + elif len(move) == 2: + newmove = p_rotations[col] + p_rotations[row] + else: + raise "oops" + + return (newmove, prob) diff --git a/server/serve.py b/server/serve.py new file mode 100644 index 0000000..bf14eff --- /dev/null +++ b/server/serve.py @@ -0,0 +1,25 @@ +from .predictor import predict + +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import urlparse, parse_qs +import json + +class S(BaseHTTPRequestHandler): + def _set_headers(self): + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + + def do_GET(self): + self._set_headers() + url = urlparse(self.path) + game = url.query.replace("game=", "") + + encoded = json.dumps(predict(game)) + self.wfile.write((encoded + "\n").encode('utf-8')) + +def run(server_class=HTTPServer, handler_class=S, port=80): + server_address = ('', port) + httpd = server_class(server_address, handler_class) + print('Starting httpd...') + httpd.serve_forever() diff --git a/server/test.sh b/server/test.sh new file mode 100644 index 0000000..8980ca1 --- /dev/null +++ b/server/test.sh @@ -0,0 +1,13 @@ +#curl "http://localhost:8008?game=" + +#curl "http://localhost:8008?game=d1;d9;d2;d8;d3;d7;a1h;d6;c1h;d5;d4v;d8v;d2v;f5h;c5h;h5h;d4;e4v;b6v;a8h;c8v;c5;c4;b5;c5;b4v" +#curl "http://localhost:8008?game=d1;d9;d2;d8;d3;d7;a1h;d6;c1h;d5;d4v;d8v;d2v;f5h;c5h;h5h;d4;e4v;b6v;a8h;c8v;c5;c4;b5;c5" + +curl "http://localhost:8008?game=e2;e8;e3;f8;e4;d4h;d3h;e3v;d4;b4h;b3h;f7;f6h;g7;c4;h6v;g6v;f7;b4" +curl "http://localhost:8008?game=e2;e8;e3;f8;e4;d4h;d3h;e3v;d4;b4h;b3h;f7;f6h;g7;c4;h6v;g6v;f7;b4;a5h" +curl "http://localhost:8008?game=e2;e8;e3;f8;e4;d4h;d3h;e3v;d4;b4h;b3h;f7;f6h;g7;c4;h6v;g6v;f7;b4;a5h;e1v" +curl "http://localhost:8008?game=e2;e8;e3;f8;e4;d4h;d3h;e3v;d4;b4h;b3h;f7;f6h;g7;c4;h6v;g6v;f7;b4;a5h;e1v;e7" +curl "http://localhost:8008?game=e2;e8;e3;f8;e4;d4h;d3h;e3v;d4;b4h;b3h;f7;f6h;g7;c4;h6v;g6v;f7;b4;a5h;e1v;e7;a4" + +# clash in the middle +# curl "http://localhost:8008?game=e2;e8;e3;e7;e4;e6;e5"