Skip to content

Commit

Permalink
Add script to verify respondes against reference set of responses
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaiter committed Jun 10, 2024
1 parent 8790326 commit 90b577a
Showing 1 changed file with 188 additions and 0 deletions.
188 changes: 188 additions & 0 deletions example/verify_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3
"""
Utility tool to compare the responses from a given deployment of the example setup against
a reference set of responses. Should be used to test for regressions between releases.
Not part of the CI process, since the indexing is kind of heavy-weight.
A reference set of responses can be generated by running the following command:
./bench.py --iterations 1 --save-responses responses.jsonl.gz
"""

import argparse
import difflib
import gzip
import hashlib
import json
import sys
import textwrap
from pathlib import Path
from typing import Iterable
from urllib.parse import urlencode
from urllib.request import urlopen, Request


try:
from rich import print, get_console
from rich.console import Console
from rich.syntax import Syntax
except ImportError:
print(
"[red]Please install the [blue]rich[/blue] package to use this script ([blue]python3-rich[/blue] on Debian/Ubuntu)[/]",
file=sys.stderr,
)
sys.exit(1)


def run_query(solr_handler: str, query_params: dict) -> dict:
req = Request(f"{solr_handler}?{urlencode(query_params)}")
with urlopen(req) as http_resp:
solr_resp = json.load(http_resp)
solr_resp["responseHeader"]["queryUrl"] = req.full_url
return solr_resp


def normalize_response(response: dict) -> dict:
def _sort_dict(d):
return {
k: _sort_dict(v) if isinstance(v, dict) else v for k, v in sorted(d.items())
}

return {
k: vs
for k, vs in _sort_dict(response).items()
if k in ("response", "ocrHighlighting")
}


def _hash_snippet(snip: dict) -> str:
sha = hashlib.sha256()
sha.update(snip["text"].encode("utf-8"))
sha.update(json.dumps(snip["pages"], sort_keys=True).encode("utf-8"))
sha.update(json.dumps(snip["regions"], sort_keys=True).encode("utf-8"))
sha.update(json.dumps(snip["highlights"], sort_keys=True).encode("utf-8"))
return sha.hexdigest()


def _only_scores_differ(expected: dict, actual: dict) -> bool:
snippets_expected = {
docid: data["ocr_text"]["snippets"]
for docid, data in expected.get("ocrHighlighting", {}).items()
}
snippets_actual = {
docid: data["ocr_text"]["snippets"]
for docid, data in actual.get("ocrHighlighting", {}).items()
}
total_expected = {
docid: data["ocr_text"]["numTotal"]
for docid, data in expected.get("ocrHighlighting", {}).items()
}
total_actual = {
docid: data["ocr_text"]["numTotal"]
for docid, data in actual.get("ocrHighlighting", {}).items()
}
for docid, expected_snips in snippets_expected.items():
expected_hashed = {_hash_snippet(snip): snip for snip in expected_snips}
actual_hashed = {
_hash_snippet(snip): snip for snip in snippets_actual.get(docid, [])
}
if len(expected_hashed) != len(actual_hashed):
return False
if total_expected[docid] != total_actual[docid]:
return False
missing = [k for k in expected_hashed if k not in actual_hashed]
if len(missing) == 0:
continue
if (
len(actual_hashed) < total_expected[docid]
and len(missing) == 1
and _hash_snippet(expected_snips[-1]) == missing[0]
):
# Last snippet is missing, this is likely due to the updated sort order
continue
breakpoint()
return False

return True


def check_reponse(
solr_handler_url: str, expected_response: dict, no_colors: bool = False
):
actual_response = run_query(
solr_handler_url, expected_response["responseHeader"]["params"]
)
actual_norm = normalize_response(actual_response)
expected_norm = normalize_response(expected_response)
actual_json = json.dumps(actual_norm, indent=2)
expected_json = json.dumps(expected_norm, indent=2)
query = expected_response["responseHeader"]["params"]["q"]
if actual_json != expected_json:
if _only_scores_differ(expected_response, actual_response):
print(f"[yellow]Response match (only scores differ)[/] for query {query}")
return
print("=====")
print(f"Response mismatch for query {query}")
print_unified_diff(expected_json, actual_json, no_colors)
print(f"Query URL: {actual_response['responseHeader']['queryUrl']}")
print("=====")
print("Continue? Y/n")
if input().strip().lower() != "y":
sys.exit(1)
else:
print(f"[green]✓ Response match[/] for query {query}")


def print_unified_diff(expected_json: str, actual_json: str, no_colors: bool = False):
expected_lines = expected_json.splitlines(keepends=True)
actual_lines = actual_json.splitlines(keepends=True)
difflines = list(difflib.unified_diff(expected_lines, actual_lines, lineterm=""))
diff = "".join(difflines)

with Console() as console:
if no_colors:
syntax = diff
else:
syntax = Syntax(
diff, "diff", theme="monokai", line_numbers=False, word_wrap=True
)
if len(difflines) > console.height:
with console.pager(styles=True):
console.print(syntax)
console.print(syntax)


def load_reference_responses(path: Path) -> Iterable[dict]:
if path.suffix == ".gz":
with gzip.open(path, "rt") as f:
yield from (next(iter(json.loads(line).values())) for line in f)
else:
with path.open() as f:
yield from (next(iter(json.loads(line).values())) for line in f)


def main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"responses_file", help="Path to the reference responses file", type=str
)
parser.add_argument(
"--solr-handler",
help="Solr handler URL to query",
type=str,
default="http://localhost:8983/solr/ocr/select",
)
parser.add_argument(
"--no-diff-colors", help="Disable colored output", action="store_true"
)
args = parser.parse_args()

reference = load_reference_responses(Path(args.responses_file))
for ref in reference:
check_reponse(args.solr_handler, ref, args.no_diff_colors)


if __name__ == "__main__":
main()

0 comments on commit 90b577a

Please sign in to comment.