diff --git a/src/app.py b/src/app.py index 2d283051..f71bac7e 100644 --- a/src/app.py +++ b/src/app.py @@ -13,6 +13,7 @@ from starlette.requests import Request # type: ignore from starlette.responses import Response, FileResponse, StreamingResponse # type: ignore from starlette.staticfiles import StaticFiles # type: ignore +from sqlmodel import Session from zmq import Again from typing import Any, Callable, List, Union, Dict, Iterable, Optional import tempfile @@ -58,6 +59,11 @@ def with_plugins() -> Iterable[PluginManager]: plugins.cleanup() +def with_session() -> Iterable[Session]: + with db.session() as session: + yield session + + class User: def __init__(self, token: Optional[Dict]) -> None: self.__token = token @@ -81,8 +87,11 @@ def roles(self, client_id: Optional[str]) -> List[str]: return [] -async def current_user(authorization: Optional[str] = Header(None)) -> User: - auth_enabled = db.get_mquery_config_key("auth_enabled") +async def current_user( + authorization: Optional[str] = Header(None), + session: Session = Depends(with_session), +) -> User: + auth_enabled = db.get_mquery_config_key(session, "auth_enabled") if not auth_enabled or auth_enabled == "false": return User(None) @@ -102,7 +111,7 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User: _bearer, token = token_parts - secret = db.get_mquery_config_key("openid_secret") + secret = db.get_mquery_config_key(session, "openid_secret") if secret is None: raise RuntimeError("Invalid configuration - missing_openid_secret.") @@ -136,12 +145,16 @@ class RoleChecker: def __init__(self, need_permissions: List[str]) -> None: self.need_permissions = need_permissions - def __call__(self, user: User = Depends(current_user)): - auth_enabled = db.get_mquery_config_key("auth_enabled") + def __call__( + self, + user: User = Depends(current_user), + session: Session = Depends(with_session), + ): + auth_enabled = db.get_mquery_config_key(session, "auth_enabled") if not auth_enabled or auth_enabled == "false": return - all_roles = get_user_roles(user) + all_roles = get_user_roles(session, user) if not any(role in self.need_permissions for role in all_roles): message = ( @@ -165,10 +178,12 @@ def __call__(self, user: User = Depends(current_user)): can_download_files = RoleChecker(["can_download_files"]) -def get_user_roles(user: User) -> List[str]: - client_id = db.get_mquery_config_key("openid_client_id") +def get_user_roles(session: Session, user: User) -> List[str]: + client_id = db.get_mquery_config_key(session, "openid_client_id") user_roles = user.roles(client_id) - auth_default_roles = db.get_mquery_config_key("auth_default_roles") + auth_default_roles = db.get_mquery_config_key( + session, "auth_default_roles" + ) if not auth_default_roles: auth_default_roles = "admin" default_roles = [role.strip() for role in auth_default_roles.split(",")] @@ -211,12 +226,14 @@ def expand_role(role: str) -> List[str]: tags=["internal"], dependencies=[Depends(is_admin)], ) -def config_list() -> List[ConfigSchema]: +def config_list( + session: Session = Depends(with_session), +) -> List[ConfigSchema]: """Returns the current database configuration. This endpoint is not stable and may be subject to change in the future. """ - return db.get_config() + return db.get_config(session) @app.post( @@ -225,12 +242,15 @@ def config_list() -> List[ConfigSchema]: tags=["internal"], dependencies=[Depends(is_admin)], ) -def config_edit(data: RequestConfigEdit = Body(...)) -> StatusSchema: +def config_edit( + data: RequestConfigEdit = Body(...), + session: Session = Depends(with_session), +) -> StatusSchema: """Change a given configuration key to a specified value. This endpoint is not stable and may be subject to change in the future. """ - db.set_config_key(data.plugin, data.key, data.value) + db.set_config_key(session, data.plugin, data.key, data.value) return StatusSchema(status="ok") @@ -240,7 +260,9 @@ def config_edit(data: RequestConfigEdit = Body(...)) -> StatusSchema: tags=["internal"], dependencies=[Depends(is_admin)], ) -def backend_status() -> BackendStatusSchema: +def backend_status( + session: Session = Depends(with_session), +) -> BackendStatusSchema: """Gets the current status of backend services, and returns it. Intended to be used by the webpage. @@ -250,7 +272,7 @@ def backend_status() -> BackendStatusSchema: components = { "mquery": mquery_version(), } - for name, agent_spec in db.get_active_agents().items(): + for name, agent_spec in db.get_active_agents(session).items(): try: ursa = UrsaDb(agent_spec.ursadb_url) status = ursa.status() @@ -280,7 +302,9 @@ def backend_status() -> BackendStatusSchema: tags=["internal"], dependencies=[Depends(can_view_queries)], ) -def backend_status_datasets() -> BackendStatusDatasetsSchema: +def backend_status_datasets( + session: Session = Depends(with_session), +) -> BackendStatusDatasetsSchema: """Returns a combined list of datasets from all agents. Caveat: In case of collision of dataset ids when there are multiple agents, @@ -290,7 +314,7 @@ def backend_status_datasets() -> BackendStatusDatasetsSchema: This endpoint is not stable and may be subject to change in the future. """ datasets: Dict[str, int] = {} - for agent_spec in db.get_active_agents().values(): + for agent_spec in db.get_active_agents(session).values(): try: ursa = UrsaDb(agent_spec.ursadb_url) datasets.update(ursa.topology()["result"]["datasets"]) @@ -314,6 +338,7 @@ def download( ordinal: int, file_path: str, plugins: PluginManager = Depends(with_plugins), + session: Session = Depends(with_session), ) -> Response: """Sends a file from given `file_path`. This path should come from results of one of the previous searches. @@ -322,7 +347,7 @@ def download( (index of the file in that job), to ensure that user can't download arbitrary files (for example "/etc/passwd"). """ - if not db.job_contains(job_id, ordinal, file_path): + if not db.job_contains(session, job_id, ordinal, file_path): return Response("No such file in result set.", status_code=404) attach_name, ext = os.path.splitext(os.path.basename(file_path)) @@ -341,12 +366,14 @@ def download( @app.get( "/api/download/hashes/{job_id}", dependencies=[Depends(can_view_queries)] ) -def download_hashes(job_id: str) -> Response: +def download_hashes( + job_id: str, session: Session = Depends(with_session) +) -> Response: """Returns a list of job matches as a sha256 strings joined with newlines.""" hashes = "\n".join( d["meta"]["sha256"]["display_text"] - for d in db.get_job_matches(job_id).matches + for d in db.get_job_matches(session, job_id).matches ) return Response(hashes + "\n") @@ -378,9 +405,11 @@ def zip_files( dependencies=[Depends(is_user), Depends(can_download_files)], ) async def download_files( - job_id: str, plugins: PluginManager = Depends(with_plugins) + job_id: str, + plugins: PluginManager = Depends(with_plugins), + session: Session = Depends(with_session), ) -> StreamingResponse: - matches = db.get_job_matches(job_id).matches + matches = db.get_job_matches(session, job_id).matches return StreamingResponse(zip_files(plugins, matches)) @@ -391,7 +420,9 @@ async def download_files( dependencies=[Depends(can_manage_queries)], ) def query( - data: QueryRequestSchema = Body(...), user: User = Depends(current_user) + data: QueryRequestSchema = Body(...), + user: User = Depends(current_user), + session: Session = Depends(with_session), ) -> Union[QueryResponseSchema, List[ParseResponseSchema]]: """Starts a new search. Response will contain a new job ID that can be used to check the job status and download matched files. @@ -420,7 +451,9 @@ def query( ] degenerate_rules = [r.name for r in rules if r.parse().is_degenerate] - allow_slow = db.get_mquery_config_key("query_allow_slow") == "true" + allow_slow = ( + db.get_mquery_config_key(session, "query_allow_slow") == "true" + ) if degenerate_rules and not (allow_slow and data.force_slow_queries): if allow_slow: # Warning: "You can force a slow query" literal is used to @@ -441,7 +474,7 @@ def query( ), ) - active_agents = db.get_active_agents() + active_agents = db.get_active_agents(session) for agent, agent_spec in active_agents.items(): missing = set(data.required_plugins).difference( @@ -458,6 +491,7 @@ def query( data.taints = [] job = db.create_search_task( + session, rules[-1].name, user.name, data.raw_yara, @@ -476,13 +510,16 @@ def query( dependencies=[Depends(can_view_queries)], ) def matches( - job_id: str, offset: int = Query(...), limit: int = Query(...) + job_id: str, + offset: int = Query(...), + limit: int = Query(...), + session: Session = Depends(with_session), ) -> MatchesSchema: """Returns a list of matched files, along with metadata tags and other useful information. Results from this query can be used to download files using the `/download` endpoint. """ - return db.get_job_matches(job_id, offset, limit) + return db.get_job_matches(session, job_id, offset, limit) @app.get( @@ -491,11 +528,11 @@ def matches( tags=["stable"], dependencies=[Depends(can_view_queries)], ) -def job_info(job_id: str) -> Job: +def job_info(job_id: str, session: Session = Depends(with_session)) -> Job: """Returns a metadata for a single job. May be useful for monitoring a job progress. """ - return db.get_job(job_id) + return db.get_job(session, job_id) @app.delete( @@ -505,18 +542,20 @@ def job_info(job_id: str) -> Job: dependencies=[Depends(can_manage_queries)], ) def job_cancel( - job_id: str, user: User = Depends(current_user) + job_id: str, + user: User = Depends(current_user), + session: Session = Depends(with_session), ) -> StatusSchema: """Cancels the job with a provided `job_id`.""" - if "can_manage_all_queries" not in get_user_roles(user): - job = db.get_job(job_id) + if "can_manage_all_queries" not in get_user_roles(session, user): + job = db.get_job(session, job_id) if job.rule_author != user.name: raise HTTPException( status_code=400, detail="You don't have enough permissions to cancel this job.", ) - db.cancel_job(job_id) + db.cancel_job(session, job_id) return StatusSchema(status="ok") @@ -526,14 +565,18 @@ def job_cancel( tags=["stable"], dependencies=[Depends(can_list_queries)], ) -def job_statuses(user: User = Depends(current_user)) -> JobsSchema: +def job_statuses( + user: User = Depends(current_user), + session: Session = Depends(with_session), +) -> JobsSchema: """Returns statuses of all the jobs in the system. May take some time (> 1s) when there are a lot of them. """ - jobs = [db.get_job(job) for job in db.get_job_ids()] + # TODO: rewrite this in a more ORM-friendly way + jobs = [db.get_job(session, job) for job in db.get_job_ids(session)] jobs = sorted(jobs, key=lambda j: j.submitted, reverse=True) jobs = [j for j in jobs if j.status != "removed"] - if "can_list_all_queries" not in get_user_roles(user): + if "can_list_all_queries" not in get_user_roles(session, user): jobs = [j for j in jobs if j.rule_author == user.name] return JobsSchema(jobs=jobs) @@ -545,17 +588,19 @@ def job_statuses(user: User = Depends(current_user)) -> JobsSchema: dependencies=[Depends(can_manage_queries)], ) def query_remove( - job_id: str, user: User = Depends(current_user) + job_id: str, + user: User = Depends(current_user), + session: Session = Depends(with_session), ) -> StatusSchema: - if "can_manage_all_queries" not in get_user_roles(user): - job = db.get_job(job_id) + if "can_manage_all_queries" not in get_user_roles(session, user): + job = db.get_job(session, job_id) if job.rule_author != user.name: raise HTTPException( status_code=400, detail="You don't have enough permissions to remove this job.", ) - db.remove_query(job_id) + db.remove_query(session, job_id) return StatusSchema(status="ok") @@ -565,12 +610,12 @@ def query_remove( @app.get("/api/server", response_model=ServerSchema, tags=["stable"]) -def server() -> ServerSchema: +def server(session: Session = Depends(with_session)) -> ServerSchema: return ServerSchema( version=mquery_version(), - auth_enabled=db.get_mquery_config_key("auth_enabled"), - openid_url=db.get_mquery_config_key("openid_url"), - openid_client_id=db.get_mquery_config_key("openid_client_id"), + auth_enabled=db.get_mquery_config_key(session, "auth_enabled"), + openid_url=db.get_mquery_config_key(session, "openid_url"), + openid_client_id=db.get_mquery_config_key(session, "openid_client_id"), about=app_config.mquery.about, ) diff --git a/src/daemon.py b/src/daemon.py index cc1e7ee1..d86250b5 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -44,7 +44,9 @@ def main() -> None: # Initial registration of the worker group. # The goal is to make the web UI aware of this worker and its configuration. - tasks.make_agent(args.group_id).register() + tmp_agent = tasks.make_agent(args.group_id) + with tmp_agent.db.session() as session: + tmp_agent.register(session) if args.scale > 1: children = [ diff --git a/src/db.py b/src/db.py index f9270ef3..ee6a07b1 100644 --- a/src/db.py +++ b/src/db.py @@ -56,176 +56,164 @@ def __schedule(self, agent: str, task: Any, *args: Any) -> None: def session(self): with Session(self.engine) as session: yield session + session.commit() - def get_job_ids(self) -> List[JobId]: + def get_job_ids(self, session: Session) -> List[JobId]: """Gets IDs of all jobs in the database.""" - with self.session() as session: - jobs = session.exec(select(Job)).all() - return [j.id for j in jobs] + jobs = session.exec(select(Job)).all() + return [j.id for j in jobs] - def cancel_job(self, job: JobId, error=None) -> None: + def cancel_job(self, session: Session, job: JobId, error=None) -> None: """Sets the job status to cancelled, with optional error message.""" - with self.session() as session: - session.execute( - update(Job) - .where(Job.id == job) - .values(status="cancelled", finished=int(time()), error=error) - ) - session.commit() + session.execute( + update(Job) + .where(Job.id == job) + .values(status="cancelled", finished=int(time()), error=error) + ) - def fail_job(self, job: JobId, message: str) -> None: + def fail_job(self, session: Session, job: JobId, message: str) -> None: """Sets the job status to cancelled with provided error message.""" - self.cancel_job(job, message) - - def __get_job(self, session: Session, job: JobId) -> Job: - """Internal helper to get a job from the database.""" - return session.exec(select(Job).where(Job.id == job)).one() + self.cancel_job(session, job, message) - def get_job(self, job: JobId) -> Job: + def get_job(self, session: Session, job: JobId) -> Job: """Retrieves a job from the database.""" - with self.session() as session: - return self.__get_job(session, job) + return session.exec(select(Job).where(Job.id == job)).one() - def remove_query(self, job: JobId) -> None: + def remove_query(self, session: Session, job: JobId) -> None: """Sets the job status to removed.""" - with self.session() as session: - session.execute( - update(Job).where(Job.id == job).values(status="removed") - ) - session.commit() + session.execute( + update(Job).where(Job.id == job).values(status="removed") + ) - def add_match(self, job: JobId, match: Match) -> None: - with self.session() as session: - job_object = self.__get_job(session, job) - match.job = job_object - session.add(match) - session.commit() + def add_match(self, session: Session, job: JobId, match: Match) -> None: + job_object = self.get_job(session, job) + match.job = job_object + session.add(match) - def job_contains(self, job: JobId, ordinal: int, file_path: str) -> bool: + def job_contains( + self, session: Session, job: JobId, ordinal: int, file_path: str + ) -> bool: """Make sure that the file path is in the job results.""" - with self.session() as session: - job_object = self.__get_job(session, job) - statement = select(Match).where( - and_(Match.job == job_object, Match.file == file_path) - ) - entry = session.exec(statement).one_or_none() - return entry is not None + job_object = self.get_job(session, job) + statement = select(Match).where( + and_(Match.job == job_object, Match.file == file_path) + ) + entry = session.exec(statement).one_or_none() + return entry is not None - def job_start_work(self, job: JobId, in_progress: int) -> None: + def job_start_work( + self, session: Session, job: JobId, in_progress: int + ) -> None: """Updates the number of files being processed right now. :param job: ID of the job being updated. :param in_progress: Number of files in the current work unit. """ - with self.session() as session: - session.execute( - update(Job) - .where(Job.id == job) - .values(files_in_progress=Job.files_in_progress + in_progress) - ) - session.commit() + session.execute( + update(Job) + .where(Job.id == job) + .values(files_in_progress=Job.files_in_progress + in_progress) + ) - def agent_finish_job(self, job: Job) -> None: + def agent_finish_job(self, session: Session, job: Job) -> None: """Decrements the number of active agents in the given job. If there are no more agents, job status is changed to done. """ - with self.session() as session: - (agents_left,) = session.execute( + (agents_left,) = session.execute( + update(Job) + .where(Job.internal_id == job.internal_id) + .values(agents_left=Job.agents_left - 1) + .returning(Job.agents_left) + ).one() + if agents_left == 0: + session.execute( update(Job) .where(Job.internal_id == job.internal_id) - .values(agents_left=Job.agents_left - 1) - .returning(Job.agents_left) - ).one() - if agents_left == 0: - session.execute( - update(Job) - .where(Job.internal_id == job.internal_id) - .values(finished=int(time()), status="done") - ) - session.commit() + .values(finished=int(time()), status="done") + ) - def init_jobagent(self, job: Job, agent_id: int, tasks: int) -> None: + def init_jobagent( + self, session: Session, job: Job, agent_id: int, tasks: int + ) -> None: """Creates a new JobAgent object. If tasks==0 then finishes job immediately""" - with self.session() as session: - obj = JobAgent( - task_in_progress=tasks, - job_id=job.internal_id, - agent_id=agent_id, - ) - session.add(obj) - session.commit() + obj = JobAgent( + task_in_progress=tasks, + job_id=job.internal_id, + agent_id=agent_id, + ) + session.add(obj) if tasks == 0: - self.agent_finish_job(job) + self.agent_finish_job(session, job) - def agent_add_tasks_in_progress( - self, job: Job, agent_id: int, tasks: int + def add_tasks_in_progress( + self, session: Session, job: Job, agent_id: int, tasks: int ) -> None: """Increments (or decrements, for negative values) the number of tasks that are in progress for agent. The number of tasks in progress should always stay positive for jobs in status inprogress. This function will automatically call agent_finish_job if the agent has no more tasks left. """ - with self.session() as session: - (tasks_left,) = session.execute( - update(JobAgent) - .where(JobAgent.job_id == job.internal_id) - .where(JobAgent.agent_id == agent_id) - .values(task_in_progress=JobAgent.task_in_progress + tasks) - .returning(JobAgent.task_in_progress) - ).one() - session.commit() + (tasks_left,) = session.execute( + update(JobAgent) + .where(JobAgent.job_id == job.internal_id) + .where(JobAgent.agent_id == agent_id) + .values(task_in_progress=JobAgent.task_in_progress + tasks) + .returning(JobAgent.task_in_progress) + ).one() assert tasks_left >= 0 if tasks_left == 0: - self.agent_finish_job(job) + self.agent_finish_job(session, job) def job_update_work( - self, job: JobId, processed: int, matched: int, errored: int + self, + session: Session, + job: JobId, + processed: int, + matched: int, + errored: int, ) -> int: """Updates progress for the job. This will increment numbers processed, inprogress, errored and matched files. Returns the number of processed files after the operation. """ - with self.session() as session: - (files_processed,) = session.execute( - update(Job) - .where(Job.id == job) - .values( - files_processed=Job.files_processed + processed, - files_in_progress=Job.files_in_progress - processed, - files_matched=Job.files_matched + matched, - files_errored=Job.files_errored + errored, - ) - .returning(Job.files_processed) - ).one() - session.commit() - return files_processed + (files_processed,) = session.execute( + update(Job) + .where(Job.id == job) + .values( + files_processed=Job.files_processed + processed, + files_in_progress=Job.files_in_progress - processed, + files_matched=Job.files_matched + matched, + files_errored=Job.files_errored + errored, + ) + .returning(Job.files_processed) + ).one() + return files_processed - def init_job_datasets(self, job: JobId, num_datasets: int) -> None: + def init_job_datasets( + self, session: Session, job: JobId, num_datasets: int + ) -> None: """Sets total_datasets and datasets_left, and status to processing.""" - with self.session() as session: - session.execute( - update(Job) - .where(Job.id == job) - .values( - total_datasets=num_datasets, - datasets_left=num_datasets, - status="processing", - ) + session.execute( + update(Job) + .where(Job.id == job) + .values( + total_datasets=num_datasets, + datasets_left=num_datasets, + status="processing", ) - session.commit() + ) - def dataset_query_done(self, job: JobId): + def dataset_query_done(self, session: Session, job: JobId): """Decrements the number of datasets left by one.""" - with self.session() as session: - session.execute( - update(Job) - .where(Job.id == job) - .values(datasets_left=Job.datasets_left - 1) - ) - session.commit() + session.execute( + update(Job) + .where(Job.id == job) + .values(datasets_left=Job.datasets_left - 1) + ) def create_search_task( self, + session: Session, rule_name: str, rule_author: str, raw_yara: str, @@ -239,28 +227,27 @@ def create_search_task( random.choice(string.ascii_uppercase + string.digits) for _ in range(12) ) - with self.session() as session: - obj = Job( - id=job, - status="new", - rule_name=rule_name, - rule_author=rule_author, - raw_yara=raw_yara, - submitted=int(time()), - files_limit=files_limit, - reference=reference, - files_in_progress=0, - files_processed=0, - files_matched=0, - files_errored=0, - total_files=0, - agents_left=len(agents), - datasets_left=0, - total_datasets=0, - taints=taints, - ) - session.add(obj) - session.commit() + obj = Job( + id=job, + status="new", + rule_name=rule_name, + rule_author=rule_author, + raw_yara=raw_yara, + submitted=int(time()), + files_limit=files_limit, + reference=reference, + files_in_progress=0, + files_processed=0, + files_matched=0, + files_errored=0, + total_files=0, + agents_left=len(agents), + datasets_left=0, + total_datasets=0, + taints=taints, + ) + session.add(obj) + session.commit() from . import tasks @@ -269,30 +256,34 @@ def create_search_task( return job def get_job_matches( - self, job_id: JobId, offset: int = 0, limit: Optional[int] = None + self, + session: Session, + job_id: JobId, + offset: int = 0, + limit: Optional[int] = None, ) -> MatchesSchema: - with self.session() as session: - job = self.__get_job(session, job_id) - if limit is None: - matches = job.matches[offset:] - else: - matches = job.matches[offset : offset + limit] - return MatchesSchema(job=job, matches=matches) - - def update_job_files(self, job: JobId, total_files: int) -> int: + job = self.get_job(session, job_id) + if limit is None: + matches = job.matches[offset:] + else: + matches = job.matches[offset : offset + limit] + return MatchesSchema(job=job, matches=matches) + + def update_job_files( + self, session: Session, job: JobId, total_files: int + ) -> int: """Add total_files to the specified job, and return a new total.""" - with self.session() as session: - (total_files,) = session.execute( - update(Job) - .where(Job.id == job) - .values(total_files=Job.total_files + total_files) - .returning(Job.total_files) - ).one() - session.commit() + (total_files,) = session.execute( + update(Job) + .where(Job.id == job) + .values(total_files=Job.total_files + total_files) + .returning(Job.total_files) + ).one() return total_files def register_active_agent( self, + session: Session, group_id: str, ursadb_url: str, plugins_spec: Dict[str, Dict[str, str]], @@ -303,22 +294,18 @@ def register_active_agent( # Currently this is done by workers when starting. In the future, # this should be configured by the admin, and workers should just read # their configuration from the database. - with self.session() as session: - entry = session.exec( - select(AgentGroup).where(AgentGroup.name == group_id) - ).one_or_none() - if not entry: - entry = AgentGroup(name=group_id) - entry.ursadb_url = ursadb_url - entry.plugins_spec = plugins_spec - entry.active_plugins = active_plugins - session.add(entry) - session.commit() - - def get_active_agents(self) -> Dict[str, AgentGroup]: - with self.session() as session: - agents = session.exec(select(AgentGroup)).all() - + entry = session.exec( + select(AgentGroup).where(AgentGroup.name == group_id) + ).one_or_none() + if not entry: + entry = AgentGroup(name=group_id) + entry.ursadb_url = ursadb_url + entry.plugins_spec = plugins_spec + entry.active_plugins = active_plugins + session.add(entry) + + def get_active_agents(self, session: Session) -> Dict[str, AgentGroup]: + agents = session.exec(select(AgentGroup)).all() return {agent.name: agent for agent in agents} def get_core_config(self) -> Dict[str, str]: @@ -335,12 +322,12 @@ def get_core_config(self) -> Dict[str, str]: "query_allow_slow": "Allow users to run queries that will end up scanning the whole malware collection", } - def get_config(self) -> List[ConfigSchema]: + def get_config(self, session: Session) -> List[ConfigSchema]: # { plugin_name: { field: description } } config_fields: Dict[str, Dict[str, str]] = defaultdict(dict) config_fields[MQUERY_PLUGIN_NAME] = self.get_core_config() # Merge all config fields - for agent_spec in self.get_active_agents().values(): + for agent_spec in self.get_active_agents(session).values(): for plugin, fields in agent_spec.plugins_spec.items(): config_fields[plugin].update(fields) # Transform fields into ConfigSchema @@ -356,7 +343,7 @@ def get_config(self) -> List[ConfigSchema]: } # Get configuration values for each plugin for plugin, spec in plugin_configs.items(): - config = self.get_plugin_config(plugin) + config = self.get_plugin_config(session, plugin) for key, value in config.items(): if key in plugin_configs[plugin]: plugin_configs[plugin][key].value = value @@ -367,37 +354,39 @@ def get_config(self) -> List[ConfigSchema]: for key in sorted(plugin_configs[plugin].keys()) ] - def get_plugin_config(self, plugin_name: str) -> Dict[str, str]: - with self.session() as session: - entries = session.exec( - select(ConfigEntry).where(ConfigEntry.plugin == plugin_name) - ).all() - return {e.key: e.value for e in entries} - - def get_mquery_config_key(self, key: str) -> Optional[str]: - with self.session() as session: - statement = select(ConfigEntry).where( - and_( - ConfigEntry.plugin == MQUERY_PLUGIN_NAME, - ConfigEntry.key == key, - ) + def get_plugin_config( + self, session: Session, plugin_name: str + ) -> Dict[str, str]: + entries = session.exec( + select(ConfigEntry).where(ConfigEntry.plugin == plugin_name) + ).all() + return {e.key: e.value for e in entries} + + def get_mquery_config_key( + self, session: Session, key: str + ) -> Optional[str]: + statement = select(ConfigEntry).where( + and_( + ConfigEntry.plugin == MQUERY_PLUGIN_NAME, + ConfigEntry.key == key, ) - entry = session.exec(statement).one_or_none() - return entry.value if entry else None - - def set_config_key(self, plugin_name: str, key: str, value: str) -> None: - with self.session() as session: - entry = session.exec( - select(ConfigEntry).where( - ConfigEntry.plugin == plugin_name, - ConfigEntry.key == key, - ) - ).one_or_none() - if not entry: - entry = ConfigEntry(plugin=plugin_name, key=key) - entry.value = value - session.add(entry) - session.commit() + ) + entry = session.exec(statement).one_or_none() + return entry.value if entry else None + + def set_config_key( + self, session: Session, plugin_name: str, key: str, value: str + ) -> None: + entry = session.exec( + select(ConfigEntry).where( + ConfigEntry.plugin == plugin_name, + ConfigEntry.key == key, + ) + ).one_or_none() + if not entry: + entry = ConfigEntry(plugin=plugin_name, key=key) + entry.value = value + session.add(entry) def init_db() -> None: diff --git a/src/tasks.py b/src/tasks.py index edb9a426..126f8a66 100644 --- a/src/tasks.py +++ b/src/tasks.py @@ -3,6 +3,7 @@ from rq import get_current_job, Queue # type: ignore from redis import Redis from contextlib import contextmanager +from sqlmodel import Session import yara # type: ignore from .db import Database, JobId @@ -25,7 +26,7 @@ def __init__(self, group_id: str) -> None: """ self.group_id = group_id self.ursa_url = app_config.mquery.backend - self.__db_object = None # set before starting first task + self.__db_id = None # set before starting first task self.db = Database(app_config.redis.host, app_config.redis.port) self.ursa = UrsaDb(self.ursa_url) self.plugins = PluginManager(app_config.mquery.plugins, self.db) @@ -34,13 +35,12 @@ def __init__(self, group_id: str) -> None: connection=Redis(app_config.redis.host, app_config.redis.port), ) - @property - def db_id(self): - if self.__db_object is None: - self.__db_object = self.db.get_active_agents()[self.group_id] - return cast(int, self.__db_object.id) + def db_id(self, session: Session) -> int: + self.__db_id = self.db.get_active_agents(session)[self.group_id].id + assert self.__db_id is not None + return cast(int, self.__db_id) - def register(self) -> None: + def register(self, session: Session) -> None: """Register the agent in the database. Should happen when starting the worker process. """ @@ -49,6 +49,7 @@ def register(self) -> None: for plugin_class in self.plugins.plugin_classes } self.db.register_active_agent( + session, self.group_id, self.ursa_url, plugins_spec, @@ -68,7 +69,12 @@ def get_datasets(self) -> List[str]: return list(result["result"]["datasets"].keys()) def update_metadata( - self, job: JobId, orig_name: str, path: str, matches: List[str] + self, + session: Session, + job: JobId, + orig_name: str, + path: str, + matches: List[str], ) -> None: """Saves matches to the database, and runs appropriate metadata plugins. @@ -94,14 +100,16 @@ def update_metadata( # Update the database. match = Match(file=orig_name, meta=metadata, matches=matches) - self.db.add_match(job, match) + self.db.add_match(session, job, match) - def execute_yara(self, job: Job, files: List[str]) -> None: + def execute_yara( + self, session: Session, job: Job, files: List[str] + ) -> None: rule = yara.compile(source=job.raw_yara) num_matches = 0 num_errors = 0 num_files = len(files) - self.db.job_start_work(job.id, num_files) + self.db.job_start_work(session, job.id, num_files) for orig_name in files: try: @@ -111,7 +119,11 @@ def execute_yara(self, job: Job, files: List[str]) -> None: matches = rule.match(path) if matches: self.update_metadata( - job.id, orig_name, path, [r.rule for r in matches] + session, + job.id, + orig_name, + path, + [r.rule for r in matches], ) num_matches += 1 except yara.Error: @@ -126,7 +138,7 @@ def execute_yara(self, job: Job, files: List[str]) -> None: self.plugins.cleanup() new_processed = self.db.job_update_work( - job.id, num_files, num_matches, num_errors + session, job.id, num_files, num_matches, num_errors ) yara_limit = app_config.mquery.yara_limit if yara_limit != 0 and new_processed > yara_limit: @@ -134,30 +146,35 @@ def execute_yara(self, job: Job, files: List[str]) -> None: scanned_datasets = job.total_datasets - job.datasets_left dataset_percent = scanned_datasets / job.total_datasets self.db.fail_job( + session, job.id, f"Configured limit of {yara_limit} YARA matches exceeded. " f"Scanned {new_processed}/{job.total_files} ({scan_percent:.0%}) of candidates " f"in {scanned_datasets}/{job.total_datasets} ({dataset_percent:.0%}) of datasets.", ) - def init_search(self, job: Job, tasks: int) -> None: - self.db.init_jobagent(job, self.db_id, tasks) + def init_search(self, session: Session, job: Job, tasks: int) -> None: + self.db.init_jobagent(session, job, self.db_id(session), tasks) - def add_tasks_in_progress(self, job: Job, tasks: int) -> None: - """See documentation of db.agent_add_tasks_in_progress.""" - self.db.agent_add_tasks_in_progress(job, self.db_id, tasks) + def add_tasks_in_progress( + self, session: Session, job: Job, tasks: int + ) -> None: + """See documentation of db.add_tasks_in_progress.""" + self.db.add_tasks_in_progress(session, job, self.db_id(session), tasks) @contextmanager def job_context(job_id: JobId): """Small error-handling context manager. Fails the job on exception.""" agent = make_agent() - try: - yield agent - except Exception as e: - logging.exception("Failed to execute %s.", job_id) - agent.db.fail_job(job_id, str(e)) - raise + with agent.db.session() as session: + try: + yield agent, session + except Exception as e: + logging.exception("Failed to execute %s.", job_id) + agent.db.fail_job(session, job_id, str(e)) + session.commit() + raise def make_agent(group_override: Optional[str] = None): @@ -180,19 +197,19 @@ def start_search(job_id: JobId) -> None: """Initialises a search task - checks available datasets and schedules smaller units of work. """ - with job_context(job_id) as agent: - job = agent.db.get_job(job_id) + with job_context(job_id) as (agent, session): + job = agent.db.get_job(session, job_id) if job.status == "cancelled": logging.info("Job was cancelled, returning...") return datasets = agent.get_datasets() - agent.db.init_job_datasets(job_id, len(datasets)) + agent.db.init_job_datasets(session, job_id, len(datasets)) # Sets the number of datasets in progress. # Caveat: if no datasets, this call is still important, because it # will let the db know that this agent has nothing more to do. - agent.init_search(job, len(datasets)) + agent.init_search(session, job, len(datasets)) rules = parse_yara(job.raw_yara) parsed = combine_rules(rules) @@ -230,8 +247,8 @@ def __get_batch_sizes(file_count: int) -> List[int]: def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None: """Queries ursadb and creates yara scans tasks with file batches.""" - with job_context(job_id) as agent: - job = agent.db.get_job(job_id) + with job_context(job_id) as (agent, session): + job = agent.db.get_job(session, job_id) if job.status == "cancelled": logging.info("Job was cancelled, returning...") return @@ -244,7 +261,7 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None: iterator = result["iterator"] logging.info(f"Iterator {iterator} contains {file_count} files") - total_files = agent.db.update_job_files(job_id, file_count) + total_files = agent.db.update_job_files(session, job_id, file_count) if job.files_limit and total_files > job.files_limit: raise RuntimeError( f"Too many candidates after prefiltering (limit: {job.files_limit}). " @@ -253,7 +270,7 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None: batches = __get_batch_sizes(file_count) # add len(batches) new tasks, -1 to account for this task - agent.add_tasks_in_progress(job, len(batches) - 1) + agent.add_tasks_in_progress(session, job, len(batches) - 1) for batch in batches: agent.queue.enqueue( @@ -264,13 +281,13 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None: job_timeout=app_config.rq.job_timeout, ) - agent.db.dataset_query_done(job_id) + agent.db.dataset_query_done(session, job_id) def run_yara_batch(job_id: JobId, iterator: str, batch_size: int) -> None: """Actually scans files, and updates a database with the results.""" - with job_context(job_id) as agent: - job = agent.db.get_job(job_id) + with job_context(job_id) as (agent, session): + job = agent.db.get_job(session, job_id) if job.status == "cancelled": logging.info("Job was cancelled, returning...") return @@ -288,5 +305,5 @@ def run_yara_batch(job_id: JobId, iterator: str, batch_size: int) -> None: ) return - agent.execute_yara(job, pop_result.files) - agent.add_tasks_in_progress(job, -1) + agent.execute_yara(session, job, pop_result.files) + agent.add_tasks_in_progress(session, job, -1)