diff --git a/uvicorn/_subprocess.py b/uvicorn/_subprocess.py index 1c06844de5..df1d22eb36 100644 --- a/uvicorn/_subprocess.py +++ b/uvicorn/_subprocess.py @@ -9,19 +9,57 @@ import os import sys from multiprocessing.context import SpawnProcess -from socket import socket +import socket from typing import Callable from uvicorn.config import Config +import sys multiprocessing.allow_connection_pickling() spawn = multiprocessing.get_context("spawn") +class SocketSharePickle: + def __init__(self, sock: socket.socket): + self._sock = sock + + def get(self) -> socket.socket: + return self._sock + +if (sys.platform == "linux" and hasattr(socket, "SO_REUSEPORT")) or hasattr(socket, "SO_REUSEPORT_LB"): + + class SocketShareRebind: + + def __init__(self, sock: socket.socket): + sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1) + self._family = sock.family + self._sockname = sock.getsockname() + + def get(self) -> socket.socket: + try: + sock = socket.socket(family=self._family) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1) + + sock.bind(self._sockname) + return sock + except BaseException: + sock.close() + raise + + +else: + class SocketShareRebind: + def __init__(self, sock: socket.socket): + raise RuntimeError("socket_load_balance not supported") + + def get(self) -> socket.socket: + raise RuntimeError("socket_load_balance not supported") + def get_subprocess( config: Config, target: Callable[..., None], - sockets: list[socket], + sockets: list[socket.socket], ) -> SpawnProcess: """ Called in the parent process, to instantiate a new child process instance. @@ -41,10 +79,15 @@ def get_subprocess( except (AttributeError, OSError): stdin_fileno = None + socket_shares: list[SocketShareRebind] | list[SocketSharePickle] + if config.socket_load_balance: + socket_shares = [SocketShareRebind(s) for s in sockets] + else: + socket_shares = [SocketSharePickle(s) for s in sockets] kwargs = { "config": config, "target": target, - "sockets": sockets, + "sockets": socket_shares, "stdin_fileno": stdin_fileno, } @@ -54,7 +97,7 @@ def get_subprocess( def subprocess_started( config: Config, target: Callable[..., None], - sockets: list[socket], + sockets: list[SocketSharePickle] | list[SocketShareRebind], stdin_fileno: int | None, ) -> None: """ @@ -77,7 +120,7 @@ def subprocess_started( try: # Now we can call into `Server.run(sockets=sockets)` - target(sockets=sockets) + target(sockets=[s.get() for s in sockets]) except KeyboardInterrupt: # pragma: no cover # supress the exception to avoid a traceback from subprocess.Popen # the parent already expects us to end, so no vital information is lost diff --git a/uvicorn/config.py b/uvicorn/config.py index 9aff8c968e..54a2f9b3fe 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -223,6 +223,7 @@ def __init__( headers: list[tuple[str, str]] | None = None, factory: bool = False, h11_max_incomplete_event_size: int | None = None, + socket_load_balance: bool = False, ): self.app = app self.host = host @@ -268,6 +269,7 @@ def __init__( self.encoded_headers: list[tuple[bytes, bytes]] = [] self.factory = factory self.h11_max_incomplete_event_size = h11_max_incomplete_event_size + self.socket_load_balance = socket_load_balance self.loaded = False self.configure_logging() diff --git a/uvicorn/main.py b/uvicorn/main.py index 43956622db..755fb4a06b 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -360,6 +360,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No help="Treat APP as an application factory, i.e. a () -> callable.", show_default=True, ) +@click.option( + "--socket-load-balance", + is_flag=True, + default=False, + help="Use kernel support for socket load balancing", + show_default=True, +) def main( app: str, host: str, @@ -408,6 +415,7 @@ def main( app_dir: str, h11_max_incomplete_event_size: int | None, factory: bool, + socket_load_balance: bool = False, ) -> None: run( app, @@ -457,6 +465,7 @@ def main( factory=factory, app_dir=app_dir, h11_max_incomplete_event_size=h11_max_incomplete_event_size, + socket_load_balance=socket_load_balance, ) @@ -509,6 +518,7 @@ def run( app_dir: str | None = None, factory: bool = False, h11_max_incomplete_event_size: int | None = None, + socket_load_balance: bool = False, ) -> None: if app_dir is not None: sys.path.insert(0, app_dir) @@ -560,6 +570,7 @@ def run( use_colors=use_colors, factory=factory, h11_max_incomplete_event_size=h11_max_incomplete_event_size, + socket_load_balance=socket_load_balance, ) server = Server(config=config) @@ -570,11 +581,11 @@ def run( try: if config.should_reload: - sock = config.bind_socket() - ChangeReload(config, target=server.run, sockets=[sock]).run() + with config.bind_socket() as sock: + ChangeReload(config, target=server.run, sockets=[sock]).run() elif config.workers > 1: - sock = config.bind_socket() - Multiprocess(config, target=server.run, sockets=[sock]).run() + with config.bind_socket() as sock: + Multiprocess(config, target=server.run, sockets=[sock]).run() else: server.run() except KeyboardInterrupt: