diff --git a/example/verify_responses.py b/example/verify_responses.py new file mode 100755 index 00000000..ce591109 --- /dev/null +++ b/example/verify_responses.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +""" +Utility tool to compare the responses from a given deployment of the example setup against +a reference set of responses. Should be used to test for regressions between releases. +Not part of the CI process, since the indexing is kind of heavy-weight. + +A reference set of responses can be generated by running the following command: + + ./bench.py --iterations 1 --save-responses responses.jsonl.gz +""" + +import argparse +import difflib +import gzip +import hashlib +import json +import sys +import textwrap +from pathlib import Path +from typing import Iterable +from urllib.parse import urlencode +from urllib.request import urlopen, Request + + +try: + from rich import print, get_console + from rich.console import Console + from rich.syntax import Syntax +except ImportError: + print( + "[red]Please install the [blue]rich[/blue] package to use this script ([blue]python3-rich[/blue] on Debian/Ubuntu)[/]", + file=sys.stderr, + ) + sys.exit(1) + + +def run_query(solr_handler: str, query_params: dict) -> dict: + req = Request(f"{solr_handler}?{urlencode(query_params)}") + with urlopen(req) as http_resp: + solr_resp = json.load(http_resp) + solr_resp["responseHeader"]["queryUrl"] = req.full_url + return solr_resp + + +def normalize_response(response: dict) -> dict: + def _sort_dict(d): + return { + k: _sort_dict(v) if isinstance(v, dict) else v for k, v in sorted(d.items()) + } + + return { + k: vs + for k, vs in _sort_dict(response).items() + if k in ("response", "ocrHighlighting") + } + + +def _hash_snippet(snip: dict) -> str: + sha = hashlib.sha256() + sha.update(snip["text"].encode("utf-8")) + sha.update(json.dumps(snip["pages"], sort_keys=True).encode("utf-8")) + sha.update(json.dumps(snip["regions"], sort_keys=True).encode("utf-8")) + sha.update(json.dumps(snip["highlights"], sort_keys=True).encode("utf-8")) + return sha.hexdigest() + + +def _only_scores_differ(expected: dict, actual: dict) -> bool: + snippets_expected = { + docid: data["ocr_text"]["snippets"] + for docid, data in expected.get("ocrHighlighting", {}).items() + } + snippets_actual = { + docid: data["ocr_text"]["snippets"] + for docid, data in actual.get("ocrHighlighting", {}).items() + } + total_expected = { + docid: data["ocr_text"]["numTotal"] + for docid, data in expected.get("ocrHighlighting", {}).items() + } + total_actual = { + docid: data["ocr_text"]["numTotal"] + for docid, data in actual.get("ocrHighlighting", {}).items() + } + for docid, expected_snips in snippets_expected.items(): + expected_hashed = {_hash_snippet(snip): snip for snip in expected_snips} + actual_hashed = { + _hash_snippet(snip): snip for snip in snippets_actual.get(docid, []) + } + if len(expected_hashed) != len(actual_hashed): + return False + if total_expected[docid] != total_actual[docid]: + return False + missing = [k for k in expected_hashed if k not in actual_hashed] + if len(missing) == 0: + continue + if ( + len(actual_hashed) < total_expected[docid] + and len(missing) == 1 + and _hash_snippet(expected_snips[-1]) == missing[0] + ): + # Last snippet is missing, this is likely due to the updated sort order + continue + breakpoint() + return False + + return True + + +def check_reponse( + solr_handler_url: str, expected_response: dict, no_colors: bool = False +): + actual_response = run_query( + solr_handler_url, expected_response["responseHeader"]["params"] + ) + actual_norm = normalize_response(actual_response) + expected_norm = normalize_response(expected_response) + actual_json = json.dumps(actual_norm, indent=2) + expected_json = json.dumps(expected_norm, indent=2) + query = expected_response["responseHeader"]["params"]["q"] + if actual_json != expected_json: + if _only_scores_differ(expected_response, actual_response): + print(f"[yellow]Response match (only scores differ)[/] for query {query}") + return + print("=====") + print(f"Response mismatch for query {query}") + print_unified_diff(expected_json, actual_json, no_colors) + print(f"Query URL: {actual_response['responseHeader']['queryUrl']}") + print("=====") + print("Continue? Y/n") + if input().strip().lower() != "y": + sys.exit(1) + else: + print(f"[green]✓ Response match[/] for query {query}") + + +def print_unified_diff(expected_json: str, actual_json: str, no_colors: bool = False): + expected_lines = expected_json.splitlines(keepends=True) + actual_lines = actual_json.splitlines(keepends=True) + difflines = list(difflib.unified_diff(expected_lines, actual_lines, lineterm="")) + diff = "".join(difflines) + + with Console() as console: + if no_colors: + syntax = diff + else: + syntax = Syntax( + diff, "diff", theme="monokai", line_numbers=False, word_wrap=True + ) + if len(difflines) > console.height: + with console.pager(styles=True): + console.print(syntax) + console.print(syntax) + + +def load_reference_responses(path: Path) -> Iterable[dict]: + if path.suffix == ".gz": + with gzip.open(path, "rt") as f: + yield from (next(iter(json.loads(line).values())) for line in f) + else: + with path.open() as f: + yield from (next(iter(json.loads(line).values())) for line in f) + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "responses_file", help="Path to the reference responses file", type=str + ) + parser.add_argument( + "--solr-handler", + help="Solr handler URL to query", + type=str, + default="http://localhost:8983/solr/ocr/select", + ) + parser.add_argument( + "--no-diff-colors", help="Disable colored output", action="store_true" + ) + args = parser.parse_args() + + reference = load_reference_responses(Path(args.responses_file)) + for ref in reference: + check_reponse(args.solr_handler, ref, args.no_diff_colors) + + +if __name__ == "__main__": + main()