Skip to content

Commit

Permalink
add socket-load-balance flag
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Sep 29, 2024
1 parent c7668ce commit c873d00
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
53 changes: 48 additions & 5 deletions uvicorn/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,57 @@
import os
import sys
from multiprocessing.context import SpawnProcess
from socket import socket
import socket
from typing import Callable

from uvicorn.config import Config
import sys

multiprocessing.allow_connection_pickling()
spawn = multiprocessing.get_context("spawn")

class SocketSharePickle:
def __init__(self, sock: socket.socket):
self._sock = sock

def get(self) -> socket.socket:
return self._sock

if (sys.platform == "linux" and hasattr(socket, "SO_REUSEPORT")) or hasattr(socket, "SO_REUSEPORT_LB"):

class SocketShareRebind:

def __init__(self, sock: socket.socket):
sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1)
self._family = sock.family
self._sockname = sock.getsockname()

def get(self) -> socket.socket:
try:
sock = socket.socket(family=self._family)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1)

sock.bind(self._sockname)
return sock
except BaseException:
sock.close()
raise


else:
class SocketShareRebind:
def __init__(self, sock: socket.socket):
raise RuntimeError("socket_load_balance not supported")

def get(self) -> socket.socket:
raise RuntimeError("socket_load_balance not supported")


def get_subprocess(
config: Config,
target: Callable[..., None],
sockets: list[socket],
sockets: list[socket.socket],
) -> SpawnProcess:
"""
Called in the parent process, to instantiate a new child process instance.
Expand All @@ -41,10 +79,15 @@ def get_subprocess(
except (AttributeError, OSError):
stdin_fileno = None

socket_shares: list[SocketShareRebind] | list[SocketSharePickle]
if config.socket_load_balance:
socket_shares = [SocketShareRebind(s) for s in sockets]
else:
socket_shares = [SocketSharePickle(s) for s in sockets]
kwargs = {
"config": config,
"target": target,
"sockets": sockets,
"sockets": socket_shares,
"stdin_fileno": stdin_fileno,
}

Expand All @@ -54,7 +97,7 @@ def get_subprocess(
def subprocess_started(
config: Config,
target: Callable[..., None],
sockets: list[socket],
sockets: list[SocketSharePickle] | list[SocketShareRebind],
stdin_fileno: int | None,
) -> None:
"""
Expand All @@ -77,7 +120,7 @@ def subprocess_started(

try:
# Now we can call into `Server.run(sockets=sockets)`
target(sockets=sockets)
target(sockets=[s.get() for s in sockets])
except KeyboardInterrupt: # pragma: no cover
# supress the exception to avoid a traceback from subprocess.Popen
# the parent already expects us to end, so no vital information is lost
Expand Down
2 changes: 2 additions & 0 deletions uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
headers: list[tuple[str, str]] | None = None,
factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
socket_load_balance: bool = False,
):
self.app = app
self.host = host
Expand Down Expand Up @@ -268,6 +269,7 @@ def __init__(
self.encoded_headers: list[tuple[bytes, bytes]] = []
self.factory = factory
self.h11_max_incomplete_event_size = h11_max_incomplete_event_size
self.socket_load_balance = socket_load_balance

self.loaded = False
self.configure_logging()
Expand Down
19 changes: 15 additions & 4 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
help="Treat APP as an application factory, i.e. a () -> <ASGI app> callable.",
show_default=True,
)
@click.option(
"--socket-load-balance",
is_flag=True,
default=False,
help="Use kernel support for socket load balancing",
show_default=True,
)
def main(
app: str,
host: str,
Expand Down Expand Up @@ -408,6 +415,7 @@ def main(
app_dir: str,
h11_max_incomplete_event_size: int | None,
factory: bool,
socket_load_balance: bool = False,
) -> None:
run(
app,
Expand Down Expand Up @@ -457,6 +465,7 @@ def main(
factory=factory,
app_dir=app_dir,
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
socket_load_balance=socket_load_balance,
)


Expand Down Expand Up @@ -509,6 +518,7 @@ def run(
app_dir: str | None = None,
factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
socket_load_balance: bool = False,
) -> None:
if app_dir is not None:
sys.path.insert(0, app_dir)
Expand Down Expand Up @@ -560,6 +570,7 @@ def run(
use_colors=use_colors,
factory=factory,
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
socket_load_balance=socket_load_balance,
)
server = Server(config=config)

Expand All @@ -570,11 +581,11 @@ def run(

try:
if config.should_reload:
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
with config.bind_socket() as sock:
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
Multiprocess(config, target=server.run, sockets=[sock]).run()
with config.bind_socket() as sock:
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()
except KeyboardInterrupt:
Expand Down

0 comments on commit c873d00

Please sign in to comment.