Skip to content

Commit

Permalink
Refactor 2: make callers responsible for sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
msm-code committed May 31, 2024
1 parent e02a197 commit 4da8b52
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 289 deletions.
133 changes: 89 additions & 44 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from starlette.requests import Request # type: ignore
from starlette.responses import Response, FileResponse, StreamingResponse # type: ignore
from starlette.staticfiles import StaticFiles # type: ignore
from sqlmodel import Session
from zmq import Again
from typing import Any, Callable, List, Union, Dict, Iterable, Optional
import tempfile
Expand Down Expand Up @@ -58,6 +59,11 @@ def with_plugins() -> Iterable[PluginManager]:
plugins.cleanup()


def with_session() -> Iterable[Session]:
with db.session() as session:
yield session


class User:
def __init__(self, token: Optional[Dict]) -> None:
self.__token = token
Expand All @@ -81,8 +87,11 @@ def roles(self, client_id: Optional[str]) -> List[str]:
return []


async def current_user(authorization: Optional[str] = Header(None)) -> User:
auth_enabled = db.get_mquery_config_key("auth_enabled")
async def current_user(
authorization: Optional[str] = Header(None),
session: Session = Depends(with_session),
) -> User:
auth_enabled = db.get_mquery_config_key(session, "auth_enabled")
if not auth_enabled or auth_enabled == "false":
return User(None)

Expand All @@ -102,7 +111,7 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User:

_bearer, token = token_parts

secret = db.get_mquery_config_key("openid_secret")
secret = db.get_mquery_config_key(session, "openid_secret")
if secret is None:
raise RuntimeError("Invalid configuration - missing_openid_secret.")

Expand Down Expand Up @@ -136,12 +145,16 @@ class RoleChecker:
def __init__(self, need_permissions: List[str]) -> None:
self.need_permissions = need_permissions

def __call__(self, user: User = Depends(current_user)):
auth_enabled = db.get_mquery_config_key("auth_enabled")
def __call__(
self,
user: User = Depends(current_user),
session: Session = Depends(with_session),
):
auth_enabled = db.get_mquery_config_key(session, "auth_enabled")
if not auth_enabled or auth_enabled == "false":
return

all_roles = get_user_roles(user)
all_roles = get_user_roles(session, user)

if not any(role in self.need_permissions for role in all_roles):
message = (
Expand All @@ -165,10 +178,12 @@ def __call__(self, user: User = Depends(current_user)):
can_download_files = RoleChecker(["can_download_files"])


def get_user_roles(user: User) -> List[str]:
client_id = db.get_mquery_config_key("openid_client_id")
def get_user_roles(session: Session, user: User) -> List[str]:
client_id = db.get_mquery_config_key(session, "openid_client_id")
user_roles = user.roles(client_id)
auth_default_roles = db.get_mquery_config_key("auth_default_roles")
auth_default_roles = db.get_mquery_config_key(
session, "auth_default_roles"
)
if not auth_default_roles:
auth_default_roles = "admin"
default_roles = [role.strip() for role in auth_default_roles.split(",")]
Expand Down Expand Up @@ -211,12 +226,14 @@ def expand_role(role: str) -> List[str]:
tags=["internal"],
dependencies=[Depends(is_admin)],
)
def config_list() -> List[ConfigSchema]:
def config_list(
session: Session = Depends(with_session),
) -> List[ConfigSchema]:
"""Returns the current database configuration.
This endpoint is not stable and may be subject to change in the future.
"""
return db.get_config()
return db.get_config(session)


@app.post(
Expand All @@ -225,12 +242,15 @@ def config_list() -> List[ConfigSchema]:
tags=["internal"],
dependencies=[Depends(is_admin)],
)
def config_edit(data: RequestConfigEdit = Body(...)) -> StatusSchema:
def config_edit(
data: RequestConfigEdit = Body(...),
session: Session = Depends(with_session),
) -> StatusSchema:
"""Change a given configuration key to a specified value.
This endpoint is not stable and may be subject to change in the future.
"""
db.set_config_key(data.plugin, data.key, data.value)
db.set_config_key(session, data.plugin, data.key, data.value)
return StatusSchema(status="ok")


