From faf6a5df9597d59b1e9d9dd60a6c6b9cc9a3608a Mon Sep 17 00:00:00 2001 From: Povilas Kanapickas Date: Sat, 14 Dec 2024 20:34:34 +0200 Subject: [PATCH] websocket: Convert connection::read_http_upgrade_request() to use coros Using coroutines improves readability of the code. The function is not performance sensitive as well, as it's called only during connection establishment. Closes scylladb/seastar#2583 --- src/websocket/server.cc | 87 +++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/src/websocket/server.cc b/src/websocket/server.cc index f9c26bfe29..7dbaa6a696 100644 --- a/src/websocket/server.cc +++ b/src/websocket/server.cc @@ -147,55 +147,50 @@ static std::string sha1_base64(std::string_view source) { future<> connection::read_http_upgrade_request() { _http_parser.init(); - return _read_buf.consume(_http_parser).then([this] () mutable { - if (_http_parser.eof()) { - _done = true; - return make_ready_future<>(); - } - std::unique_ptr req = _http_parser.get_parsed_request(); - if (_http_parser.failed()) { - return make_exception_future<>(websocket::exception("Incorrect upgrade request")); - } + co_await _read_buf.consume(_http_parser); - sstring upgrade_header = req->get_header("Upgrade"); - if (upgrade_header != "websocket") { - return make_exception_future<>(websocket::exception("Upgrade header missing")); - } + if (_http_parser.eof()) { + _done = true; + co_return; + } + std::unique_ptr req = _http_parser.get_parsed_request(); + if (_http_parser.failed()) { + throw websocket::exception("Incorrect upgrade request"); + } - sstring subprotocol = req->get_header("Sec-WebSocket-Protocol"); - if (subprotocol.empty()) { - return make_exception_future<>(websocket::exception("Subprotocol header missing.")); - } + sstring upgrade_header = req->get_header("Upgrade"); + if (upgrade_header != "websocket") { + throw websocket::exception("Upgrade header missing"); + } - if (!_server.is_handler_registered(subprotocol)) { - return make_exception_future<>(websocket::exception("Subprotocol not supported.")); - } - this->_handler = this->_server._handlers[subprotocol]; - this->_subprotocol = subprotocol; - wlogger.debug("Sec-WebSocket-Protocol: {}", subprotocol); - - sstring sec_key = req->get_header("Sec-Websocket-Key"); - sstring sec_version = req->get_header("Sec-Websocket-Version"); - - sstring sha1_input = sec_key + magic_key_suffix; - - wlogger.debug("Sec-Websocket-Key: {}, Sec-Websocket-Version: {}", sec_key, sec_version); - - std::string sha1_output = sha1_base64(sha1_input); - wlogger.debug("SHA1 output: {} of size {}", sha1_output, sha1_output.size()); - - return _write_buf.write(http_upgrade_reply_template).then([this, sha1_output = std::move(sha1_output)] { - return _write_buf.write(sha1_output); - }).then([this] { - return _write_buf.write("\r\nSec-WebSocket-Protocol: ", 26); - }).then([this] { - return _write_buf.write(_subprotocol); - }).then([this] { - return _write_buf.write("\r\n\r\n", 4); - }).then([this] { - return _write_buf.flush(); - }); - }); + sstring subprotocol = req->get_header("Sec-WebSocket-Protocol"); + if (subprotocol.empty()) { + throw websocket::exception("Subprotocol header missing."); + } + + if (!_server.is_handler_registered(subprotocol)) { + throw websocket::exception("Subprotocol not supported."); + } + this->_handler = this->_server._handlers[subprotocol]; + this->_subprotocol = subprotocol; + wlogger.debug("Sec-WebSocket-Protocol: {}", subprotocol); + + sstring sec_key = req->get_header("Sec-Websocket-Key"); + sstring sec_version = req->get_header("Sec-Websocket-Version"); + + sstring sha1_input = sec_key + magic_key_suffix; + + wlogger.debug("Sec-Websocket-Key: {}, Sec-Websocket-Version: {}", sec_key, sec_version); + + std::string sha1_output = sha1_base64(sha1_input); + wlogger.debug("SHA1 output: {} of size {}", sha1_output, sha1_output.size()); + + co_await _write_buf.write(http_upgrade_reply_template); + co_await _write_buf.write(sha1_output); + co_await _write_buf.write("\r\nSec-WebSocket-Protocol: ", 26); + co_await _write_buf.write(_subprotocol); + co_await _write_buf.write("\r\n\r\n", 4); + co_await _write_buf.flush(); } future websocket_parser::operator()(