Skip to content

Commit

Permalink
Merge pull request #13 from justmobilize/close-all-and-counts
Browse files Browse the repository at this point in the history
Close all and counts
  • Loading branch information
dhalbert authored Apr 28, 2024
2 parents fc33375 + 601ee66 commit 1531496
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 141 deletions.
207 changes: 122 additions & 85 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


if not sys.implementation.name == "circuitpython":
from typing import Optional, Tuple
from typing import List, Optional, Tuple

from circuitpython_typing.socket import (
CircuitPythonSocketType,
Expand Down Expand Up @@ -64,15 +64,14 @@ def connect(self, address: Tuple[str, int]) -> None:
try:
return self._socket.connect(address, self._mode)
except RuntimeError as error:
raise OSError(errno.ENOMEM) from error
raise OSError(errno.ENOMEM, str(error)) from error


class _FakeSSLContext:
def __init__(self, iface: InterfaceType) -> None:
self._iface = iface

# pylint: disable=unused-argument
def wrap_socket(
def wrap_socket( # pylint: disable=unused-argument
self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None
) -> _FakeSSLSocket:
"""Return the same socket"""
Expand All @@ -99,7 +98,8 @@ def create_fake_ssl_context(
return _FakeSSLContext(iface)


_global_socketpool = {}
_global_connection_managers = {}
_global_socketpools = {}
_global_ssl_contexts = {}


Expand All @@ -113,7 +113,7 @@ def get_radio_socketpool(radio):
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
if class_name not in _global_socketpool:
if class_name not in _global_socketpools:
if class_name == "Radio":
import ssl # pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -151,10 +151,10 @@ def get_radio_socketpool(radio):
else:
raise AttributeError(f"Unsupported radio class: {class_name}")

_global_socketpool[class_name] = pool
_global_socketpools[class_name] = pool
_global_ssl_contexts[class_name] = ssl_context

return _global_socketpool[class_name]
return _global_socketpools[class_name]


def get_radio_ssl_context(radio):
Expand Down Expand Up @@ -183,42 +183,75 @@ def __init__(
) -> None:
self._socket_pool = socket_pool
# Hang onto open sockets so that we can reuse them.
self._available_socket = {}
self._open_sockets = {}

def _free_sockets(self) -> None:
available_sockets = []
for socket, free in self._available_socket.items():
if free:
available_sockets.append(socket)
self._available_sockets = set()
self._key_by_managed_socket = {}
self._managed_socket_by_key = {}

def _free_sockets(self, force: bool = False) -> None:
# cloning lists since items are being removed
available_sockets = list(self._available_sockets)
for socket in available_sockets:
self.close_socket(socket)
if force:
open_sockets = list(self._managed_socket_by_key.values())
for socket in open_sockets:
self.close_socket(socket)

def _get_key_for_socket(self, socket):
def _get_connected_socket( # pylint: disable=too-many-arguments
self,
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
host: str,
port: int,
timeout: float,
is_ssl: bool,
ssl_context: Optional[SSLContextType] = None,
):
try:
return next(
key for key, value in self._open_sockets.items() if value == socket
)
except StopIteration:
return None
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except (OSError, RuntimeError) as exc:
return exc

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

try:
socket.connect((connect_host, port))
except (MemoryError, OSError) as exc:
socket.close()
return exc

return socket

@property
def available_socket_count(self) -> int:
"""Get the count of freeable open sockets"""
return len(self._available_sockets)

@property
def managed_socket_count(self) -> int:
"""Get the count of open sockets"""
return len(self._managed_socket_by_key)

def close_socket(self, socket: SocketType) -> None:
"""Close a previously opened socket."""
if socket not in self._open_sockets.values():
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
key = self._get_key_for_socket(socket)
socket.close()
del self._available_socket[socket]
del self._open_sockets[key]
key = self._key_by_managed_socket.pop(socket)
del self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)

def free_socket(self, socket: SocketType) -> None:
"""Mark a previously opened socket as available so it can be reused if needed."""
if socket not in self._open_sockets.values():
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
self._available_socket[socket] = True
self._available_sockets.add(socket)

# pylint: disable=too-many-branches,too-many-locals,too-many-statements
def get_socket(
self,
host: str,
Expand All @@ -234,10 +267,10 @@ def get_socket(
if session_id:
session_id = str(session_id)
key = (host, port, proto, session_id)
if key in self._open_sockets:
socket = self._open_sockets[key]
if self._available_socket[socket]:
self._available_socket[socket] = False
if key in self._managed_socket_by_key:
socket = self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)
return socket

raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
Expand All @@ -253,64 +286,68 @@ def get_socket(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

try_count = 0
socket = None
last_exc = None
while try_count < 2 and socket is None:
try_count += 1
if try_count > 1:
if any(
socket
for socket, free in self._available_socket.items()
if free is True
):
self._free_sockets()
else:
break

try:
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError as exc:
last_exc = exc
continue
except RuntimeError as exc:
last_exc = exc
continue

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

try:
socket.connect((connect_host, port))
except MemoryError as exc:
last_exc = exc
socket.close()
socket = None
except OSError as exc:
last_exc = exc
socket.close()
socket = None

if socket is None:
raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc

self._available_socket[socket] = False
self._open_sockets[key] = socket
return socket
first_exception = None
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
# Got an error, if there are any available sockets, free them and try again
if self.available_socket_count:
first_exception = result
self._free_sockets()
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
last_result = f", first error: {first_exception}" if first_exception else ""
raise RuntimeError(
f"Error connecting socket: {result}{last_result}"
) from result

self._key_by_managed_socket[result] = key
self._managed_socket_by_key[key] = result
return result


# global helpers


_global_connection_manager = {}
def connection_manager_close_all(
socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False
) -> None:
"""Close all open sockets for pool"""
if socket_pool:
socket_pools = [socket_pool]
else:
socket_pools = _global_connection_managers.keys()

for pool in socket_pools:
connection_manager = _global_connection_managers.get(pool, None)
if connection_manager is None:
raise RuntimeError("SocketPool not managed")

connection_manager._free_sockets(force=True) # pylint: disable=protected-access

if release_references:
radio_key = None
for radio_check, pool_check in _global_socketpools.items():
if pool == pool_check:
radio_key = radio_check
break

if radio_key:
if radio_key in _global_socketpools:
del _global_socketpools[radio_key]

if radio_key in _global_ssl_contexts:
del _global_ssl_contexts[radio_key]

if pool in _global_connection_managers:
del _global_connection_managers[pool]


def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
"""Get the ConnectionManager singleton for the given pool"""
if socket_pool not in _global_connection_manager:
_global_connection_manager[socket_pool] = ConnectionManager(socket_pool)
return _global_connection_manager[socket_pool]
if socket_pool not in _global_connection_managers:
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
return _global_connection_managers[socket_pool]
32 changes: 28 additions & 4 deletions examples/connectionmanager_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,38 @@

# get request session
requests = adafruit_requests.Session(pool, ssl_context)
connection_manager = adafruit_connection_manager.get_connection_manager(pool)
print("-" * 40)
print("Nothing yet opened")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

# make request
print("-" * 40)
print(f"Fetching from {TEXT_URL}")
print(f"Fetching from {TEXT_URL} in a context handler")
with requests.get(TEXT_URL) as response:
response_text = response.text
print(f"Text Response {response_text}")

print("-" * 40)
print("1 request, opened and freed")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

print("-" * 40)
print(f"Fetching from {TEXT_URL} not in a context handler")
response = requests.get(TEXT_URL)
response_text = response.text
response.close()

print(f"Text Response {response_text}")
print("-" * 40)
print("1 request, opened but not freed")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")

print("-" * 40)
print("Closing everything in the pool")
adafruit_connection_manager.connection_manager_close_all(pool)

print("-" * 40)
print("Everything closed")
print(f"Open Sockets: {connection_manager.managed_socket_count}")
print(f"Freeable Open Sockets: {connection_manager.available_socket_count}")
8 changes: 4 additions & 4 deletions tests/close_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def test_close_socket():
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
key = (mocket.MOCK_HOST_1, 80, "http:", None)
assert socket == mock_socket_1
assert socket in connection_manager._available_socket
assert key in connection_manager._open_sockets
assert socket not in connection_manager._available_sockets
assert key in connection_manager._managed_socket_by_key

# validate socket is no longer tracked
connection_manager.close_socket(socket)
assert socket not in connection_manager._available_socket
assert key not in connection_manager._open_sockets
assert socket not in connection_manager._available_sockets
assert key not in connection_manager._managed_socket_by_key


def test_close_socket_not_managed():
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module():
@pytest.fixture(autouse=True)
def reset_connection_manager(monkeypatch):
monkeypatch.setattr(
"adafruit_connection_manager._global_socketpool",
"adafruit_connection_manager._global_connection_managers",
{},
)
monkeypatch.setattr(
"adafruit_connection_manager._global_socketpools",
{},
)
monkeypatch.setattr(
Expand Down
Loading

0 comments on commit 1531496

Please sign in to comment.