Skip to content

Commit

Permalink
Add support for string min/max length to json_freetext (#43)
Browse files Browse the repository at this point in the history
* Add support for min/max string length to json_freetext

* Cleanup and document JsonFreetextTokenCache, add attribution
  • Loading branch information
elonen authored Dec 28, 2023
1 parent 39ad002 commit 2b7deb9
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 77 deletions.
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
MIT License

Copyright (c) 2023 Noam Gat
2023 Jarno Elonen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
6 changes: 3 additions & 3 deletions lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def can_end(self) -> bool:
"""Return True if the parser is in a state where it can end (potentially finished parsing the desired structure), and False otherwise."""
raise NotImplementedError()

def shortcut_key(self) -> Optional[str]:
"""Optional. Return a string that denotes that this state is a repeating state, full tree traversal should be avoided."""
def shortcut_key(self) -> Optional[Hashable]:
"""Optional. Return a key that denotes that this state is a repeating state, full tree traversal should be avoided."""
return None

def cache_key(self) -> Optional[Hashable]:
"""Optional. Return a string that denotes that this state is a repeating state, and if it is visited again, results can be cached."""
"""Optional. Return a key that denotes that this state is a repeating state, and if it is visited again, results can be cached."""
return None

@property
Expand Down
15 changes: 10 additions & 5 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
import enum
from typing import Any, List, Optional, Union, cast
import sys
from typing import Hashable, List, Optional, Union, cast


from .external.jsonschemaobject import JsonSchemaObject, json_schema_data_formats
Expand Down Expand Up @@ -97,16 +98,20 @@ def get_allowed_characters(self) -> str:
def can_end(self) -> bool:
return all(parser.can_end() for parser in self.object_stack)

def shortcut_key(self) -> Optional[str]:
def shortcut_key(self) -> Optional[Hashable]:
if self.object_stack:
current_parser = self.object_stack[-1]
if isinstance(current_parser, StringParsingState):
if not current_parser.allowed_strings and current_parser.seen_opening_quote and not current_parser.seen_closing_quote \
and current_parser.min_length is None and current_parser.max_length is None:
if not current_parser.allowed_strings and current_parser.seen_opening_quote and not current_parser.seen_closing_quote:
# Performance optimization: When we are parsing a string that is not from a list of allowed strings, most tokens
# are legal. The exploration can be more costly than the LM itself for large tokenizers (because this is pure python),
# so we signal that we are in a "freetext" mode, and reuse the allowed token list throughout the run.
return 'json_freetext'
cur_len = len(current_parser.parsed_string)
min_len = current_parser.min_length or 0
max_len = current_parser.max_length or sys.maxsize
assert min_len <= max_len, "Invalid schema for str: min length is larger than max length"
if cur_len < max_len:
return ('json_freetext', cur_len, min_len, max_len)
return None


Expand Down
17 changes: 12 additions & 5 deletions lmformatenforcer/tokenenforcer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
import sys
from typing import Callable, Dict, Hashable, List, Optional, Tuple
import logging

Expand Down Expand Up @@ -115,7 +116,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out
"CharacterLevelParser parameters")
state.allowed_tokens = [self.eos_token_id]

def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: TokenizerPrefixTreeNode, allowed_tokens: List[int], shortcut_key: Optional[str]):
def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: TokenizerPrefixTreeNode, allowed_tokens: List[int], shortcut_key: Optional[Hashable]):
allowed_tokens.extend(tree_node.tokens)
allowed_characters = parser.get_allowed_characters()
relevant_characters = tree_node.children.keys()
Expand All @@ -125,12 +126,19 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token
# Performance optimization: If we are in JSON freetext, all of the tokens that don't contain quote, or end with quote, are legal, so we take
# their cached list. If the quote character is allowed, we only need to dynamically explore the cases where the string starts with a quote.
# This breaks the elegance of the API, but otherwise it is a huge performance hit.
if shortcut_key == 'json_freetext':
allowed_tokens.extend(self.tokenizer_tree.json_freetext_tokens)
if isinstance(shortcut_key, tuple) and shortcut_key[0] == 'json_freetext':
assert len(shortcut_key) == 4
_, cur_len, min_len, max_len = shortcut_key
cache = self.tokenizer_tree.json_freetext_tokens

min_remaining = min(cache.max_token_len, max(0, min_len - cur_len)) # no " allowed before this many chars
max_allowed_len = min(cache.max_token_len, max_len - cur_len) # max new characters allowed (before ")

allowed_tokens.extend(cache.lookup_allowed_tokens(min_remaining, max_allowed_len))
characters_to_explore = characters_to_explore.intersection(['"'])

for character in characters_to_explore:
next_parser = parser.add_character(character )
next_parser = parser.add_character(character)
next_tree_node = tree_node.children[character]
self._collect_allowed_tokens(next_parser, next_tree_node, allowed_tokens, None)

Expand All @@ -155,4 +163,3 @@ def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_
return new_state



103 changes: 87 additions & 16 deletions lmformatenforcer/tokenizerprefixtree.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,108 @@
from collections import OrderedDict
from typing import Dict, List, Set, Tuple
import json

