Skip to content

Commit

Permalink
Draft: fix: rewrite query_ursadb not to use iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
mickol34 committed Oct 9, 2024
1 parent 015a887 commit bab9068
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
15 changes: 12 additions & 3 deletions src/lib/ursadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import zmq # type: ignore
from typing import Dict, Any, List, Optional

from config import app_config
from models.queryresult import QueryResult
from db import Database, JobId


Json = Dict[str, Any]

Expand Down Expand Up @@ -37,6 +41,7 @@ def __str__(self) -> str:
class UrsaDb:
def __init__(self, backend: str) -> None:
self.backend = backend
self.redis_db = Database(app_config.redis.host, app_config.redis.port)

def __execute(self, command: str, recv_timeout: int = 2000) -> Json:
context = zmq.Context()
Expand All @@ -53,6 +58,7 @@ def __execute(self, command: str, recv_timeout: int = 2000) -> Json:
def query(
self,
query: str,
job_id: JobId,
taints: List[str] | None = None,
dataset: Optional[str] = None,
) -> Json:
Expand All @@ -63,7 +69,7 @@ def query(
command += f"with taints {taints_whole_str} "
if dataset:
command += f'with datasets ["{dataset}"] '
command += f"into iterator {query};"
command += f"{query};"

start = time.perf_counter()
res = self.__execute(command, recv_timeout=-1)
Expand All @@ -73,10 +79,13 @@ def query(
error = res.get("error", {}).get("message", "(no message)")
return {"error": f"ursadb failed: {error}"}

with self.redis_db.session() as session:
obj = QueryResult(job_id=job_id, files=res['result']['files'])
session.add(obj)
session.commit()

return {
"time": (end - start),
"iterator": res["result"]["iterator"],
"file_count": res["result"]["file_count"],
}

def pop(self, iterator: str, count: int) -> PopResult:
Expand Down
7 changes: 7 additions & 0 deletions src/models/queryresult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from sqlmodel import Field, SQLModel, ARRAY, Column, String
from typing import List


class QueryResult(SQLModel, table=True):
job_id: str = Field(foreign_key="job.internal_id", primary_key=True)
files: List[str] = Field(sa_column=Column(ARRAY(String)))
41 changes: 23 additions & 18 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from rq import get_current_job, Queue # type: ignore
from redis import Redis
from contextlib import contextmanager
from sqlalchemy import delete, update
from sqlmodel import select
import yara # type: ignore

from .db import Database, JobId
Expand All @@ -11,6 +13,7 @@
from .plugins import PluginManager
from .models.job import Job
from .models.match import Match
from .models.queryresult import QueryResult
from .lib.yaraparse import parse_yara, combine_rules
from .lib.ursadb import Json, UrsaDb
from .metadata import Metadata
Expand Down Expand Up @@ -236,13 +239,13 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
logging.info("Job was cancelled, returning...")
return

result = agent.ursa.query(ursadb_query, job.taints, dataset_id)
result = agent.ursa.query(ursadb_query, job_id, job.taints, dataset_id)
if "error" in result:
raise RuntimeError(result["error"])

file_count = result["file_count"]
iterator = result["iterator"]
logging.info(f"Iterator {iterator} contains {file_count} files")
with agent.db.session() as session:
result = session.exec(select(QueryResult).where(QueryResult.job_id == job_id)).one()
file_count = len(result.files)

total_files = agent.db.update_job_files(job_id, file_count)
if job.files_limit and total_files > job.files_limit:
Expand All @@ -259,34 +262,36 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
agent.queue.enqueue(
run_yara_batch,
job_id,
iterator,
result,
batch,
job_timeout=app_config.rq.job_timeout,
)

agent.db.dataset_query_done(job_id)


def run_yara_batch(job_id: JobId, iterator: str, batch_size: int) -> None:
def run_yara_batch(job_id: JobId, result: QueryResult, batch_size: int) -> None:
"""Actually scans files, and updates a database with the results."""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
logging.info("Job was cancelled, returning...")
return

pop_result = agent.ursa.pop(iterator, batch_size)
logging.info("job %s: Pop successful: %s", job_id, pop_result)
if pop_result.was_locked:
# Iterator is currently locked, re-enqueue self
agent.queue.enqueue(
run_yara_batch,
job_id,
iterator,
batch_size,
job_timeout=app_config.rq.job_timeout,
## 1. get batch_size first files from result
batch_files = result.files[0:batch_size]

## 2. remove batch files from result
with agent.db.session() as session:
session.execute(
update(QueryResult).where(QueryResult.job_id == result.job_id).values(files=result.files[batch_size+1:])
)
return

agent.execute_yara(job, pop_result.files)
## 3. if result has no files, delete
session.execute(
delete(QueryResult).where(QueryResult.job_id == job_id).where(QueryResult.files == [])
)
session.commit()

agent.execute_yara(job, batch_files)
agent.add_tasks_in_progress(job, -1)

0 comments on commit bab9068

Please sign in to comment.