Expand All @@ -240,7 +260,9 @@ def config_edit(data: RequestConfigEdit = Body(...)) -> StatusSchema:
tags=["internal"],
dependencies=[Depends(is_admin)],
)
def backend_status() -> BackendStatusSchema:
def backend_status(
session: Session = Depends(with_session),
) -> BackendStatusSchema:
"""Gets the current status of backend services, and returns it. Intended to
be used by the webpage.
Expand All @@ -250,7 +272,7 @@ def backend_status() -> BackendStatusSchema:
components = {
"mquery": mquery_version(),
}
for name, agent_spec in db.get_active_agents().items():
for name, agent_spec in db.get_active_agents(session).items():
try:
ursa = UrsaDb(agent_spec.ursadb_url)
status = ursa.status()
Expand Down Expand Up @@ -280,7 +302,9 @@ def backend_status() -> BackendStatusSchema:
tags=["internal"],
dependencies=[Depends(can_view_queries)],
)
def backend_status_datasets() -> BackendStatusDatasetsSchema:
def backend_status_datasets(
session: Session = Depends(with_session),
) -> BackendStatusDatasetsSchema:
"""Returns a combined list of datasets from all agents.
Caveat: In case of collision of dataset ids when there are multiple agents,
Expand All @@ -290,7 +314,7 @@ def backend_status_datasets() -> BackendStatusDatasetsSchema:
This endpoint is not stable and may be subject to change in the future.
"""
datasets: Dict[str, int] = {}
for agent_spec in db.get_active_agents().values():
for agent_spec in db.get_active_agents(session).values():
try:
ursa = UrsaDb(agent_spec.ursadb_url)
datasets.update(ursa.topology()["result"]["datasets"])
Expand All @@ -314,6 +338,7 @@ def download(
ordinal: int,
file_path: str,
plugins: PluginManager = Depends(with_plugins),
session: Session = Depends(with_session),
) -> Response:
"""Sends a file from given `file_path`. This path should come from
results of one of the previous searches.
Expand All @@ -322,7 +347,7 @@ def download(
(index of the file in that job), to ensure that user can't download
arbitrary files (for example "/etc/passwd").
"""
if not db.job_contains(job_id, ordinal, file_path):
if not db.job_contains(session, job_id, ordinal, file_path):
return Response("No such file in result set.", status_code=404)

attach_name, ext = os.path.splitext(os.path.basename(file_path))
Expand All @@ -341,12 +366,14 @@ def download(
@app.get(
"/api/download/hashes/{job_id}", dependencies=[Depends(can_view_queries)]
)
def download_hashes(job_id: str) -> Response:
def download_hashes(
job_id: str, session: Session = Depends(with_session)
) -> Response:
"""Returns a list of job matches as a sha256 strings joined with newlines."""

hashes = "\n".join(
d["meta"]["sha256"]["display_text"]
for d in db.get_job_matches(job_id).matches
for d in db.get_job_matches(session, job_id).matches
)
return Response(hashes + "\n")

Expand Down Expand Up @@ -378,9 +405,11 @@ def zip_files(
dependencies=[Depends(is_user), Depends(can_download_files)],
)
async def download_files(
job_id: str, plugins: PluginManager = Depends(with_plugins)
job_id: str,
plugins: PluginManager = Depends(with_plugins),
session: Session = Depends(with_session),
) -> StreamingResponse:
matches = db.get_job_matches(job_id).matches
matches = db.get_job_matches(session, job_id).matches
return StreamingResponse(zip_files(plugins, matches))


Expand All @@ -391,7 +420,9 @@ async def download_files(
dependencies=[Depends(can_manage_queries)],
)
def query(
data: QueryRequestSchema = Body(...), user: User = Depends(current_user)
data: QueryRequestSchema = Body(...),
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> Union[QueryResponseSchema, List[ParseResponseSchema]]:
"""Starts a new search. Response will contain a new job ID that can be used
to check the job status and download matched files.
Expand Down Expand Up @@ -420,7 +451,9 @@ def query(
]

degenerate_rules = [r.name for r in rules if r.parse().is_degenerate]
allow_slow = db.get_mquery_config_key("query_allow_slow") == "true"
allow_slow = (
db.get_mquery_config_key(session, "query_allow_slow") == "true"
)
if degenerate_rules and not (allow_slow and data.force_slow_queries):
if allow_slow:
# Warning: "You can force a slow query" literal is used to
Expand All @@ -441,7 +474,7 @@ def query(
),
)

active_agents = db.get_active_agents()
active_agents = db.get_active_agents(session)

