diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index 4d32949d55..48b4b54b79 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -111,6 +111,9 @@ class BaseTcpServerHandler(Work[T]): a. handle_data(data: memoryview) implementation b. Optionally, also implement other Work method e.g. initialize, is_inactive, shutdown + c. Optionally, override has_buffer method to avoid + shutting down the connection unless additional + buffers are also cleared up. """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -135,6 +138,24 @@ def handle_data(self, data: memoryview) -> Optional[bool]: """Optionally return True to close client connection.""" pass # pragma: no cover + def has_external_buffer(self) -> bool: + """BaseTcpServerHandler makes sure that Work.buffers are flushed before shutting down. + + Example, HttpProtocolHandler uses BaseTcpServerHandler with TcpClientConnection as the + Work class. So, automagically, BaseTcpServerHandler implementation makes sure that + pending TcpClientConnection buffers are flushed to the clients before tearing down + the connection. + + But, imagine reverse proxy scenario where ReverseProxy has also opened an upstream + TcpServerConnection object. ReverseProxy would also want any pending buffer for + upstream to flush out before tearing down the connection. For such scenarios, + you must override the has_buffer to incorporate upstream buffers in the logic. + """ + return False + + def has_buffer(self) -> bool: + return self.work.has_buffer() or self.has_external_buffer() + async def get_events(self) -> SelectableEvents: events = {} # We always want to read from client @@ -143,7 +164,7 @@ async def get_events(self) -> SelectableEvents: events[self.work.connection.fileno()] = selectors.EVENT_READ # If there is pending buffer for client # also register for EVENT_WRITE events - if self.work.has_buffer(): + if self.has_buffer(): if self.work.connection.fileno() in events: events[self.work.connection.fileno()] |= selectors.EVENT_WRITE else: @@ -174,8 +195,7 @@ async def handle_writables(self, writables: Writables) -> bool: 'Flushing buffer to client {0}'.format(self.work.address), ) self.work.flush(self.flags.max_sendbuf_size) - if self.must_flush_before_shutdown is True and \ - not self.work.has_buffer(): + if self.must_flush_before_shutdown is True and not self.has_buffer(): teardown = True self.must_flush_before_shutdown = False return teardown @@ -214,7 +234,7 @@ async def handle_readables(self, readables: Readables) -> bool: self.work.address, ), ) - if self.work.has_buffer(): + if self.has_buffer(): logger.debug( 'Client {0} has pending buffer, will be flushed before shutting down'.format( self.work.address, diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 5120d5b32c..d8f892da78 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -70,8 +70,10 @@ def initialize(self) -> None: ) def is_inactive(self) -> bool: - if not self.work.has_buffer() and \ - self._connection_inactive_for() > self.flags.timeout: + if ( + not self.work.has_buffer() + and self._connection_inactive_for() > self.flags.timeout + ): return True return False @@ -87,8 +89,8 @@ def shutdown(self) -> None: if self.plugin: self.plugin.on_client_connection_close() logger.debug( - 'Closing client connection %s has buffer %s' % - (self.work.address, self.work.has_buffer()), + 'Closing client connection %s has buffer %s' + % (self.work.address, self.work.has_buffer()), ) conn = self.work.connection # Unwrap if wrapped before shutdown. diff --git a/proxy/plugin/cache/cache_responses.py b/proxy/plugin/cache/cache_responses.py index 51fea3a1aa..ab963ad376 100644 --- a/proxy/plugin/cache/cache_responses.py +++ b/proxy/plugin/cache/cache_responses.py @@ -23,7 +23,8 @@ br_installed = False try: - import brotli + import brotli # type: ignore[import-untyped] + br_installed = True except ModuleNotFoundError: pass diff --git a/proxy/plugin/reverse_proxy.py b/proxy/plugin/reverse_proxy.py index 7b7a5a4b38..d81eabf0d1 100644 --- a/proxy/plugin/reverse_proxy.py +++ b/proxy/plugin/reverse_proxy.py @@ -45,7 +45,10 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]: # A static route ( r'/get$', - [b'http://httpbingo.org/get', b'https://httpbingo.org/get'], + [ + b'http://httpbingo.org/get', + b'https://httpbingo.org/get', + ], ), # A dynamic route to catch requests on "/get/"" # See "handle_route" method below for what we do when diff --git a/proxy/socks/handler.py b/proxy/socks/handler.py index ab501ebe25..b2f07fbd2b 100644 --- a/proxy/socks/handler.py +++ b/proxy/socks/handler.py @@ -25,4 +25,4 @@ def create(*args: Any) -> SocksClientConnection: return SocksClientConnection(*args) # pragma: no cover def handle_data(self, data: memoryview) -> Optional[bool]: - return super().handle_data(data) # pragma: no cover + pass # pragma: no cover diff --git a/tests/http/web/test_web_server.py b/tests/http/web/test_web_server.py index 8100d995bd..11a025deb0 100644 --- a/tests/http/web/test_web_server.py +++ b/tests/http/web/test_web_server.py @@ -12,21 +12,25 @@ import gzip import tempfile import selectors -from typing import Any +from typing import Any, cast import pytest +from unittest import mock from pytest_mock import MockerFixture from proxy.http import HttpProtocolHandler, HttpClientConnection +from proxy.http.url import Url from proxy.common.flag import FlagParser from proxy.http.parser import HttpParser, httpParserTypes, httpParserStates from proxy.common.utils import bytes_, build_http_request, build_http_response from proxy.common.plugins import Plugins from proxy.http.responses import NOT_FOUND_RESPONSE_PKT +from proxy.http.server.web import HttpWebServerPlugin from proxy.common.constants import ( CRLF, PROXY_PY_DIR, PLUGIN_PAC_FILE, PLUGIN_HTTP_PROXY, PLUGIN_WEB_SERVER, ) +from proxy.http.server.reverse import ReverseProxy from ...test_assertions import Assertions @@ -384,3 +388,87 @@ async def test_default_web_server_returns_404(self) -> None: self.protocol_handler.work.buffer[0], NOT_FOUND_RESPONSE_PKT, ) + + +class TestThreadedReverseProxyPlugin(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self._conn = self.mock_socket.return_value + self.flags = FlagParser.initialize( + [ + '--enable-reverse-proxy', + ], + threaded=True, + plugins=[ + b'proxy.plugin.ReverseProxyPlugin', + ], + ) + self.protocol_handler = HttpProtocolHandler( + HttpClientConnection(self._conn, self._addr), + flags=self.flags, + ) + self.protocol_handler.initialize() + # Assert reverse proxy has loaded successfully + self.assertEqual( + self.protocol_handler.flags.plugins[b'HttpWebServerBasePlugin'][0].__name__, + 'ReverseProxy', + ) + # Assert reverse proxy plugins have loaded successfully + self.assertEqual( + self.protocol_handler.flags.plugins[b'ReverseProxyBasePlugin'][0].__name__, + 'ReverseProxyPlugin', + ) + + @pytest.mark.asyncio # type: ignore[misc] + @mock.patch('proxy.core.connection.server.ssl.create_default_context') + async def test_reverse_proxy_works( + self, + mock_create_default_context: mock.MagicMock, + ) -> None: + self.mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] + self._conn.recv.return_value = CRLF.join( + [ + b'GET /get HTTP/1.1', + CRLF, + ], + ) + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.request.state, + httpParserStates.COMPLETE, + ) + assert ( + self.protocol_handler.plugin is not None + and self.protocol_handler.plugin.__class__.__name__ == 'HttpWebServerPlugin' + ) + rproxy = cast( + ReverseProxy, + cast(HttpWebServerPlugin, self.protocol_handler.plugin).route, + ) + choice = str(cast(Url, rproxy.choice)) + options = ('http://httpbingo.org/get', 'https://httpbingo.org/get') + is_https = choice == options[1] + if is_https: + mock_create_default_context.assert_called_once() + self.assertEqual(choice in options, True) + upstream = rproxy.upstream + self.assertEqual(upstream.__class__.__name__, 'TcpServerConnection') + assert upstream + self.assertEqual(upstream.addr, ('httpbingo.org', 80 if not is_https else 443)) + self.assertEqual(upstream.has_buffer(), True)