Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cases for reverse proxy asserting all buffers are flushed before teardown #1495

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class BaseTcpServerHandler(Work[T]):
a. handle_data(data: memoryview) implementation
b. Optionally, also implement other Work method
e.g. initialize, is_inactive, shutdown
c. Optionally, override has_buffer method to avoid
shutting down the connection unless additional
buffers are also cleared up.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -135,6 +138,24 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
"""Optionally return True to close client connection."""
pass # pragma: no cover

def has_external_buffer(self) -> bool:
"""BaseTcpServerHandler makes sure that Work.buffers are flushed before shutting down.

Example, HttpProtocolHandler uses BaseTcpServerHandler with TcpClientConnection as the
Work class. So, automagically, BaseTcpServerHandler implementation makes sure that
pending TcpClientConnection buffers are flushed to the clients before tearing down
the connection.

But, imagine reverse proxy scenario where ReverseProxy has also opened an upstream
TcpServerConnection object. ReverseProxy would also want any pending buffer for
upstream to flush out before tearing down the connection. For such scenarios,
you must override the has_buffer to incorporate upstream buffers in the logic.
"""
return False

def has_buffer(self) -> bool:
return self.work.has_buffer() or self.has_external_buffer()

async def get_events(self) -> SelectableEvents:
events = {}
# We always want to read from client
Expand All @@ -143,7 +164,7 @@ async def get_events(self) -> SelectableEvents:
events[self.work.connection.fileno()] = selectors.EVENT_READ
# If there is pending buffer for client
# also register for EVENT_WRITE events
if self.work.has_buffer():
if self.has_buffer():
if self.work.connection.fileno() in events:
events[self.work.connection.fileno()] |= selectors.EVENT_WRITE
else:
Expand Down Expand Up @@ -174,8 +195,7 @@ async def handle_writables(self, writables: Writables) -> bool:
'Flushing buffer to client {0}'.format(self.work.address),
)
self.work.flush(self.flags.max_sendbuf_size)
if self.must_flush_before_shutdown is True and \
not self.work.has_buffer():
if self.must_flush_before_shutdown is True and not self.has_buffer():
teardown = True
self.must_flush_before_shutdown = False
return teardown
Expand Down Expand Up @@ -214,7 +234,7 @@ async def handle_readables(self, readables: Readables) -> bool:
self.work.address,
),
)
if self.work.has_buffer():
if self.has_buffer():
logger.debug(
'Client {0} has pending buffer, will be flushed before shutting down'.format(
self.work.address,
Expand Down
10 changes: 6 additions & 4 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def initialize(self) -> None:
)

def is_inactive(self) -> bool:
if not self.work.has_buffer() and \
self._connection_inactive_for() > self.flags.timeout:
if (
not self.work.has_buffer()
and self._connection_inactive_for() > self.flags.timeout
):
return True
return False

Expand All @@ -87,8 +89,8 @@ def shutdown(self) -> None:
if self.plugin:
self.plugin.on_client_connection_close()
logger.debug(
'Closing client connection %s has buffer %s' %
(self.work.address, self.work.has_buffer()),
'Closing client connection %s has buffer %s'
% (self.work.address, self.work.has_buffer()),
)
conn = self.work.connection
# Unwrap if wrapped before shutdown.
Expand Down
3 changes: 2 additions & 1 deletion proxy/plugin/cache/cache_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

br_installed = False
try:
import brotli
import brotli # type: ignore[import-untyped]

br_installed = True
except ModuleNotFoundError:
pass
Expand Down
5 changes: 4 additions & 1 deletion proxy/plugin/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]:
# A static route
(
r'/get$',
[b'http://httpbingo.org/get', b'https://httpbingo.org/get'],
[
b'http://httpbingo.org/get',
b'https://httpbingo.org/get',
],
),
# A dynamic route to catch requests on "/get/<int>""
# See "handle_route" method below for what we do when
Expand Down
2 changes: 1 addition & 1 deletion proxy/socks/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def create(*args: Any) -> SocksClientConnection:
return SocksClientConnection(*args) # pragma: no cover

