Skip to content

Commit

Permalink
Added option to profile unit tests (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat authored Dec 19, 2023
1 parent 114d15b commit 9f78179
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ cython_debug/
# Mac OS
.DS_Store
poetry.lock
*.prof
*.prof_stats
26 changes: 21 additions & 5 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import cProfile
from pstats import Stats
from typing import Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase

Expand Down Expand Up @@ -34,7 +36,7 @@ def assert_parser_with_string_direct(string: str, parser: CharacterLevelParser,
raise ValueError("Parser did not reach end state when it should have")


def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevelParser, expect_success: bool):
def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevelParser, expect_success: bool, profile_file_path: Optional[str]):
global _tokenizer
if _tokenizer is None:
model_id = 'TheBloke/Llama-2-7b-Chat-GPTQ'
Expand All @@ -56,6 +58,11 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel
# The token enforcer is stateful - it keeps track of the parsing state as tokens arrive on a token by token basis.
# We simulate a language model that "chooses" the next token in the encoded sequence, and check that it is in the
# allowed list at every timestep.
profiler: Optional[cProfile.Profile] = None
if profile_file_path:
profiler = cProfile.Profile()
profiler.enable()

for prefix_length in range(len(initial_token_array), len(target_token_array) + 1):
prefix = target_token_array[:prefix_length]
allowed_tokens = token_enforcer.get_allowed_tokens(prefix)
Expand All @@ -65,9 +72,9 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel
if expect_success:
decoded_before = _tokenizer.decode(prefix, skip_special_tokens=True)
decoded_after = _tokenizer.decode(prefix + [next_token], skip_special_tokens=True)
next_char = decoded_after[len(decoded_before)]
next_token_chars = decoded_after[len(decoded_before):]
next_idx = len(decoded_before) - len(prompt)
raise CharacterNotAllowedException(f"Parser does not allow '{next_char}' at index {next_idx}")
raise CharacterNotAllowedException(f"Parser does not allow '{next_token_chars}' at index {next_idx}")
else:
return # Test success
else:
Expand All @@ -77,8 +84,17 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel
raise ValueError("Parser succeeded when it should have failed")
if not can_end and expect_success:
raise ValueError("Parser did not reach end state when it should have")

if profiler and profile_file_path:
profiler.disable()
with open(profile_file_path, 'w') as stream:
stats = Stats(profiler, stream=stream)
stats.strip_dirs()
stats.sort_stats('time')
stats.dump_stats(profile_file_path + '.prof_stats')
stats.print_stats()


def assert_parser_with_string(string: str, parser: CharacterLevelParser, expect_success: bool):
def assert_parser_with_string(string: str, parser: CharacterLevelParser, expect_success: bool, profile_file_path: Optional[str] = None):
assert_parser_with_string_direct(string, parser, expect_success)
assert_parser_with_string_token_enforcer(string, parser, expect_success)
assert_parser_with_string_token_enforcer(string, parser, expect_success, profile_file_path)
20 changes: 18 additions & 2 deletions tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from .common import assert_parser_with_string, CharacterNotAllowedException


def _test_json_schema_parsing_with_string(string: str, schema_dict: Optional[dict], expect_success: bool):
def _test_json_schema_parsing_with_string(string: str, schema_dict: Optional[dict], expect_success: bool, profile_file_path: Optional[str] = None):
parser = JsonSchemaParser(schema_dict)
assert_parser_with_string(string, parser, expect_success)
assert_parser_with_string(string, parser, expect_success, profile_file_path)
if expect_success:
# If expecting success, also check minified and pretty-printed
minified = json.dumps(json.loads(string), separators=(',', ':'))
Expand Down Expand Up @@ -243,6 +243,22 @@ def test_any_json_object():
_test_json_schema_parsing_with_string('"str"', None, True)


def test_long_json_object():
from urllib.request import urlopen
import json
json_url = 'https://microsoftedge.github.io/Demos/json-dummy-data/64KB.json'
json_text = urlopen(json_url).read().decode('utf-8')
# These are several "hacks" on top of the json file in order to bypass some shortcomings of the unit testing method.
json_text = ''.join(c for c in json_text if 0 < ord(c) < 127)
json_text = json_text.replace('.",', '",')
json_text = json_text.replace(' ",', '",')
json_text = json_text.replace('.",', '",')
json_text = json.dumps(json.loads(json_text)[:20])

profile_file_path = None # '64KB.prof'
_test_json_schema_parsing_with_string(json_text, None, True, profile_file_path=profile_file_path)


def test_union():
class SchemaWithUnion(BaseModel):
key: Union[int, str]
Expand Down

0 comments on commit 9f78179

Please sign in to comment.