Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor 2: make callers responsible for sessions #390

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 89 additions & 44 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from starlette.requests import Request # type: ignore
from starlette.responses import Response, FileResponse, StreamingResponse # type: ignore
from starlette.staticfiles import StaticFiles # type: ignore
from sqlmodel import Session
from zmq import Again
from typing import Any, Callable, List, Union, Dict, Iterable, Optional
import tempfile
Expand Down Expand Up @@ -58,6 +59,11 @@ def with_plugins() -> Iterable[PluginManager]:
plugins.cleanup()


def with_session() -> Iterable[Session]:
with db.session() as session:
yield session


class User:
def __init__(self, token: Optional[Dict]) -> None:
self.__token = token
Expand All @@ -81,8 +87,11 @@ def roles(self, client_id: Optional[str]) -> List[str]:
return []


async def current_user(authorization: Optional[str] = Header(None)) -> User:
auth_enabled = db.get_mquery_config_key("auth_enabled")
async def current_user(
authorization: Optional[str] = Header(None),
session: Session = Depends(with_session),
) -> User:
auth_enabled = db.get_mquery_config_key(session, "auth_enabled")
if not auth_enabled or auth_enabled == "false":
return User(None)

Expand All @@ -102,7 +111,7 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User:

_bearer, token = token_parts

secret = db.get_mquery_config_key("openid_secret")
secret = db.get_mquery_config_key(session, "openid_secret")
if secret is None:
raise RuntimeError("Invalid configuration - missing_openid_secret.")

Expand Down Expand Up @@ -136,12 +145,16 @@ class RoleChecker:
def __init__(self, need_permissions: List[str]) -> None:
self.need_permissions = need_permissions

def __call__(self, user: User = Depends(current_user)):
auth_enabled = db.get_mquery_config_key("auth_enabled")
def __call__(
self,
user: User = Depends(current_user),
session: Session = Depends(with_session),
):
auth_enabled = db.get_mquery_config_key(session, "auth_enabled")
if not auth_enabled or auth_enabled == "false":
return

all_roles = get_user_roles(user)
all_roles = get_user_roles(session, user)

if not any(role in self.need_permissions for role in all_roles):
message = (
Expand All @@ -165,10 +178,12 @@ def __call__(self, user: User = Depends(current_user)):
can_download_files = RoleChecker(["can_download_files"])


def get_user_roles(user: User) -> List[str]:
client_id = db.get_mquery_config_key("openid_client_id")
def get_user_roles(session: Session, user: User) -> List[str]:
client_id = db.get_mquery_config_key(session, "openid_client_id")
user_roles = user.roles(client_id)
auth_default_roles = db.get_mquery_config_key("auth_default_roles")
auth_default_roles = db.get_mquery_config_key(
session, "auth_default_roles"
)
if not auth_default_roles:
auth_default_roles = "admin"
default_roles = [role.strip() for role in auth_default_roles.split(",")]
Expand Down Expand Up @@ -211,12 +226,14 @@ def expand_role(role: str) -> List[str]:
tags=["internal"],
dependencies=[Depends(is_admin)],
)
def config_list() -> List[ConfigSchema]:
def config_list(
session: Session = Depends(with_session),
) -> List[ConfigSchema]:
"""Returns the current database configuration.

This endpoint is not stable and may be subject to change in the future.
"""
return db.get_config()
return db.get_config(session)


@app.post(
Expand All @@ -225,12 +242,15 @@ def config_list() -> List[ConfigSchema]:
tags=["internal"],
dependencies=[Depends(is_admin)],
)
def config_edit(data: RequestConfigEdit = Body(...)) -> StatusSchema:
def config_edit(
data: RequestConfigEdit = Body(...),
session: Session = Depends(with_session),
) -> StatusSchema:
"""Change a given configuration key to a specified value.

This endpoint is not stable and may be subject to change in the future.
"""
db.set_config_key(data.plugin, data.key, data.value)
db.set_config_key(session, data.plugin, data.key, data.value)
return StatusSchema(status="ok")