for agent, agent_spec in active_agents.items():
missing = set(data.required_plugins).difference(
Expand All @@ -458,6 +491,7 @@ def query(
data.taints = []

job = db.create_search_task(
session,
rules[-1].name,
user.name,
data.raw_yara,
Expand All @@ -476,13 +510,16 @@ def query(
dependencies=[Depends(can_view_queries)],
)
def matches(
job_id: str, offset: int = Query(...), limit: int = Query(...)
job_id: str,
offset: int = Query(...),
limit: int = Query(...),
session: Session = Depends(with_session),
) -> MatchesSchema:
"""Returns a list of matched files, along with metadata tags and other
useful information. Results from this query can be used to download files
using the `/download` endpoint.
"""
return db.get_job_matches(job_id, offset, limit)
return db.get_job_matches(session, job_id, offset, limit)


@app.get(
Expand All @@ -491,11 +528,11 @@ def matches(
tags=["stable"],
dependencies=[Depends(can_view_queries)],
)
def job_info(job_id: str) -> Job:
def job_info(job_id: str, session: Session = Depends(with_session)) -> Job:
"""Returns a metadata for a single job. May be useful for monitoring
a job progress.
"""
return db.get_job(job_id)
return db.get_job(session, job_id)


@app.delete(
Expand All @@ -505,18 +542,20 @@ def job_info(job_id: str) -> Job:
dependencies=[Depends(can_manage_queries)],
)
def job_cancel(
job_id: str, user: User = Depends(current_user)
job_id: str,
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> StatusSchema:
"""Cancels the job with a provided `job_id`."""
if "can_manage_all_queries" not in get_user_roles(user):
job = db.get_job(job_id)
if "can_manage_all_queries" not in get_user_roles(session, user):
job = db.get_job(session, job_id)
if job.rule_author != user.name:
raise HTTPException(
status_code=400,
detail="You don't have enough permissions to cancel this job.",
)

db.cancel_job(job_id)
db.cancel_job(session, job_id)
return StatusSchema(status="ok")


Expand All @@ -526,14 +565,18 @@ def job_cancel(
tags=["stable"],
dependencies=[Depends(can_list_queries)],
)
def job_statuses(user: User = Depends(current_user)) -> JobsSchema:
def job_statuses(
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> JobsSchema:
"""Returns statuses of all the jobs in the system. May take some time (> 1s)
when there are a lot of them.
"""
jobs = [db.get_job(job) for job in db.get_job_ids()]
# TODO: rewrite this in a more ORM-friendly way
jobs = [db.get_job(session, job) for job in db.get_job_ids(session)]
jobs = sorted(jobs, key=lambda j: j.submitted, reverse=True)
jobs = [j for j in jobs if j.status != "removed"]
if "can_list_all_queries" not in get_user_roles(user):
if "can_list_all_queries" not in get_user_roles(session, user):
jobs = [j for j in jobs if j.rule_author == user.name]

return JobsSchema(jobs=jobs)
Expand All @@ -545,17 +588,19 @@ def job_statuses(user: User = Depends(current_user)) -> JobsSchema:
dependencies=[Depends(can_manage_queries)],
)
def query_remove(
job_id: str, user: User = Depends(current_user)
job_id: str,
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> StatusSchema:
if "can_manage_all_queries" not in get_user_roles(user):
job = db.get_job(job_id)
if "can_manage_all_queries" not in get_user_roles(session, user):
job = db.get_job(session, job_id)
if job.rule_author != user.name:
raise HTTPException(
status_code=400,
detail="You don't have enough permissions to remove this job.",
)

db.remove_query(job_id)
db.remove_query(session, job_id)
return StatusSchema(status="ok")


Expand All @@ -565,12 +610,12 @@ def query_remove(


@app.get("/api/server", response_model=ServerSchema, tags=["stable"])
def server() -> ServerSchema:
def server(session: Session = Depends(with_session)) -> ServerSchema:
return ServerSchema(
version=mquery_version(),
auth_enabled=db.get_mquery_config_key("auth_enabled"),
openid_url=db.get_mquery_config_key("openid_url"),
openid_client_id=db.get_mquery_config_key("openid_client_id"),
auth_enabled=db.get_mquery_config_key(session, "auth_enabled"),
openid_url=db.get_mquery_config_key(session, "openid_url"),
openid_client_id=db.get_mquery_config_key(session, "openid_client_id"),
about=app_config.mquery.about,
)

Expand Down
4 changes: 3 additions & 1 deletion src/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def main() -> None:

# Initial registration of the worker group.
# The goal is to make the web UI aware of this worker and its configuration.
tasks.make_agent(args.group_id).register()
tmp_agent = tasks.make_agent(args.group_id)
with tmp_agent.db.session() as session:
tmp_agent.register(session)

if args.scale > 1:
children = [
Expand Down
Loading

0 comments on commit 4da8b52

Please sign in to comment.