Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dan/per-10281-error-details-in-python-sdk-exceptions #106

Merged
merged 7 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.3.0
- id: end-of-file-fixer
- id: check-added-large-files
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-json
- id: check-toml
- id: check-yaml
- id: check-xml
- id: check-merge-conflict
- id: mixed-line-ending
args: [ --fix=lf ]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
- id: ruff
args: [--fix]
files: \.py$
types: [ file ]
- id: ruff-format
files: \.py$
types: [ file ]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
hooks:
- id: isort
- id: mypy
pass_filenames: false
additional_dependencies:
- pydantic
files: \.py$
types: [ file ]
3 changes: 2 additions & 1 deletion permit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .api.models import *
# ruff: noqa: F401
from .api.models import * # noqa: F403
from .config import PermitConfig
from .enforcement.enforcer import Action, Resource, User
from .enforcement.interfaces import (
Expand Down
123 changes: 26 additions & 97 deletions permit/api/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import functools
from contextlib import contextmanager
from typing import Callable, Iterable, Optional, Type, TypeVar, Union
from typing import Optional, Type, TypeVar, Union

import aiohttp
from aiohttp import ClientTimeout
from loguru import logger
from typing_extensions import Self

from ..utils.pydantic_version import PYDANTIC_VERSION

Expand All @@ -19,46 +16,10 @@
from .context import API_ACCESS_LEVELS, ApiContextLevel, ApiKeyAccessLevel
from .models import APIKeyScopeRead

T = TypeVar("T", bound=Callable)
TModel = TypeVar("TModel", bound=BaseModel)
TData = TypeVar("TData", bound=BaseModel)


def required_permissions(access_level: ApiKeyAccessLevel):
def decorator(func: T) -> T:
@functools.wraps(func)
async def wrapped(self: BasePermitApi, *args, **kwargs):
await self._ensure_access_level(access_level)
return await func(self, *args, **kwargs)

return wrapped

return decorator


def required_context(context: ApiContextLevel):
"""
a decorator that ensures that an API endpoint is called only after the SDK has initialized
an API context (authorization level) by inferring it from the API key or manually by the user.

Args:
call_level: The required API key level for the endpoint.

Raises:
PermitContextError: If the API context does not match the required endpoint context.
"""

def decorator(func: T) -> T:
@functools.wraps(func)
async def wrapped(self: BasePermitApi, *args, **kwargs):
await self._ensure_context(context)
return await func(self, *args, **kwargs)

return wrapped

return decorator


def pagination_params(page: int, per_page: int) -> dict:
return {"page": page, "per_page": per_page}

Expand All @@ -79,25 +40,19 @@ class SimpleHttpClient:
wraps aiohttp client to reduce boilerplace
"""

def __init__(
self, client_config: dict, base_url: str = "", timeout: Optional[int] = None
):
def __init__(self, client_config: dict, base_url: str = "", timeout: Optional[int] = None):
self._client_config = client_config
self._base_url = base_url
if timeout is not None:
self._client_config["timeout"] = ClientTimeout(total=timeout)

def _log_request(self, url: str, method: str) -> None:
logger.debug("Sending HTTP request: {} {}".format(method, url))
logger.debug(f"Sending HTTP request: {method} {url}")

def _log_response(self, url: str, method: str, status: int) -> None:
logger.debug(
"Received HTTP response: {} {}, status: {}".format(method, url, status)
)
logger.debug(f"Received HTTP response: {method} {url}, status: {status}")

def _prepare_json(
self, json: Optional[Union[TData, dict, list]] = None
) -> Optional[dict]:
def _prepare_json(self, json: Optional[Union[TData, dict, list]] = None) -> Optional[Union[dict, list]]:
if json is None:
return None

Expand Down Expand Up @@ -131,9 +86,7 @@ async def post(
url = f"{self._base_url}{url}"
async with aiohttp.ClientSession(**self._client_config) as client:
self._log_request(url, "POST")
async with client.post(
url, json=self._prepare_json(json), **kwargs
) as response:
async with client.post(url, json=self._prepare_json(json), **kwargs) as response:
await handle_api_error(response)
self._log_response(url, "POST", response.status)
data = await response.json()
Expand All @@ -150,9 +103,7 @@ async def put(
url = f"{self._base_url}{url}"
async with aiohttp.ClientSession(**self._client_config) as client:
self._log_request(url, "PUT")
async with client.put(
url, json=self._prepare_json(json), **kwargs
) as response:
async with client.put(url, json=self._prepare_json(json), **kwargs) as response:
await handle_api_error(response)
self._log_response(url, "PUT", response.status)
data = await response.json()
Expand All @@ -169,9 +120,7 @@ async def patch(
url = f"{self._base_url}{url}"
async with aiohttp.ClientSession(**self._client_config) as client:
self._log_request(url, "PATCH")
async with client.patch(
url, json=self._prepare_json(json), **kwargs
) as response:
async with client.patch(url, json=self._prepare_json(json), **kwargs) as response:
await handle_api_error(response)
self._log_response(url, "PATCH", response.status)
data = await response.json()
Expand All @@ -188,9 +137,7 @@ async def delete(
url = f"{self._base_url}{url}"
async with aiohttp.ClientSession(**self._client_config) as client:
self._log_request(url, "DELETE")
async with client.delete(
url, json=self._prepare_json(json), **kwargs
) as response:
async with client.delete(url, json=self._prepare_json(json), **kwargs) as response:
await handle_api_error(response)
self._log_response(url, "DELETE", response.status)
if model is None:
Expand All @@ -214,9 +161,7 @@ def __init__(self, config: PermitConfig):
self.config = config
self.__api_keys = self._build_http_client("/v2/api-key")

def _build_http_client(
self, endpoint_url: str = "", *, use_pdp: bool = False, **kwargs
):
def _build_http_client(self, endpoint_url: str = "", *, use_pdp: bool = False, **kwargs):
optional_headers = {}
if self.config.proxy_facts_via_pdp and self.config.facts_sync_timeout:
optional_headers["X-Wait-Timeout"] = str(self.config.facts_sync_timeout)
Expand All @@ -228,10 +173,10 @@ def _build_http_client(
**optional_headers,
},
)
client_config = client_config.dict()
client_config.update(kwargs)
client_config_dict = client_config.dict()
client_config_dict.update(kwargs)
danyi1212 marked this conversation as resolved.
Show resolved Hide resolved
return SimpleHttpClient(
client_config,
client_config_dict,
base_url=endpoint_url,
timeout=self.config.api_timeout,
)
Expand All @@ -247,14 +192,8 @@ async def _set_context_from_api_key(self) -> None:
# saves the permitted access level by that api key
self.config.api_context._save_api_key_accessible_scope(
org=str(scope.organization_id),
project=(
str(scope.project_id) if scope.project_id is not None else None
),
environment=(
str(scope.environment_id)
if scope.environment_id is not None
else None
),
project=(str(scope.project_id) if scope.project_id is not None else None),
environment=(str(scope.environment_id) if scope.environment_id is not None else None),
)

if scope.project_id is not None:
Expand All @@ -268,22 +207,16 @@ async def _set_context_from_api_key(self) -> None:
return

# Set project level context
self.config.api_context.set_project_level_context(
str(scope.organization_id), str(scope.project_id)
)
self.config.api_context.set_project_level_context(str(scope.organization_id), str(scope.project_id))
return

# Set org level context
self.config.api_context.set_organization_level_context(
str(scope.organization_id)
)
self.config.api_context.set_organization_level_context(str(scope.organization_id))
return

raise PermitContextError("Could not set API context level")

async def _ensure_access_level(
self, required_access_level: ApiKeyAccessLevel
) -> None:
async def _ensure_access_level(self, required_access_level: ApiKeyAccessLevel) -> None:
"""
Ensure that the API Key has the necessary permissions to successfully call the API endpoint.

Expand All @@ -298,8 +231,7 @@ async def _ensure_access_level(
# should only happen once in the lifetime of the sdk
if (
self.config.api_context.level == ApiContextLevel.WAIT_FOR_INIT
or self.config.api_context.permitted_access_level
== ApiKeyAccessLevel.WAIT_FOR_INIT
or self.config.api_context.permitted_access_level == ApiKeyAccessLevel.WAIT_FOR_INIT
):
await self._set_context_from_api_key()

Expand All @@ -308,18 +240,16 @@ async def _ensure_access_level(
self.config.api_context.permitted_access_level
):
raise PermitContextError(
f"You're trying to use an SDK method that requires an API Key with access level: {required_access_level}, "
+ f"however the SDK is running with an API key with level {self.config.api_context.permitted_access_level}."
f"You're trying to use an SDK method that requires an API Key "
f"with access level: {required_access_level}, however the SDK is running "
f"with an API key with level {self.config.api_context.permitted_access_level}."
)
return

if (
self.config.api_context.permitted_access_level.value
< required_access_level.value
):
if self.config.api_context.permitted_access_level.value < required_access_level.value:
raise PermitContextError(
f"You're trying to use an SDK method that requires an api context of {required_context.name}, "
+ f"however the SDK is running in a less specific context level: {self.config.api_context.level}."
f"You're trying to use an SDK method that requires an api context of {required_access_level.name}, "
f"however the SDK is running in a less specific context level: {self.config.api_context.level}."
)

async def _ensure_context(self, required_context: ApiContextLevel) -> None:
Expand All @@ -335,8 +265,7 @@ async def _ensure_context(self, required_context: ApiContextLevel) -> None:
# should only happen once in the lifetime of the sdk
if (
self.config.api_context.level == ApiContextLevel.WAIT_FOR_INIT
or self.config.api_context.permitted_access_level
== ApiKeyAccessLevel.WAIT_FOR_INIT
or self.config.api_context.permitted_access_level == ApiKeyAccessLevel.WAIT_FOR_INIT
):
await self._set_context_from_api_key()

Expand Down
37 changes: 15 additions & 22 deletions permit/api/condition_set_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
if PYDANTIC_VERSION < (2, 0):
from pydantic import validate_arguments
else:
from pydantic.v1 import validate_arguments # type: ignore
from pydantic.v1 import validate_arguments

from .base import (
BasePermitApi,
SimpleHttpClient,
pagination_params,
required_context,
required_permissions,
)
from .context import ApiContextLevel, ApiKeyAccessLevel
from .models import ConditionSetRuleCreate, ConditionSetRuleRead, ConditionSetRuleRemove
Expand All @@ -22,15 +20,10 @@ class ConditionSetRulesApi(BasePermitApi):
@property
def __condition_set_rules(self) -> SimpleHttpClient:
return self._build_http_client(
"/v2/facts/{proj_id}/{env_id}/set_rules".format(
proj_id=self.config.api_context.project,
env_id=self.config.api_context.environment,
)
f"/v2/facts/{self.config.api_context.project}/{self.config.api_context.environment}/set_rules"
)

@required_permissions(ApiKeyAccessLevel.ENVIRONMENT_LEVEL_API_KEY)
@required_context(ApiContextLevel.ENVIRONMENT)
@validate_arguments
@validate_arguments # type: ignore[operator]
async def list(
self,
user_set_key: Optional[str] = None,
Expand All @@ -57,22 +50,22 @@ async def list(
PermitApiError: If the API returns an error HTTP status code.
PermitContextError: If the configured ApiContext does not match the required endpoint context.
"""
await self._ensure_access_level(ApiKeyAccessLevel.ENVIRONMENT_LEVEL_API_KEY)
await self._ensure_context(ApiContextLevel.ENVIRONMENT)
params = pagination_params(page, per_page)
if user_set_key is not None:
params.update(dict(user_set=user_set_key))
params.update(user_set=user_set_key)
if permission_key is not None:
params.update(dict(permission=permission_key))
params.update(permission=permission_key)
if resource_set_key is not None:
params.update(dict(resource_set=resource_set_key))
params.update(resource_set=resource_set_key)
return await self.__condition_set_rules.get(
"",
model=List[ConditionSetRuleRead],
params=params,
)

@required_permissions(ApiKeyAccessLevel.ENVIRONMENT_LEVEL_API_KEY)
@required_context(ApiContextLevel.ENVIRONMENT)
@validate_arguments
@validate_arguments # type: ignore[operator]
async def create(self, rule: ConditionSetRuleCreate) -> List[ConditionSetRuleRead]:
"""
Creates a new condition set rule.
Expand All @@ -87,13 +80,11 @@ async def create(self, rule: ConditionSetRuleCreate) -> List[ConditionSetRuleRea
PermitApiError: If the API returns an error HTTP status code.
PermitContextError: If the configured ApiContext does not match the required endpoint context.
"""
return await self.__condition_set_rules.post(
"", model=List[ConditionSetRuleRead], json=rule
)
await self._ensure_access_level(ApiKeyAccessLevel.ENVIRONMENT_LEVEL_API_KEY)
await self._ensure_context(ApiContextLevel.ENVIRONMENT)
return await self.__condition_set_rules.post("", model=List[ConditionSetRuleRead], json=rule)

@required_permissions(ApiKeyAccessLevel.ENVIRONMENT_LEVEL_API_KEY)
@required_context(ApiContextLevel.ENVIRONMENT)
@validate_arguments
@validate_arguments # type: ignore[operator]
async def delete(self, rule: ConditionSetRuleRemove) -> None:
"""
Deletes a condition set rule.
Expand All @@ -105,4 +96,6 @@ async def delete(self, rule: ConditionSetRuleRemove) -> None:
PermitApiError: If the API returns an error HTTP status code.
PermitContextError: If the configured ApiContext does not match the required endpoint context.
"""
await self._ensure_access_level(ApiKeyAccessLevel.ENVIRONMENT_LEVEL_API_KEY)
await self._ensure_context(ApiContextLevel.ENVIRONMENT)
return await self.__condition_set_rules.delete("", json=rule)
Loading
Loading