From f2c5363c9f2dd61ce34ca92d73a8757c69963d05 Mon Sep 17 00:00:00 2001 From: helintong Date: Tue, 16 Apr 2024 19:26:57 +0800 Subject: [PATCH] feat: add permessage deflate extension for websocket server --- include/cinatra/coro_http_connection.hpp | 65 +++++++++ include/cinatra/coro_http_request.hpp | 8 ++ include/cinatra/gzip.hpp | 162 +++++++++++++++++++++++ include/cinatra/websocket.hpp | 42 ++++++ 4 files changed, 277 insertions(+) diff --git a/include/cinatra/coro_http_connection.hpp b/include/cinatra/coro_http_connection.hpp index f0590379..1d3f07a9 100644 --- a/include/cinatra/coro_http_connection.hpp +++ b/include/cinatra/coro_http_connection.hpp @@ -21,6 +21,9 @@ #include "sha1.hpp" #include "string_resize.hpp" #include "websocket.hpp" +#ifdef CINATRA_ENABLE_GZIP +#include "gzip.hpp" +#endif #include "ylt/coro_io/coro_file.hpp" #include "ylt/coro_io/coro_io.hpp" @@ -132,6 +135,14 @@ class coro_http_connection if (body_len == 0) { if (parser_.method() == "GET"sv) { if (request_.is_upgrade()) { +#ifdef CINATRA_ENABLE_GZIP + if (request_.is_support_compressed()) { + is_client_ws_compressed_ = true; + } + else { + is_client_ws_compressed_ = false; + } +#endif // websocket build_ws_handshake_head(); bool ok = co_await reply(true); // response ws handshake @@ -551,6 +562,32 @@ class coro_http_connection async_simple::coro::Lazy write_websocket( std::string_view msg, opcode op = opcode::text) { +#ifdef CINATRA_ENABLE_GZIP + std::string dest_buf; + if (is_client_ws_compressed_ && data_length > 0) { + if (!cinatra::gzip_codec::deflate(std::string(msg), dest_buf)) { + CINATRA_LOG_ERROR << "compuress data error, data: " << msg; + co_return std::make_error_code(std::errc::protocol_error); + } + + auto header = ws_.format_compressed_header(dest_buf.length(), op); + std::vector buffers; + buffers.push_back(asio::buffer(header)); + buffers.push_back(asio::buffer(dest_buf)); + + auto [ec, sz] = co_await async_write(buffers); + co_return ec; + } + else { + auto header = ws_.format_header(msg.length(), op); + std::vector buffers; + buffers.push_back(asio::buffer(header)); + buffers.push_back(asio::buffer(msg)); + + auto [ec, sz] = co_await async_write(buffers); + co_return ec; + } +#else auto header = ws_.format_header(msg.length(), op); std::vector buffers; buffers.push_back(asio::buffer(header)); @@ -558,6 +595,7 @@ class coro_http_connection auto [ec, sz] = co_await async_write(buffers); co_return ec; +#endif } async_simple::coro::Lazy read_websocket() { @@ -612,8 +650,27 @@ class coro_http_connection break; case cinatra::ws_frame_type::WS_TEXT_FRAME: case cinatra::ws_frame_type::WS_BINARY_FRAME: { +#ifdef CINATRA_ENABLE_GZIP + std::string out; + if (is_client_ws_compressed_) + { + if (!cinatra::gzip_codec::inflate(std::string(payload.begin(), payload.end()), out)) + { + CINATRA_LOG_ERROR << "uncompuress data error"; + result.ec = std::make_error_code(std::errc::protocol_error); + break; + } + result.eof = true; + result.data = {out.data(), out.size()}; + } + else { + result.eof = true; + result.data = {payload.data(), payload.size()}; + } +#else result.eof = true; result.data = {payload.data(), payload.size()}; +#endif } break; case cinatra::ws_frame_type::WS_CLOSE_FRAME: { close_frame close_frame = @@ -803,6 +860,11 @@ class coro_http_connection if (!protocal_str.empty()) { response_.add_header("Sec-WebSocket-Protocol", std::string(protocal_str)); } +#ifdef CINATRA_ENABLE_GZIP + if (is_client_ws_compressed_) { + response_.add_header("Sec-WebSocket-Extensions", "permessage-deflate; client_no_context_takeover"); + } +#endif } private: @@ -825,6 +887,9 @@ class coro_http_connection std::atomic last_rwtime_; uint64_t max_part_size_ = 8 * 1024 * 1024; std::string resp_str_; +#ifdef CINATRA_ENABLE_GZIP + bool is_client_ws_compressed_ = false; +#endif websocket ws_; #ifdef CINATRA_ENABLE_SSL diff --git a/include/cinatra/coro_http_request.hpp b/include/cinatra/coro_http_request.hpp index 36309a12..ea6574fd 100644 --- a/include/cinatra/coro_http_request.hpp +++ b/include/cinatra/coro_http_request.hpp @@ -208,6 +208,14 @@ class coro_http_request { return true; } + bool is_support_compressed() { + auto extension_str = get_header_value("Sec-WebSocket-Extensions"); + if (extension_str.find("permessage-deflate") != std::string::npos) { + return true; + } + return false; + } + void set_aspect_data(std::string data) { aspect_data_.push_back(std::move(data)); } diff --git a/include/cinatra/gzip.hpp b/include/cinatra/gzip.hpp index 400ce6ff..dc13ba0a 100644 --- a/include/cinatra/gzip.hpp +++ b/include/cinatra/gzip.hpp @@ -140,4 +140,166 @@ inline int uncompress_file(const char *src_file, const char *out_file_name) { return 0; } + +bool inflate(const std::string& str_src, std::string& str_dest) +{ + int err = Z_DATA_ERROR; + // Create stream + z_stream zs = { 0 }; + // Set output data streams, do this here to avoid overwriting on recursive calls + const int OUTPUT_BUF_SIZE = 8192; + Bytef bytes_out[OUTPUT_BUF_SIZE] = { 0 }; + + // Initialise the z_stream + err = ::inflateInit2(&zs, -15); + if (err != Z_OK) + { + return false; + } + + // Use whatever input is provided + zs.next_in = (Bytef*)(str_src.c_str()); + zs.avail_in = str_src.length(); + + do { + try + { + // Initialise stream values + //zs->zalloc = (alloc_func)0; + //zs->zfree = (free_func)0; + //zs->opaque = (voidpf)0; + + zs.next_out = bytes_out; + zs.avail_out = OUTPUT_BUF_SIZE; + + // Try to unzip the data + err = ::inflate(&zs, Z_SYNC_FLUSH); + + // Is zip finished reading all currently available input and writing all generated output + if (err == Z_STREAM_END) + { + // Finish up + int kerr = ::inflateEnd(&zs); + + // Got a good result, set the size to the amount unzipped in this call (including all recursive calls) + + str_dest.append((const char*)bytes_out, OUTPUT_BUF_SIZE - zs.avail_out); + return true; + } + else if ((err == Z_OK) && (zs.avail_out == 0) && (zs.avail_in != 0)) + { + // Output array was not big enough, call recursively until there is enough space + + str_dest.append((const char*)bytes_out, OUTPUT_BUF_SIZE - zs.avail_out); + + continue; + } + else if ((err == Z_OK) && (zs.avail_in == 0)) + { + // All available input has been processed, everything ok. + // Set the size to the amount unzipped in this call (including all recursive calls) + str_dest.append((const char*)bytes_out, OUTPUT_BUF_SIZE - zs.avail_out); + + int kerr = ::inflateEnd(&zs); + + break; + } + else + { + return false; + } + } + catch (...) + { + return false; + } + } while (true); + + return err == Z_OK; +} + +bool deflate(const std::string& str_src, std::string& str_dest) +{ + int err = Z_DATA_ERROR; + // Create stream + z_stream zs = { 0 }; + // Set output data streams, do this here to avoid overwriting on recursive calls + const int OUTPUT_BUF_SIZE = 8192; + Bytef bytes_out[OUTPUT_BUF_SIZE] = { 0 }; + + // Initialise the z_stream + err = ::deflateInit2(&zs, 1, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY); + if (err != Z_OK) + { + return false; + } + // Use whatever input is provided + zs.next_in = (Bytef*)(str_src.c_str()); + zs.avail_in = str_src.length(); + + do { + try + { + // Initialise stream values + //zs->zalloc = (alloc_func)0; + //zs->zfree = (free_func)0; + //zs->opaque = (voidpf)0; + + zs.next_out = bytes_out; + zs.avail_out = OUTPUT_BUF_SIZE; + + // Try to unzip the data + err = ::deflate(&zs, Z_SYNC_FLUSH); + + // Is zip finished reading all currently available input and writing all generated output + if (err == Z_STREAM_END) + { + // Finish up + int kerr = ::deflateEnd(&zs); + + // Got a good result, set the size to the amount unzipped in this call (including all recursive calls) + + str_dest.append((const char*)bytes_out, OUTPUT_BUF_SIZE - zs.avail_out); + return true; + } + else if ((err == Z_OK) && (zs.avail_out == 0) && (zs.avail_in != 0)) + { + // Output array was not big enough, call recursively until there is enough space + + str_dest.append((const char*)bytes_out, OUTPUT_BUF_SIZE - zs.avail_out); + + continue; + } + else if ((err == Z_OK) && (zs.avail_in == 0)) + { + // All available input has been processed, everything ok. + // Set the size to the amount unzipped in this call (including all recursive calls) + str_dest.append((const char*)bytes_out, OUTPUT_BUF_SIZE - zs.avail_out); + + int kerr = ::deflateEnd(&zs); + + break; + } + else + { + return false; + } + } + catch (...) + { + return false; + } + } while (true); + + + if (err == Z_OK) + { + // subtract 4 to remove the extra 00 00 ff ff added to the end of the deflat function + str_dest = str_dest.substr(0, str_dest.length() - 4); + return true; + } + + return false; +} + } // namespace cinatra::gzip_codec \ No newline at end of file diff --git a/include/cinatra/websocket.hpp b/include/cinatra/websocket.hpp index 8327ef2e..1b33bd43 100644 --- a/include/cinatra/websocket.hpp +++ b/include/cinatra/websocket.hpp @@ -126,6 +126,48 @@ class websocket { return {msg_header_, header_length}; } + std::string format_compressed_header(size_t data_length, opcode code) { + + std::string destbuf; + char first_two_bytes[2] = { 0 }; + //FIN + first_two_bytes[0] |= 0x80; + + first_two_bytes[0] |= code; + + const char compress_flag = 0x40; + first_two_bytes[0] |= compress_flag; + + //mask = 0; + std::string send_data; + + if (data_length < 126) + { + first_two_bytes[1] = data_length; + send_data.append(first_two_bytes, 2); + } + else if (data_length <= UINT16_MAX) + { + first_two_bytes[1] = 126; + char extended_playload_length[2] = { 0 }; + uint16_t tmp = htons(data_length); + memcpy(&extended_playload_length, &tmp, 2); + send_data.append(first_two_bytes, 2); + send_data.append(extended_playload_length, 2); + } + else + { + first_two_bytes[1] = 127; + char extended_playload_length[8] = {0}; + uint64_t tmp = htobe64((uint64_t)data_length); + memcpy(&extended_playload_length, &tmp, 8); + send_data.append(first_two_bytes, 2); + send_data.append(extended_playload_length, 8); + } + + return send_data; + } + std::string encode_frame(std::span &data, opcode op, bool need_mask, bool eof = true) { std::string header;