Expand All @@ -240,7 +260,9 @@ def config_edit(data: RequestConfigEdit = Body(...)) -> StatusSchema:
tags=["internal"],
dependencies=[Depends(is_admin)],
)
def backend_status() -> BackendStatusSchema:
def backend_status(
session: Session = Depends(with_session),
) -> BackendStatusSchema:
"""Gets the current status of backend services, and returns it. Intended to
be used by the webpage.

Expand All @@ -250,7 +272,7 @@ def backend_status() -> BackendStatusSchema:
components = {
"mquery": mquery_version(),
}
for name, agent_spec in db.get_active_agents().items():
for name, agent_spec in db.get_active_agents(session).items():
try:
ursa = UrsaDb(agent_spec.ursadb_url)
status = ursa.status()
Expand Down Expand Up @@ -280,7 +302,9 @@ def backend_status() -> BackendStatusSchema:
tags=["internal"],
dependencies=[Depends(can_view_queries)],
)
def backend_status_datasets() -> BackendStatusDatasetsSchema:
def backend_status_datasets(
session: Session = Depends(with_session),
) -> BackendStatusDatasetsSchema:
"""Returns a combined list of datasets from all agents.

Caveat: In case of collision of dataset ids when there are multiple agents,
Expand All @@ -290,7 +314,7 @@ def backend_status_datasets() -> BackendStatusDatasetsSchema:
This endpoint is not stable and may be subject to change in the future.
"""
datasets: Dict[str, int] = {}
for agent_spec in db.get_active_agents().values():
for agent_spec in db.get_active_agents(session).values():
try:
ursa = UrsaDb(agent_spec.ursadb_url)
datasets.update(ursa.topology()["result"]["datasets"])
Expand All @@ -314,6 +338,7 @@ def download(
ordinal: int,
file_path: str,
plugins: PluginManager = Depends(with_plugins),
session: Session = Depends(with_session),
) -> Response:
"""Sends a file from given `file_path`. This path should come from
results of one of the previous searches.
Expand All @@ -322,7 +347,7 @@ def download(
(index of the file in that job), to ensure that user can't download
arbitrary files (for example "/etc/passwd").
"""
if not db.job_contains(job_id, ordinal, file_path):
if not db.job_contains(session, job_id, ordinal, file_path):
return Response("No such file in result set.", status_code=404)

attach_name, ext = os.path.splitext(os.path.basename(file_path))
Expand All @@ -341,12 +366,14 @@ def download(
@app.get(
"/api/download/hashes/{job_id}", dependencies=[Depends(can_view_queries)]
)
def download_hashes(job_id: str) -> Response:
def download_hashes(
job_id: str, session: Session = Depends(with_session)
) -> Response:
"""Returns a list of job matches as a sha256 strings joined with newlines."""

hashes = "\n".join(
d["meta"]["sha256"]["display_text"]
for d in db.get_job_matches(job_id).matches
for d in db.get_job_matches(session, job_id).matches
)
return Response(hashes + "\n")

Expand Down Expand Up @@ -378,9 +405,11 @@ def zip_files(
dependencies=[Depends(is_user), Depends(can_download_files)],
)
async def download_files(
job_id: str, plugins: PluginManager = Depends(with_plugins)
job_id: str,
plugins: PluginManager = Depends(with_plugins),
session: Session = Depends(with_session),
) -> StreamingResponse:
matches = db.get_job_matches(job_id).matches
matches = db.get_job_matches(session, job_id).matches
return StreamingResponse(zip_files(plugins, matches))


Expand All @@ -391,7 +420,9 @@ async def download_files(
dependencies=[Depends(can_manage_queries)],
)
def query(
data: QueryRequestSchema = Body(...), user: User = Depends(current_user)
data: QueryRequestSchema = Body(...),
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> Union[QueryResponseSchema, List[ParseResponseSchema]]:
"""Starts a new search. Response will contain a new job ID that can be used
to check the job status and download matched files.
Expand Down Expand Up @@ -420,7 +451,9 @@ def query(
]

degenerate_rules = [r.name for r in rules if r.parse().is_degenerate]
allow_slow = db.get_mquery_config_key("query_allow_slow") == "true"
allow_slow = (
db.get_mquery_config_key(session, "query_allow_slow") == "true"
)
if degenerate_rules and not (allow_slow and data.force_slow_queries):
if allow_slow:
# Warning: "You can force a slow query" literal is used to
Expand All @@ -441,7 +474,7 @@ def query(
),
)

active_agents = db.get_active_agents()
active_agents = db.get_active_agents(session)

