-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script to verify respondes against reference set of responses
- Loading branch information
Showing
1 changed file
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Utility tool to compare the responses from a given deployment of the example setup against | ||
a reference set of responses. Should be used to test for regressions between releases. | ||
Not part of the CI process, since the indexing is kind of heavy-weight. | ||
A reference set of responses can be generated by running the following command: | ||
./bench.py --iterations 1 --save-responses responses.jsonl.gz | ||
""" | ||
|
||
import argparse | ||
import difflib | ||
import gzip | ||
import hashlib | ||
import json | ||
import sys | ||
import textwrap | ||
from pathlib import Path | ||
from typing import Iterable | ||
from urllib.parse import urlencode | ||
from urllib.request import urlopen, Request | ||
|
||
|
||
try: | ||
from rich import print, get_console | ||
from rich.console import Console | ||
from rich.syntax import Syntax | ||
except ImportError: | ||
print( | ||
"[red]Please install the [blue]rich[/blue] package to use this script ([blue]python3-rich[/blue] on Debian/Ubuntu)[/]", | ||
file=sys.stderr, | ||
) | ||
sys.exit(1) | ||
|
||
|
||
def run_query(solr_handler: str, query_params: dict) -> dict: | ||
req = Request(f"{solr_handler}?{urlencode(query_params)}") | ||
with urlopen(req) as http_resp: | ||
solr_resp = json.load(http_resp) | ||
solr_resp["responseHeader"]["queryUrl"] = req.full_url | ||
return solr_resp | ||
|
||
|
||
def normalize_response(response: dict) -> dict: | ||
def _sort_dict(d): | ||
return { | ||
k: _sort_dict(v) if isinstance(v, dict) else v for k, v in sorted(d.items()) | ||
} | ||
|
||
return { | ||
k: vs | ||
for k, vs in _sort_dict(response).items() | ||
if k in ("response", "ocrHighlighting") | ||
} | ||
|
||
|
||
def _hash_snippet(snip: dict) -> str: | ||
sha = hashlib.sha256() | ||
sha.update(snip["text"].encode("utf-8")) | ||
sha.update(json.dumps(snip["pages"], sort_keys=True).encode("utf-8")) | ||
sha.update(json.dumps(snip["regions"], sort_keys=True).encode("utf-8")) | ||
sha.update(json.dumps(snip["highlights"], sort_keys=True).encode("utf-8")) | ||
return sha.hexdigest() | ||
|
||
|
||
def _only_scores_differ(expected: dict, actual: dict) -> bool: | ||
snippets_expected = { | ||
docid: data["ocr_text"]["snippets"] | ||
for docid, data in expected.get("ocrHighlighting", {}).items() | ||
} | ||
snippets_actual = { | ||
docid: data["ocr_text"]["snippets"] | ||
for docid, data in actual.get("ocrHighlighting", {}).items() | ||
} | ||
total_expected = { | ||
docid: data["ocr_text"]["numTotal"] | ||
for docid, data in expected.get("ocrHighlighting", {}).items() | ||
} | ||
total_actual = { | ||
docid: data["ocr_text"]["numTotal"] | ||
for docid, data in actual.get("ocrHighlighting", {}).items() | ||
} | ||
for docid, expected_snips in snippets_expected.items(): | ||
expected_hashed = {_hash_snippet(snip): snip for snip in expected_snips} | ||
actual_hashed = { | ||
_hash_snippet(snip): snip for snip in snippets_actual.get(docid, []) | ||
} | ||
if len(expected_hashed) != len(actual_hashed): | ||
return False | ||
if total_expected[docid] != total_actual[docid]: | ||
return False | ||
missing = [k for k in expected_hashed if k not in actual_hashed] | ||
if len(missing) == 0: | ||
continue | ||
if ( | ||
len(actual_hashed) < total_expected[docid] | ||
and len(missing) == 1 | ||
and _hash_snippet(expected_snips[-1]) == missing[0] | ||
): | ||
# Last snippet is missing, this is likely due to the updated sort order | ||
continue | ||
breakpoint() | ||
return False | ||
|
||
return True | ||
|
||
|
||
def check_reponse( | ||
solr_handler_url: str, expected_response: dict, no_colors: bool = False | ||
): | ||
actual_response = run_query( | ||
solr_handler_url, expected_response["responseHeader"]["params"] | ||
) | ||
actual_norm = normalize_response(actual_response) | ||
expected_norm = normalize_response(expected_response) | ||
actual_json = json.dumps(actual_norm, indent=2) | ||
expected_json = json.dumps(expected_norm, indent=2) | ||
query = expected_response["responseHeader"]["params"]["q"] | ||
if actual_json != expected_json: | ||
if _only_scores_differ(expected_response, actual_response): | ||
print(f"[yellow]Response match (only scores differ)[/] for query {query}") | ||
return | ||
print("=====") | ||
print(f"Response mismatch for query {query}") | ||
print_unified_diff(expected_json, actual_json, no_colors) | ||
print(f"Query URL: {actual_response['responseHeader']['queryUrl']}") | ||
print("=====") | ||
print("Continue? Y/n") | ||
if input().strip().lower() != "y": | ||
sys.exit(1) | ||
else: | ||
print(f"[green]✓ Response match[/] for query {query}") | ||
|
||
|
||
def print_unified_diff(expected_json: str, actual_json: str, no_colors: bool = False): | ||
expected_lines = expected_json.splitlines(keepends=True) | ||
actual_lines = actual_json.splitlines(keepends=True) | ||
difflines = list(difflib.unified_diff(expected_lines, actual_lines, lineterm="")) | ||
diff = "".join(difflines) | ||
|
||
with Console() as console: | ||
if no_colors: | ||
syntax = diff | ||
else: | ||
syntax = Syntax( | ||
diff, "diff", theme="monokai", line_numbers=False, word_wrap=True | ||
) | ||
if len(difflines) > console.height: | ||
with console.pager(styles=True): | ||
console.print(syntax) | ||
console.print(syntax) | ||
|
||
|
||
def load_reference_responses(path: Path) -> Iterable[dict]: | ||
if path.suffix == ".gz": | ||
with gzip.open(path, "rt") as f: | ||
yield from (next(iter(json.loads(line).values())) for line in f) | ||
else: | ||
with path.open() as f: | ||
yield from (next(iter(json.loads(line).values())) for line in f) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description=__doc__, formatter_class=argparse.RawTextHelpFormatter | ||
) | ||
parser.add_argument( | ||
"responses_file", help="Path to the reference responses file", type=str | ||
) | ||
parser.add_argument( | ||
"--solr-handler", | ||
help="Solr handler URL to query", | ||
type=str, | ||
default="http://localhost:8983/solr/ocr/select", | ||
) | ||
parser.add_argument( | ||
"--no-diff-colors", help="Disable colored output", action="store_true" | ||
) | ||
args = parser.parse_args() | ||
|
||
reference = load_reference_responses(Path(args.responses_file)) | ||
for ref in reference: | ||
check_reponse(args.solr_handler, ref, args.no_diff_colors) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |