diff --git a/lmformatenforcer/integrations/trtllm.py b/lmformatenforcer/integrations/trtllm.py index 3847b44..9d1fca7 100644 --- a/lmformatenforcer/integrations/trtllm.py +++ b/lmformatenforcer/integrations/trtllm.py @@ -15,7 +15,8 @@ def __init__(self, token_enforcer: TokenEnforcer, eos_token_id, analyze): self.eos_token_id = eos_token_id def _trim(self, input): - return [x for x in input.tolist() if x != self.eos_token_id] + return [x for x in input.tolist() if x not in \ + (self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id])] def __call__(self, step: int, batch_input_ids: List[List[int]], logits: torch.Tensor) -> torch.Tensor: for idx in range(len(batch_input_ids)): diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 86df79f..a8082cf 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -438,12 +438,16 @@ def __init__( self.allow_floating_point = allow_floating_point self.seen_decimal_point = False self.seen_whitespace_after_digits = False + self.seen_exponent = False + self.seen_digit = False def _clone(self) -> "NumberParsingState": clone = NumberParsingState(self.root, self.allow_floating_point) clone.parsed_string = self.parsed_string clone.seen_decimal_point = self.seen_decimal_point clone.seen_whitespace_after_digits = self.seen_whitespace_after_digits + clone.seen_exponent = self.seen_exponent + clone.seen_digit = self.seen_digit return clone def add_character(self, new_character: str) -> CharacterLevelParser: @@ -455,7 +459,17 @@ def add_character(self, new_character: str) -> CharacterLevelParser: self.seen_whitespace_after_digits = True return self if new_character == ".": + if not self.parsed_string or len(self.parsed_string) == 1: + raise LMFormatEnforcerException("Numbers cannot start with a decimal point.") + if self.seen_decimal_point: + raise LMFormatEnforcerException("Numbers cannot contain more than two decimal points.") self.seen_decimal_point = True + elif new_character in "eE": + if self.seen_exponent or not self.seen_digit: + raise LMFormatEnforcerException("Invalid number format") + self.seen_exponent = True + elif new_character.isdigit(): + self.seen_digit = True return self def get_allowed_characters(self) -> str: @@ -464,13 +478,23 @@ def get_allowed_characters(self) -> str: allowed_characters = "0123456789" if not self.parsed_string: allowed_characters += "-" + WHITESPACE_CHARACTERS - if self.allow_floating_point and not self.seen_decimal_point: + if self.parsed_string and len(self.parsed_string) == 1 and self.parsed_string[0] == "0": + allowed_characters = WHITESPACE_CHARACTERS + if self.parsed_string and len(self.parsed_string) == 2 and self.parsed_string == "-0": + allowed_characters = "." + WHITESPACE_CHARACTERS + if self.parsed_string and self.parsed_string[-1] in "eE": + allowed_characters += "-+" + if self.seen_digit and not self.seen_exponent: + allowed_characters += "eE" + if self.allow_floating_point and not self.seen_decimal_point and self.seen_digit and not self.seen_exponent: allowed_characters += "." if self.parsed_string and self.parsed_string[-1].isdigit(): allowed_characters += WHITESPACE_CHARACTERS return allowed_characters def can_end(self) -> bool: + if self.seen_exponent and self.parsed_string[-1] in "eE+-": + return False return bool(self.parsed_string) and (self.parsed_string[-1].isdigit() or self.seen_whitespace_after_digits) diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index c720241..6b2534e 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field import sys -from typing import Callable, Dict, Hashable, List, Optional, Tuple +from typing import Callable, Dict, Hashable, List, Optional, Tuple, Union import logging from .exceptions import LMFormatEnforcerException @@ -14,13 +14,13 @@ class TokenEnforcerTokenizerData: def __init__(self, regular_tokens: List[Tuple[int, str, bool]], decoder: Callable[[List[int]], str], - eos_token_id: int): + eos_token_id: Union[int, List[int]]): """ Create the tokenizer data that the TokenEnforcer needs. This can be reused for multiple TokenEnforcers if they work with the same tokenizer. :param regular_tokens: A list of tuples (token_id, token_string, is_new_word_token) for all the regular (not special) tokens in the tokenizer vocabulary. Note that token_string is expected to include leading / trailing whitespaces if relevant. :param decoder: A function that decodes a list of token ids into a string. - :param eos_token_id: The token id of the end-of-string token. + :param eos_token_id: The token id(s) of the end-of-string token(s). """ self.regular_tokens = regular_tokens self.tokenizer_tree = TokenizerPrefixTree(regular_tokens) @@ -95,7 +95,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out shortcut_key = state.parser.shortcut_key() self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key) if state.parser.can_end(): - allowed_tokens.append(self.eos_token_id) + allowed_tokens.extend(self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id]) if not allowed_tokens: raise ValueError(f"Parser reached state with no allowed tokens") # root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser) @@ -115,7 +115,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out "Terminating the parser. Please open an issue at \n" "https://github.com/noamgat/lm-format-enforcer/issues with the prefix and " "CharacterLevelParser parameters") - state.allowed_tokens = [self.eos_token_id] + state.allowed_tokens = self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id] def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: TokenizerPrefixTreeNode, allowed_tokens: List[int], shortcut_key: Optional[Hashable]): allowed_tokens.extend(tree_node.tokens) diff --git a/tests/common.py b/tests/common.py index 2ff8294..014e1b4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -55,8 +55,8 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel # the parser the most. target_token_array = _tokenizer.encode(prompt + string) eos_token_id = _tokenizer.eos_token_id - if eos_token_id is None: - raise ValueError("Tokenizer does not have an EOS token") + if not eos_token_id: + raise ValueError(f"Tokenizer does not have {'an EOS token' if eos_token_id is None else 'EOS tokens'}") token_enforcer = TokenEnforcer(_tokenizer_data, parser) # The token enforcer is stateful - it keeps track of the parsing state as tokens arrive on a token by token basis. @@ -82,7 +82,7 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel return # Test success else: # Reached the end of the sequence, check that ending state matches expected ending state - can_end = eos_token_id in allowed_tokens + can_end = any(token in allowed_tokens for token in (eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])) if can_end and not expect_success: raise ValueError("Parser succeeded when it should have failed") if not can_end and expect_success: diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index 16321d0..6fe74a2 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -673,3 +673,105 @@ def test_top_level_object_inheritance(): valid_object = '{"parent": {"child": "test"}}' _test_json_schema_parsing_with_string(valid_object, schema, True) + +class NumberSchema(BaseModel): + value: float = Field(..., type="number") + + +schema = NumberSchema.model_json_schema() + +@pytest.mark.parametrize("test_input", [ + '{"value": 0}', + '{"value": 1}', + '{"value": 10}', + '{"value": 0.1}', + '{"value": 1.01}', + '{"value": -1}', + '{"value": -0.1}', + '{"value": 1e5}', + '{"value": 1.5e-5}', + '{"value": 1.5e5}', + '{"value": 1.5e+5}', + '{"value": -1.5e5}', + '{"value": -1.5e-5}', + '{"value": -1.5e+5}', + '{"value": 0.0}', + '{"value": -0.0}', + '{"value": 1.0}', + '{"value": -1.0}', + '{"value": 1.5e0}', + '{"value": -1.5e0}', + '{"value": 9007199254740991}', + '{"value": -9007199254740991}', + '{"value": 1e-323}', + '{"value": 1.7976931348623157e+308}', + '{"value": 5e-324}', + '{"value": 2.2250738585072014e-308}', +]) +def test_valid_number_formats(test_input): + _test_json_schema_parsing_with_string(test_input, schema, True) + + +@pytest.mark.parametrize("test_input", [ + '{"value": 01}', + '{"value": 00.1}', + '{"value": 01.01}', + '{"value": -01}', + '{"value": -00.1}', + '{"value": 01e5}', + '{"value": 00}', + '{"value": 00.0}', + '{"value": 00.0e5}', + '{"value": -00.0e5}', + '{"value": 0123}', + '{"value": -0123}', + '{"value": 01.23e45}', +]) +def test_invalid_number_formats_with_leading_zeros(test_input): + _test_json_schema_parsing_with_string(test_input, schema, False) + + +@pytest.mark.parametrize("test_input, expected_success", [ + ('{"value": .1}', False), + ('{"value": -.1}', False), + ('{"value": 1.}', False), + ('{"value": +1}', False), + ('{"value": 1e}', False), + ('{"value": 1e+}', False), + ('{"value": .}', False), + ('{"value": -.}', False), + ('{"value": e5}', False), + ('{"value": .e5}', False), + ('{"value": -.e5}', False), + ('{"value": 1.5e}', False), + ('{"value": 1.5e+}', False), + ('{"value": -1.5e}', False), + ('{"value": -1.5e+}', False), + ('{"value": 1.5e-}', False), + ('{"value": -1.5e-}', False), + ('{"value": 1e-}', False), + ('{"value": -1e-}', False), + ('{"value": 1e+1e2}', False), + ('{"value": 1e1.5}', False), + ('{"value": 1e-1.5}', False), + ('{"value": 1e1a}', False), + ('{"value": 1e-1a}', False), + ('{"value": 0x123}', False), + ('{"value": 0b1010}', False), + ('{"value": 0o123}', False), + ('{"value": Infinity}', False), + ('{"value": -Infinity}', False), + ('{"value": NaN}', False), + ('{"value": 1,000}', False), + ('{"value": 1_000}', False), + ('{"value": 1.2.3}', False), + ('{"value": 1e2e3}', False), + ('{"value": 1e+2e-3}', False), + ('{"value": --1}', False), + ('{"value": ++1}', False), + ('{"value": +-1}', False), + ('{"value": 9007199254740992}', True), + ('{"value": -9007199254740992}', True), +]) +def test_number_edge_cases(test_input, expected_success): + _test_json_schema_parsing_with_string(test_input, schema, expected_success) \ No newline at end of file