diff --git a/src/app.py b/src/app.py index e0711b57..319608e1 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,7 @@ from contextlib import asynccontextmanager import os +from enum import Enum, auto import uvicorn # type: ignore from pathlib import Path from fastapi import ( @@ -66,6 +67,22 @@ def with_plugins() -> Iterable[PluginManager]: plugins.cleanup() +# See docs/users.md for documentation on the permission model. +class UserRole(Enum): + # "role groups", used to grant a collection of "action roles" + nobody = auto() # no permissions granted + user = auto() # can run yara queries and read the state + admin = auto() # can manage the system (and do everything else) + + # "action roles", used to give permission to a specific thing + can_manage_all_queries = auto() + can_manage_queries = auto() + can_list_all_queries = auto() + can_list_queries = auto() + can_view_queries = auto() + can_download_files = auto() + + class User: def __init__(self, token: Optional[Dict]) -> None: self.__token = token @@ -80,11 +97,12 @@ def name(self) -> str: return "anonymous" return self.__token.get("preferred_username", "unknown") - def roles(self, client_id: Optional[str]) -> List[str]: + def roles(self, client_id: Optional[str]) -> List[UserRole]: if self.__token is None: return [] try: - return self.__token["resource_access"][client_id]["roles"] + role_names = self.__token["resource_access"][client_id]["roles"] + return [UserRole[name] for name in role_names] except KeyError: return [] @@ -141,7 +159,7 @@ async def add_headers(request: Request, call_next: Callable) -> Response: class RoleChecker: - def __init__(self, need_permissions: List[str]) -> None: + def __init__(self, need_permissions: List[UserRole]) -> None: self.need_permissions = need_permissions def __call__(self, user: User = Depends(current_user)): @@ -150,7 +168,6 @@ def __call__(self, user: User = Depends(current_user)): return all_roles = get_user_roles(user) - if not any(role in self.need_permissions for role in all_roles): message = ( f"Operation not allowed for user {user.name} " @@ -164,45 +181,48 @@ def __call__(self, user: User = Depends(current_user)): ) -# See docs/users.md for documentation on the permission model. -is_admin = RoleChecker(["admin"]) -is_user = RoleChecker(["user"]) -can_view_queries = RoleChecker(["can_view_queries"]) -can_manage_queries = RoleChecker(["can_manage_queries"]) -can_list_queries = RoleChecker(["can_list_queries"]) -can_download_files = RoleChecker(["can_download_files"]) +is_admin = RoleChecker([UserRole.admin]) +is_user = RoleChecker([UserRole.user]) +can_view_queries = RoleChecker([UserRole.can_view_queries]) +can_manage_queries = RoleChecker([UserRole.can_manage_queries]) +can_list_queries = RoleChecker([UserRole.can_list_queries]) +can_download_files = RoleChecker([UserRole.can_download_files]) -def get_user_roles(user: User) -> List[str]: +def get_user_roles(user: User) -> List[UserRole]: + """Get all roles assigned to user, taking into account the + system configuration (like default configured roles)""" client_id = db.get_mquery_config_key("openid_client_id") user_roles = user.roles(client_id) auth_default_roles = db.get_mquery_config_key("auth_default_roles") if not auth_default_roles: auth_default_roles = "admin" - default_roles = [role.strip() for role in auth_default_roles.split(",")] + default_roles = [ + UserRole[role.strip()] for role in auth_default_roles.split(",") + ] all_roles = set(user_roles + default_roles) return sum((expand_role(role) for role in all_roles), []) -def expand_role(role: str) -> List[str]: +def expand_role(role: UserRole) -> List[UserRole]: """Some roles imply other roles or permissions. For example, admin role also gives permissions for all user permissions. """ - role_implications: Dict = { - "nobody": [], - "admin": [ - "user", - "can_list_all_queries", - "can_manage_all_queries", + role_implications: Dict[UserRole, List[UserRole]] = { + UserRole.nobody: [], + UserRole.admin: [ + UserRole.user, + UserRole.can_list_all_queries, + UserRole.can_manage_all_queries, ], - "user": [ - "can_view_queries", - "can_manage_queries", - "can_list_queries", - "can_download_files", + UserRole.user: [ + UserRole.can_view_queries, + UserRole.can_manage_queries, + UserRole.can_list_queries, + UserRole.can_download_files, ], - "can_manage_all_queries": ["can_manage_queries"], - "can_list_all_queries": ["can_list_queries"], + UserRole.can_manage_all_queries: [UserRole.can_manage_queries], + UserRole.can_list_all_queries: [UserRole.can_list_queries], } implied_roles = [role] for subrole in role_implications.get(role, []): @@ -517,7 +537,7 @@ def job_cancel( job_id: str, user: User = Depends(current_user) ) -> StatusSchema: """Cancels the job with a provided `job_id`.""" - if "can_manage_all_queries" not in get_user_roles(user): + if UserRole.can_manage_all_queries not in get_user_roles(user): job = db.get_job(job_id) if job.rule_author != user.name: raise HTTPException( @@ -540,7 +560,7 @@ def job_statuses(user: User = Depends(current_user)) -> JobsSchema: when there are a lot of them. """ username_filter: Optional[str] = user.name - if "can_list_all_queries" in get_user_roles(user): + if UserRole.can_list_all_queries in get_user_roles(user): username_filter = None jobs = db.get_valid_jobs(username_filter) return JobsSchema(jobs=jobs) @@ -554,7 +574,7 @@ def job_statuses(user: User = Depends(current_user)) -> JobsSchema: def query_remove( job_id: str, user: User = Depends(current_user) ) -> StatusSchema: - if "can_manage_all_queries" not in get_user_roles(user): + if UserRole.can_manage_all_queries not in get_user_roles(user): job = db.get_job(job_id) if job.rule_author != user.name: raise HTTPException(