Skip to content

Commit

Permalink
Refactor streams: rename is_* to wait_* for clarity (#2069)
Browse files Browse the repository at this point in the history
- Replace is_readable() with wait_readable() and is_writable() with
  wait_writable() in the Stream interface.
- Implement a new is_readable() function with semantics that more
  closely reflect its name. It returns immediately whether data is
  available for reading, without waiting.
- Update call sites of is_writable(), removing redundant checks.
  • Loading branch information
falbrechtskirchinger authored Feb 20, 2025
1 parent a4b2c61 commit 550f728
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 33 deletions.
64 changes: 37 additions & 27 deletions httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,8 @@ class Stream {
virtual ~Stream() = default;

virtual bool is_readable() const = 0;
virtual bool is_writable() const = 0;
virtual bool wait_readable() const = 0;
virtual bool wait_writable() const = 0;

virtual ssize_t read(char *ptr, size_t size) = 0;
virtual ssize_t write(const char *ptr, size_t size) = 0;
Expand Down Expand Up @@ -2466,7 +2467,8 @@ class BufferStream final : public Stream {
~BufferStream() override = default;

bool is_readable() const override;
bool is_writable() const override;
bool wait_readable() const override;
bool wait_writable() const override;
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
Expand Down Expand Up @@ -3380,7 +3382,8 @@ class SocketStream final : public Stream {
~SocketStream() override;

bool is_readable() const override;
bool is_writable() const override;
bool wait_readable() const override;
bool wait_writable() const override;
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
Expand Down Expand Up @@ -3416,7 +3419,8 @@ class SSLSocketStream final : public Stream {
~SSLSocketStream() override;

bool is_readable() const override;
bool is_writable() const override;
bool wait_readable() const override;
bool wait_writable() const override;
ssize_t read(char *ptr, size_t size) override;
ssize_t write(const char *ptr, size_t size) override;
void get_remote_ip_and_port(std::string &ip, int &port) const override;
Expand Down Expand Up @@ -4578,7 +4582,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,

data_sink.write = [&](const char *d, size_t l) -> bool {
if (ok) {
if (strm.is_writable() && write_data(strm, d, l)) {
if (write_data(strm, d, l)) {
offset += l;
} else {
ok = false;
Expand All @@ -4587,10 +4591,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
return ok;
};

data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };

while (offset < end_offset && !is_shutting_down()) {
if (!strm.is_writable()) {
if (!strm.wait_writable()) {
error = Error::Write;
return false;
} else if (!content_provider(offset, end_offset - offset, data_sink)) {
Expand Down Expand Up @@ -4628,17 +4632,17 @@ write_content_without_length(Stream &strm,
data_sink.write = [&](const char *d, size_t l) -> bool {
if (ok) {
offset += l;
if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; }
if (!write_data(strm, d, l)) { ok = false; }
}
return ok;
};

data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };

data_sink.done = [&](void) { data_available = false; };

while (data_available && !is_shutting_down()) {
if (!strm.is_writable()) {
if (!strm.wait_writable()) {
return false;
} else if (!content_provider(offset, 0, data_sink)) {
return false;
Expand Down Expand Up @@ -4673,10 +4677,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
// Emit chunked response header and footer for each chunk
auto chunk =
from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
if (!strm.is_writable() ||
!write_data(strm, chunk.data(), chunk.size())) {
ok = false;
}
if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; }
}
} else {
ok = false;
Expand All @@ -4685,7 +4686,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
return ok;
};

data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };

auto done_with_trailer = [&](const Headers *trailer) {
if (!ok) { return; }
Expand All @@ -4705,8 +4706,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
if (!payload.empty()) {
// Emit chunked response header and footer for each chunk
auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
if (!strm.is_writable() ||
!write_data(strm, chunk.data(), chunk.size())) {
if (!write_data(strm, chunk.data(), chunk.size())) {
ok = false;
return;
}
Expand Down Expand Up @@ -4738,7 +4738,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
};

while (data_available && !is_shutting_down()) {
if (!strm.is_writable()) {
if (!strm.wait_writable()) {
error = Error::Write;
return false;
} else if (!content_provider(offset, 0, data_sink)) {
Expand Down Expand Up @@ -6029,6 +6029,10 @@ inline SocketStream::SocketStream(
inline SocketStream::~SocketStream() = default;

inline bool SocketStream::is_readable() const {
return read_buff_off_ < read_buff_content_size_;
}

inline bool SocketStream::wait_readable() const {
if (max_timeout_msec_ <= 0) {
return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
}
Expand All @@ -6041,7 +6045,7 @@ inline bool SocketStream::is_readable() const {
return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
}

inline bool SocketStream::is_writable() const {
inline bool SocketStream::wait_writable() const {
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
is_socket_alive(sock_);
}
Expand All @@ -6068,7 +6072,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
}
}

if (!is_readable()) { return -1; }
if (!wait_readable()) { return -1; }

read_buff_off_ = 0;
read_buff_content_size_ = 0;
Expand All @@ -6093,7 +6097,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
}

inline ssize_t SocketStream::write(const char *ptr, size_t size) {
if (!is_writable()) { return -1; }
if (!wait_writable()) { return -1; }

#if defined(_WIN32) && !defined(_WIN64)
size =
Expand Down Expand Up @@ -6124,7 +6128,9 @@ inline time_t SocketStream::duration() const {
// Buffer stream implementation
inline bool BufferStream::is_readable() const { return true; }

inline bool BufferStream::is_writable() const { return true; }
inline bool BufferStream::wait_readable() const { return true; }

inline bool BufferStream::wait_writable() const { return true; }

inline ssize_t BufferStream::read(char *ptr, size_t size) {
#if defined(_MSC_VER) && _MSC_VER < 1910
Expand Down Expand Up @@ -9161,6 +9167,10 @@ inline SSLSocketStream::SSLSocketStream(
inline SSLSocketStream::~SSLSocketStream() = default;

inline bool SSLSocketStream::is_readable() const {
return SSL_pending(ssl_) > 0;
}

inline bool SSLSocketStream::wait_readable() const {
if (max_timeout_msec_ <= 0) {
return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
}
Expand All @@ -9173,15 +9183,15 @@ inline bool SSLSocketStream::is_readable() const {
return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
}

inline bool SSLSocketStream::is_writable() const {
inline bool SSLSocketStream::wait_writable() const {
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_);
}

inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
if (SSL_pending(ssl_) > 0) {
return SSL_read(ssl_, ptr, static_cast<int>(size));
} else if (is_readable()) {
} else if (wait_readable()) {
auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
if (ret < 0) {
auto err = SSL_get_error(ssl_, ret);
Expand All @@ -9195,7 +9205,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
#endif
if (SSL_pending(ssl_) > 0) {
return SSL_read(ssl_, ptr, static_cast<int>(size));
} else if (is_readable()) {
} else if (wait_readable()) {
std::this_thread::sleep_for(std::chrono::microseconds{10});
ret = SSL_read(ssl_, ptr, static_cast<int>(size));
if (ret >= 0) { return ret; }
Expand All @@ -9212,7 +9222,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
}

inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
if (is_writable()) {
if (wait_writable()) {
auto handle_size = static_cast<int>(
std::min<size_t>(size, (std::numeric_limits<int>::max)()));

Expand All @@ -9227,7 +9237,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
#else
while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
#endif
if (is_writable()) {
if (wait_writable()) {
std::this_thread::sleep_for(std::chrono::microseconds{10});
ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
if (ret >= 0) { return ret; }
Expand Down
4 changes: 3 additions & 1 deletion test/fuzzing/server_fuzzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class FuzzedStream : public httplib::Stream {

bool is_readable() const override { return true; }

bool is_writable() const override { return true; }
bool wait_readable() const override { return true; }

bool wait_writable() const override { return true; }

void get_remote_ip_and_port(std::string &ip, int &port) const override {
ip = "127.0.0.1";
Expand Down
10 changes: 5 additions & 5 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ TEST_F(UnixSocketTest, abstract) {
}
#endif

TEST(SocketStream, is_writable_UNIX) {
TEST(SocketStream, wait_writable_UNIX) {
int fds[2];
ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));

Expand All @@ -167,17 +167,17 @@ TEST(SocketStream, is_writable_UNIX) {
};
asSocketStream(fds[0], [&](Stream &s0) {
EXPECT_EQ(s0.socket(), fds[0]);
EXPECT_TRUE(s0.is_writable());
EXPECT_TRUE(s0.wait_writable());

EXPECT_EQ(0, close(fds[1]));
EXPECT_FALSE(s0.is_writable());
EXPECT_FALSE(s0.wait_writable());

return true;
});
EXPECT_EQ(0, close(fds[0]));
}

TEST(SocketStream, is_writable_INET) {
TEST(SocketStream, wait_writable_INET) {
sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
Expand Down Expand Up @@ -212,7 +212,7 @@ TEST(SocketStream, is_writable_INET) {
};
asSocketStream(disconnected_svr_sock, [&](Stream &ss) {
EXPECT_EQ(ss.socket(), disconnected_svr_sock);
EXPECT_FALSE(ss.is_writable());
EXPECT_FALSE(ss.wait_writable());

return true;
});
Expand Down

0 comments on commit 550f728

Please sign in to comment.