diff --git a/.gitignore b/.gitignore index da6c9c2..667cbbf 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,5 @@ cython_debug/ # Mac OS .DS_Store poetry.lock +*.prof +*.prof_stats diff --git a/tests/common.py b/tests/common.py index c229a94..1c68ac9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,3 +1,5 @@ +import cProfile +from pstats import Stats from typing import Optional from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -34,7 +36,7 @@ def assert_parser_with_string_direct(string: str, parser: CharacterLevelParser, raise ValueError("Parser did not reach end state when it should have") -def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevelParser, expect_success: bool): +def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevelParser, expect_success: bool, profile_file_path: Optional[str]): global _tokenizer if _tokenizer is None: model_id = 'TheBloke/Llama-2-7b-Chat-GPTQ' @@ -56,6 +58,11 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel # The token enforcer is stateful - it keeps track of the parsing state as tokens arrive on a token by token basis. # We simulate a language model that "chooses" the next token in the encoded sequence, and check that it is in the # allowed list at every timestep. + profiler: Optional[cProfile.Profile] = None + if profile_file_path: + profiler = cProfile.Profile() + profiler.enable() + for prefix_length in range(len(initial_token_array), len(target_token_array) + 1): prefix = target_token_array[:prefix_length] allowed_tokens = token_enforcer.get_allowed_tokens(prefix) @@ -65,9 +72,9 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel if expect_success: decoded_before = _tokenizer.decode(prefix, skip_special_tokens=True) decoded_after = _tokenizer.decode(prefix + [next_token], skip_special_tokens=True) - next_char = decoded_after[len(decoded_before)] + next_token_chars = decoded_after[len(decoded_before):] next_idx = len(decoded_before) - len(prompt) - raise CharacterNotAllowedException(f"Parser does not allow '{next_char}' at index {next_idx}") + raise CharacterNotAllowedException(f"Parser does not allow '{next_token_chars}' at index {next_idx}") else: return # Test success else: @@ -77,8 +84,17 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel raise ValueError("Parser succeeded when it should have failed") if not can_end and expect_success: raise ValueError("Parser did not reach end state when it should have") + + if profiler and profile_file_path: + profiler.disable() + with open(profile_file_path, 'w') as stream: + stats = Stats(profiler, stream=stream) + stats.strip_dirs() + stats.sort_stats('time') + stats.dump_stats(profile_file_path + '.prof_stats') + stats.print_stats() -def assert_parser_with_string(string: str, parser: CharacterLevelParser, expect_success: bool): +def assert_parser_with_string(string: str, parser: CharacterLevelParser, expect_success: bool, profile_file_path: Optional[str] = None): assert_parser_with_string_direct(string, parser, expect_success) - assert_parser_with_string_token_enforcer(string, parser, expect_success) + assert_parser_with_string_token_enforcer(string, parser, expect_success, profile_file_path) diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index e606444..17eeaa7 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -9,9 +9,9 @@ from .common import assert_parser_with_string, CharacterNotAllowedException -def _test_json_schema_parsing_with_string(string: str, schema_dict: Optional[dict], expect_success: bool): +def _test_json_schema_parsing_with_string(string: str, schema_dict: Optional[dict], expect_success: bool, profile_file_path: Optional[str] = None): parser = JsonSchemaParser(schema_dict) - assert_parser_with_string(string, parser, expect_success) + assert_parser_with_string(string, parser, expect_success, profile_file_path) if expect_success: # If expecting success, also check minified and pretty-printed minified = json.dumps(json.loads(string), separators=(',', ':')) @@ -243,6 +243,22 @@ def test_any_json_object(): _test_json_schema_parsing_with_string('"str"', None, True) +def test_long_json_object(): + from urllib.request import urlopen + import json + json_url = 'https://microsoftedge.github.io/Demos/json-dummy-data/64KB.json' + json_text = urlopen(json_url).read().decode('utf-8') + # These are several "hacks" on top of the json file in order to bypass some shortcomings of the unit testing method. + json_text = ''.join(c for c in json_text if 0 < ord(c) < 127) + json_text = json_text.replace('.",', '",') + json_text = json_text.replace(' ",', '",') + json_text = json_text.replace('.",', '",') + json_text = json.dumps(json.loads(json_text)[:20]) + + profile_file_path = None # '64KB.prof' + _test_json_schema_parsing_with_string(json_text, None, True, profile_file_path=profile_file_path) + + def test_union(): class SchemaWithUnion(BaseModel): key: Union[int, str]