-
-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathconftest.py
104 lines (81 loc) · 2.78 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import asyncio
import multiprocessing as mp
import socket
from contextlib import asynccontextmanager, closing
from functools import partial
from pathlib import Path
import pytest
from granian import Granian
def _serve(**kwargs):
server = Granian(f'tests.apps.{kwargs["interface"]}:app', **kwargs)
server.serve()
@asynccontextmanager
async def _server(interface, port, runtime_mode, tls=False, task_impl='asyncio'):
certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
kwargs = {
'interface': interface,
'port': port,
'loop': 'asyncio',
'blocking_threads': 1,
'runtime_mode': runtime_mode,
'task_impl': task_impl,
}
if tls:
if tls == 'private':
kwargs['ssl_cert'] = certs_path / 'pcert.pem'
kwargs['ssl_key'] = certs_path / 'pkey.pem'
kwargs['ssl_key_password'] = 'foobar' # noqa: S105
else:
kwargs['ssl_cert'] = certs_path / 'cert.pem'
kwargs['ssl_key'] = certs_path / 'key.pem'
succeeded, spawn_failures = False, 0
while spawn_failures < 3:
proc = mp.get_context('spawn').Process(target=_serve, kwargs=kwargs)
proc.start()
conn_failures = 0
while conn_failures < 3:
try:
await asyncio.sleep(1.5)
sock = socket.create_connection(('127.0.0.1', port), timeout=1)
sock.close()
succeeded = True
break
except Exception:
conn_failures += 1
if succeeded:
break
proc.terminate()
proc.join(timeout=2)
if proc.is_alive():
proc.kill()
spawn_failures += 1
if not succeeded:
raise RuntimeError('Cannot bind server')
try:
yield port
finally:
proc.terminate()
proc.join(timeout=2)
if proc.is_alive():
proc.kill()
@pytest.fixture(scope='function')
def server_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(('localhost', 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock.getsockname()[1]
@pytest.fixture(scope='function')
def asgi_server(server_port, **extras):
return partial(_server, 'asgi', server_port, **extras)
@pytest.fixture(scope='function')
def rsgi_server(server_port):
return partial(_server, 'rsgi', server_port)
@pytest.fixture(scope='function')
def wsgi_server(server_port):
return partial(_server, 'wsgi', server_port)
@pytest.fixture(scope='function')
def server(server_port, request):
return partial(_server, request.param, server_port)
@pytest.fixture(scope='function')
def server_tls(server_port, request):
return partial(_server, request.param, server_port, tls=True)