Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(CORS): set allow private network header #2383

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

from typing import Any, Iterable, Optional, Union

from .request import Request
from .response import Response


class CORSMiddleware(object):
"""CORS Middleware.

This middleware provides a simple out-of-the box CORS policy, including handling
of preflighted requests from the browser.

See also:

* https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
* https://www.w3.org/TR/cors/#resource-processing-model

Keyword Arguments:
allow_origins (Union[str, Iterable[str]]): List of origins to allow (case
sensitive). The string ``'*'`` acts as a wildcard, matching every origin.
(default ``'*'``).
expose_headers (Optional[Union[str, Iterable[str]]]): List of additional
response headers to expose via the ``Access-Control-Expose-Headers``
header. These headers are in addition to the CORS-safelisted ones:
``Cache-Control``, ``Content-Language``, ``Content-Length``,
``Content-Type``, ``Expires``, ``Last-Modified``, ``Pragma``.
(default ``None``).

See also:
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
allow_credentials (Optional[Union[str, Iterable[str]]]): List of origins
(case sensitive) for which to allow credentials via the
``Access-Control-Allow-Credentials`` header.
The string ``'*'`` acts as a wildcard, matching every allowed origin,
while ``None`` disallows all origins. This parameter takes effect only
if the origin is allowed by the ``allow_origins`` argument.
(Default ``None``).

"""

def __init__(
self,
allow_origins: Union[str, Iterable[str]] = '*',
expose_headers: Optional[Union[str, Iterable[str]]] = None,
allow_credentials: Optional[Union[str, Iterable[str]]] = None,
allow_private_network: bool = False,
):

if allow_origins == '*':
self.allow_origins = allow_origins
else:
if isinstance(allow_origins, str):
allow_origins = [allow_origins]
self.allow_origins = frozenset(allow_origins)
if '*' in self.allow_origins:
raise ValueError(
'The wildcard string "*" may only be passed to allow_origins as a '
'string literal, not inside an iterable.'
)

if expose_headers is not None and not isinstance(expose_headers, str):
expose_headers = ', '.join(expose_headers)
self.expose_headers = expose_headers

if allow_credentials is None:
allow_credentials = frozenset()
elif allow_credentials != '*':
if isinstance(allow_credentials, str):
allow_credentials = [allow_credentials]
allow_credentials = frozenset(allow_credentials)
if '*' in allow_credentials:
raise ValueError(
'The wildcard string "*" may only be passed to allow_credentials '
'as a string literal, not inside an iterable.'
)
self.allow_credentials = allow_credentials

self.allow_private_network = allow_private_network

def process_response(
self, req: Request, resp: Response, resource: object, req_succeeded: bool
) -> None:
"""Implement the CORS policy for all routes.

This middleware provides a simple out-of-the box CORS policy,
including handling of preflighted requests from the browser.

See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS

See also: https://www.w3.org/TR/cors/#resource-processing-model
"""

origin = req.get_header('Origin')
if origin is None:
return

if self.allow_origins != '*' and origin not in self.allow_origins:
return

if resp.get_header('Access-Control-Allow-Origin') is None:
set_origin = '*' if self.allow_origins == '*' else origin
if self.allow_credentials == '*' or origin in self.allow_credentials:
set_origin = origin
resp.set_header('Access-Control-Allow-Credentials', 'true')
resp.set_header('Access-Control-Allow-Origin', set_origin)

if self.expose_headers:
resp.set_header('Access-Control-Expose-Headers', self.expose_headers)

if (
req_succeeded
and req.method == 'OPTIONS'
and req.get_header('Access-Control-Request-Method')
):
# NOTE(kgriffs): This is a CORS preflight request. Patch the
# response accordingly.

allow = resp.get_header('Allow')
resp.delete_header('Allow')

allow_headers = req.get_header(
'Access-Control-Request-Headers', default='*'
)

resp.set_header('Access-Control-Allow-Methods', allow)
resp.set_header('Access-Control-Allow-Headers', allow_headers)
resp.set_header('Access-Control-Max-Age', '86400') # 24 hours

if self.allow_private_network and req.get_header('Access-Control-Request-Private-Network') == 'true':
resp.set_header('Access-Control-Allow-Private-Network', 'true')


async def process_response_async(self, *args: Any) -> None:
self.process_response(*args)
Loading