From ba6845925d4465653a4ad1ad54d645bc4b8c8463 Mon Sep 17 00:00:00 2001 From: yhirose Date: Thu, 16 Jan 2025 22:36:07 -0500 Subject: [PATCH] Fix #2014 --- httplib.h | 30 ++++++++++++++++++++++++++---- test/test.cc | 48 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/httplib.h b/httplib.h index 096e944903..7ef7369782 100644 --- a/httplib.h +++ b/httplib.h @@ -2012,18 +2012,34 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) { callback(static_cast(sec), static_cast(usec)); } +inline bool is_numeric(const std::string &str) { + return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); +} + inline uint64_t get_header_value_u64(const Headers &headers, const std::string &key, uint64_t def, - size_t id) { + size_t id, bool &is_invalid_value) { + is_invalid_value = false; auto rng = headers.equal_range(key); auto it = rng.first; std::advance(it, static_cast(id)); if (it != rng.second) { - return std::strtoull(it->second.data(), nullptr, 10); + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } } return def; } +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + } // namespace detail inline uint64_t Request::get_header_value_u64(const std::string &key, @@ -4433,8 +4449,14 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } else if (!has_header(x.headers, "Content-Length")) { ret = read_content_without_length(strm, out); } else { - auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0); - if (len > payload_max_length) { + auto is_invalid_value = false; + auto len = get_header_value_u64(x.headers, "Content-Length", + std::numeric_limits::max(), + 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { exceed_payload_max_length = true; skip_content_with_length(strm, len); ret = false; diff --git a/test/test.cc b/test/test.cc index ebc50f6f01..5372832754 100644 --- a/test/test.cc +++ b/test/test.cc @@ -511,6 +511,15 @@ TEST(GetHeaderValueTest, RegularValueInt) { EXPECT_EQ(100ull, val); } +TEST(GetHeaderValueTest, RegularInvalidValueInt) { + Headers headers = {{"Content-Length", "x"}}; + auto is_invalid_value = false; + auto val = detail::get_header_value_u64(headers, "Content-Length", 0, 0, + is_invalid_value); + EXPECT_EQ(0ull, val); + EXPECT_TRUE(is_invalid_value); +} + TEST(GetHeaderValueTest, Range) { { Headers headers = {make_range_header({{1, -1}})}; @@ -7496,9 +7505,9 @@ TEST(MultipartFormDataTest, CloseDelimiterWithoutCRLF) { "text2" "\r\n------------"; - std::string resonse; - ASSERT_TRUE(send_request(1, req, &resonse)); - ASSERT_EQ("200", resonse.substr(9, 3)); + std::string response; + ASSERT_TRUE(send_request(1, req, &response)); + ASSERT_EQ("200", response.substr(9, 3)); } TEST(MultipartFormDataTest, ContentLength) { @@ -7543,11 +7552,10 @@ TEST(MultipartFormDataTest, ContentLength) { "text2" "\r\n------------\r\n"; - std::string resonse; - ASSERT_TRUE(send_request(1, req, &resonse)); - ASSERT_EQ("200", resonse.substr(9, 3)); + std::string response; + ASSERT_TRUE(send_request(1, req, &response)); + ASSERT_EQ("200", response.substr(9, 3)); } - #endif TEST(TaskQueueTest, IncreaseAtomicInteger) { @@ -8007,6 +8015,32 @@ TEST(InvalidHeaderCharsTest, OnServer) { } } +TEST(InvalidHeaderValueTest, InvalidContentLength) { + auto handled = false; + + Server svr; + svr.Post("/test", [&](const Request &, Response &) { handled = true; }); + + thread t = thread([&] { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + ASSERT_FALSE(handled); + }); + + svr.wait_until_ready(); + + auto req = "POST /test HTTP/1.1\r\n" + "Content-Length: x\r\n" + "\r\n"; + + std::string response; + ASSERT_TRUE(send_request(1, req, &response)); + ASSERT_EQ("HTTP/1.1 400 Bad Request", + response.substr(0, response.find("\r\n"))); +} + #ifndef _WIN32 TEST(Expect100ContinueTest, ServerClosesConnection) { static constexpr char reject[] = "Unauthorized";