def handle_data(self, data: memoryview) -> Optional[bool]:
return super().handle_data(data) # pragma: no cover
pass # pragma: no cover
90 changes: 89 additions & 1 deletion tests/http/web/test_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,25 @@
import gzip
import tempfile
import selectors
from typing import Any
from typing import Any, cast

import pytest
from unittest import mock

from pytest_mock import MockerFixture

from proxy.http import HttpProtocolHandler, HttpClientConnection
from proxy.http.url import Url
from proxy.common.flag import FlagParser
from proxy.http.parser import HttpParser, httpParserTypes, httpParserStates
from proxy.common.utils import bytes_, build_http_request, build_http_response
from proxy.common.plugins import Plugins
from proxy.http.responses import NOT_FOUND_RESPONSE_PKT
from proxy.http.server.web import HttpWebServerPlugin
from proxy.common.constants import (
CRLF, PROXY_PY_DIR, PLUGIN_PAC_FILE, PLUGIN_HTTP_PROXY, PLUGIN_WEB_SERVER,
)
from proxy.http.server.reverse import ReverseProxy
from ...test_assertions import Assertions


Expand Down Expand Up @@ -384,3 +388,87 @@
self.protocol_handler.work.buffer[0],
NOT_FOUND_RESPONSE_PKT,
)


class TestThreadedReverseProxyPlugin(Assertions):

@pytest.fixture(autouse=True) # type: ignore[misc]
def _setUp(self, mocker: MockerFixture) -> None:
self.mock_socket = mocker.patch('socket.socket')
self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd)
self.mock_selector = mocker.patch('selectors.DefaultSelector')
self.fileno = 10
self._addr = ('127.0.0.1', 54382)
self._conn = self.mock_socket.return_value
self.flags = FlagParser.initialize(
[
'--enable-reverse-proxy',
],
threaded=True,
plugins=[
b'proxy.plugin.ReverseProxyPlugin',
],
)
self.protocol_handler = HttpProtocolHandler(
HttpClientConnection(self._conn, self._addr),
flags=self.flags,
)
self.protocol_handler.initialize()
# Assert reverse proxy has loaded successfully
self.assertEqual(
self.protocol_handler.flags.plugins[b'HttpWebServerBasePlugin'][0].__name__,
'ReverseProxy',
)
# Assert reverse proxy plugins have loaded successfully
self.assertEqual(
self.protocol_handler.flags.plugins[b'ReverseProxyBasePlugin'][0].__name__,
'ReverseProxyPlugin',
)

@pytest.mark.asyncio # type: ignore[misc]
@mock.patch('proxy.core.connection.server.ssl.create_default_context')
async def test_reverse_proxy_works(
self,
mock_create_default_context: mock.MagicMock,
) -> None:
self.mock_selector.return_value.select.return_value = [
(
selectors.SelectorKey(
fileobj=self._conn.fileno(),
fd=self._conn.fileno(),
events=selectors.EVENT_READ,
data=None,
),
selectors.EVENT_READ,
),
]
self._conn.recv.return_value = CRLF.join(
[
b'GET /get HTTP/1.1',
CRLF,
],
)
await self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.request.state,
httpParserStates.COMPLETE,
)
assert (
self.protocol_handler.plugin is not None
and self.protocol_handler.plugin.__class__.__name__ == 'HttpWebServerPlugin'
)
rproxy = cast(
ReverseProxy,
cast(HttpWebServerPlugin, self.protocol_handler.plugin).route,
)
choice = str(cast(Url, rproxy.choice))
options = ('http://httpbingo.org/get', 'https://httpbingo.org/get')
is_https = choice == options[1]
if is_https:
mock_create_default_context.assert_called_once()

Check warning on line 468 in tests/http/web/test_web_server.py

View check run for this annotation

Codecov / codecov/patch

tests/http/web/test_web_server.py#L468

Added line #L468 was not covered by tests
self.assertEqual(choice in options, True)
upstream = rproxy.upstream
self.assertEqual(upstream.__class__.__name__, 'TcpServerConnection')
assert upstream
self.assertEqual(upstream.addr, ('httpbingo.org', 80 if not is_https else 443))
self.assertEqual(upstream.has_buffer(), True)
Loading