for agent, agent_spec in active_agents.items():
missing = set(data.required_plugins).difference(
Expand All @@ -458,6 +491,7 @@ def query(
data.taints = []

job = db.create_search_task(
session,
rules[-1].name,
user.name,
data.raw_yara,
Expand All @@ -476,13 +510,16 @@ def query(
dependencies=[Depends(can_view_queries)],
)
def matches(
job_id: str, offset: int = Query(...), limit: int = Query(...)
job_id: str,
offset: int = Query(...),
limit: int = Query(...),
session: Session = Depends(with_session),
) -> MatchesSchema:
"""Returns a list of matched files, along with metadata tags and other
useful information. Results from this query can be used to download files
using the `/download` endpoint.
"""
return db.get_job_matches(job_id, offset, limit)
return db.get_job_matches(session, job_id, offset, limit)


@app.get(
Expand All @@ -491,11 +528,11 @@ def matches(
tags=["stable"],
dependencies=[Depends(can_view_queries)],
)
def job_info(job_id: str) -> Job:
def job_info(job_id: str, session: Session = Depends(with_session)) -> Job:
"""Returns a metadata for a single job. May be useful for monitoring
a job progress.
"""
return db.get_job(job_id)
return db.get_job(session, job_id)


@app.delete(
Expand All @@ -505,18 +542,20 @@ def job_info(job_id: str) -> Job:
dependencies=[Depends(can_manage_queries)],
)
def job_cancel(
job_id: str, user: User = Depends(current_user)
job_id: str,
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> StatusSchema:
"""Cancels the job with a provided `job_id`."""
if "can_manage_all_queries" not in get_user_roles(user):
job = db.get_job(job_id)
if "can_manage_all_queries" not in get_user_roles(session, user):
job = db.get_job(session, job_id)
if job.rule_author != user.name:
raise HTTPException(
status_code=400,
detail="You don't have enough permissions to cancel this job.",
)

db.cancel_job(job_id)
db.cancel_job(session, job_id)
return StatusSchema(status="ok")


Expand All @@ -526,14 +565,18 @@ def job_cancel(
tags=["stable"],
dependencies=[Depends(can_list_queries)],
)
def job_statuses(user: User = Depends(current_user)) -> JobsSchema:
def job_statuses(
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> JobsSchema:
"""Returns statuses of all the jobs in the system. May take some time (> 1s)
when there are a lot of them.
"""
jobs = [db.get_job(job) for job in db.get_job_ids()]
# TODO: rewrite this in a more ORM-friendly way
jobs = [db.get_job(session, job) for job in db.get_job_ids(session)]
jobs = sorted(jobs, key=lambda j: j.submitted, reverse=True)
jobs = [j for j in jobs if j.status != "removed"]
if "can_list_all_queries" not in get_user_roles(user):
if "can_list_all_queries" not in get_user_roles(session, user):
jobs = [j for j in jobs if j.rule_author == user.name]

return JobsSchema(jobs=jobs)
Expand All @@ -545,17 +588,19 @@ def job_statuses(user: User = Depends(current_user)) -> JobsSchema:
dependencies=[Depends(can_manage_queries)],
)
def query_remove(
job_id: str, user: User = Depends(current_user)
job_id: str,
user: User = Depends(current_user),
session: Session = Depends(with_session),
) -> StatusSchema:
if "can_manage_all_queries" not in get_user_roles(user):
job = db.get_job(job_id)
if "can_manage_all_queries" not in get_user_roles(session, user):
job = db.get_job(session, job_id)
if job.rule_author != user.name:
raise HTTPException(
status_code=400,
detail="You don't have enough permissions to remove this job.",
)

db.remove_query(job_id)
db.remove_query(session, job_id)
return StatusSchema(status="ok")


Expand All @@ -565,12 +610,12 @@ def query_remove(


@app.get("/api/server", response_model=ServerSchema, tags=["stable"])
def server() -> ServerSchema:
def server(session: Session = Depends(with_session)) -> ServerSchema:
return ServerSchema(
version=mquery_version(),
auth_enabled=db.get_mquery_config_key("auth_enabled"),
openid_url=db.get_mquery_config_key("openid_url"),
openid_client_id=db.get_mquery_config_key("openid_client_id"),
auth_enabled=db.get_mquery_config_key(session, "auth_enabled"),
openid_url=db.get_mquery_config_key(session, "openid_url"),
openid_client_id=db.get_mquery_config_key(session, "openid_client_id"),
about=app_config.mquery.about,
)

Expand Down
4 changes: 3 additions & 1 deletion src/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def main() -> None:

# Initial registration of the worker group.
# The goal is to make the web UI aware of this worker and its configuration.
tasks.make_agent(args.group_id).register()
tmp_agent = tasks.make_agent(args.group_id)
with tmp_agent.db.session() as session:
tmp_agent.register(session)

if args.scale > 1:
children = [
Expand Down
Loading
Loading