class TokenizerPrefixTreeNode:
def __init__(self):
def __init__(self) -> None:
self.tokens: List[int] = []
self.children: Dict[str, TokenizerPrefixTreeNode] = {}


class JsonFreetextTokenCache:
"""
JSON string can contain almost any unicode character, so creating a list of allowed tokens is very expensive.
The list can be cached, but JSON Schema also allows 'minLength' and 'maxLength' constraint on the string,
that make some tokens illegal depending on how long the generated string is already. This class precalculates
a separate allowlist for all possible constraint states up to maximum token length (16 in Llama, for example).
After deduplication, this results in about ~75 lists for the Llama tokenizer.
"""
def __init__(self, ) -> None:
self.token_str_to_num: Dict[str, int] = {}
self.allowlist_cache: Dict[Tuple[int, int], Tuple[int, ...]] = {}
self.max_token_len = 0


def add_token(self, token_str: str, token_int: int):
assert not self.allowlist_cache, "Cannot add more tokens after allowlists were precalculated"

has_non_trailing_backslash = "\\" in token_str[:-1]
has_quote_before_end = '"' in token_str[0:-1]
has_newline = "\n" in token_str or "\r" in token_str
if has_non_trailing_backslash or has_quote_before_end or has_newline:
try:
json.loads(f'"{token_str}"')
except json.decoder.JSONDecodeError:
return # Illegal inside JSON string, skip this token

self.token_str_to_num[token_str] = token_int
self.max_token_len = max(self.max_token_len, len(token_str))


def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, ...]:
"""
Get the list of tokens that are allowed within a JSON string, such that:
1. all candidate tokens are at most `max_len` characters long (excluding the trailing quote), and
2. if a token ends with a quote, it's at least `min_remaining` chars long (excluding the quote).
"""
return self.allowlist_cache[(min_remaining, max_len)]


def freeze(self) -> None:
"""
Precalculate token allowlists for all valid combinations of `min_remaining` and `max_len`
based on the tokens that were added with `add_token()`.
"""
all_tokens: List[str] = sorted(self.token_str_to_num.keys())
assert all_tokens, "Cannot precalculate allowlists for an empty token list"
assert not any(t == '' for t in all_tokens), "Tokenizer must not contain empty tokens"

def _valid_for_min_remaining(token, min_remaining):
return not token.endswith('"') or len(token.rstrip('"')) >= min_remaining

def _valid_for_max_len(token, max_len):
return len(token.rstrip('"')) <= max_len

# Make a 2D array of constrained allowlists, indexed by tuple `(min_remaining, max_len)`
token_lists = {}
for min_remaining in range(self.max_token_len + 1):
for max_len in range(self.max_token_len + 1):
if max_len >= min_remaining: # Skip combinations that are never used
token_lists[(min_remaining, max_len)] = tuple(sorted([
token for token in all_tokens
if _valid_for_min_remaining(token, min_remaining) and _valid_for_max_len(token, max_len)
]))

# Deduplicate the lists to save RAM as many of them will be identical
unique_lists = set(token_lists.values())
for key, lst in token_lists.items():
for uniq in unique_lists:
if len(uniq) == len(lst) and uniq == lst:
token_lists[key] = uniq
break

# Turn token strings into token numbers
self.allowlist_cache = {
key: tuple(self.token_str_to_num[t] for t in lst)
for key, lst in token_lists.items()
}
del self.token_str_to_num


class TokenizerPrefixTree:
def __init__(self, regular_tokens: List[Tuple[int, str, bool]]):
self.root = TokenizerPrefixTreeNode()
self.json_freetext_tokens: List[int] = []
self.json_freetext_tokens = JsonFreetextTokenCache()
self.new_word_tokens: Set[int] = set()
self.tokens_to_strs = {token_idx: token_str for token_idx, token_str, _ in regular_tokens}
for token_idx, decoded, is_new_word in regular_tokens:
self._add_token_to_tree(decoded, token_idx, self.root)
# Performance optimization - cache the tokens of all the strings that don't contain a quote in the middle, or a line break.
# When we are in a JSON freetext string field, they will all be permitted and this will save a lot of tree iterations.
has_quote_before_end = '"' in decoded[0:-1]
has_newline = "\n" in decoded or "\r" in decoded

if not (has_quote_before_end or has_newline):
if '\\' in decoded[:-1]:
# If there is a backslash that is not trailing, we might be in an illegal json territory. Need to verify
# that is is a legal json character streak
try:
json.loads(f'"{decoded}"')
except json.decoder.JSONDecodeError:
continue
self.json_freetext_tokens.append(token_idx)
self.json_freetext_tokens.add_token(decoded, token_idx)
if is_new_word:
self.new_word_tokens.add(token_idx)

self.json_freetext_tokens.freeze()


def _add_token_to_tree(self, token_str: str, token_idx: int, node: TokenizerPrefixTreeNode):
for character in token_str:
if character not in node.children:
Expand Down
Loading

0 comments on commit 2b7deb9

Please sign in to comment.