diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index ebd7d19..bb91f70 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -10,7 +10,8 @@ from .consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, MAX_CONSECUTIVE_WHITESPACES, WHITESPACE_CHARACTERS -_ANY_JSON_SCHEMA_DICT = {'anyOf': [{'type': type} for type in json_schema_data_formats.keys()]} +# No need to include the 'integer' option in the anyOf, as it is a subset of 'number' +_ANY_JSON_SCHEMA_DICT = {'anyOf': [{'type': type} for type in json_schema_data_formats.keys() if type != 'integer']} class JsonSchemaParser(CharacterLevelParser): ANY_JSON_OBJECT_SCHEMA: JsonSchemaObject = JsonSchemaObject(**_ANY_JSON_SCHEMA_DICT) @@ -53,6 +54,7 @@ def __init__(self, self.last_non_whitespace_character = "" def add_character(self, new_character: str) -> CharacterLevelParser: + self.context.active_parser = self # Assumption: The top-most parser that can accept the character is the one that should accept it. # This is different from the SequenceParser, in which we need to split (union) into all options. receiving_idx = len(self.object_stack) - 1 @@ -73,9 +75,26 @@ def add_character(self, new_character: str) -> CharacterLevelParser: else: updated_parser.num_consecutive_whitespaces = 0 updated_parser.last_non_whitespace_character = new_character + + if isinstance(updated_parser.object_stack[-1], UnionParser) and \ + any(isinstance(parser, (ObjectParsingState, ListParsingState)) for parser in updated_parser.object_stack[-1].parsers): + # If the top parser is a union parser with "advanced" (=parsers that modify the object stack) parsers inside, + # we need to split the top level parser into the different options, + # As each "fork" can live with a different object stack, and we need to make sure they have their own ones. + option_json_schema_parsers = [] + for option_parser in updated_parser.object_stack[-1].parsers: + option_stack = updated_parser.object_stack[:-1] + [option_parser] + option_parser = JsonSchemaParser(self.context, self.config, option_stack, updated_parser.num_consecutive_whitespaces) + option_parser.context.active_parser = option_parser + option_parser.last_parsed_string = last_parsed_string + option_json_schema_parsers.append(option_parser) + return UnionParser(option_json_schema_parsers) + return updated_parser def get_allowed_characters(self) -> str: + self.context.active_parser = self + allowed_character_strs = [] for parser in reversed(self.object_stack): # Similar to SequenceParser, if the top object can end, we need to know to accept the next character of parser below, etc. @@ -535,7 +554,11 @@ def add_character(self, new_character: str) -> "ListParsingState": parser_to_push = item_parser else: # If we don't require items, we can also end immediately, the Union + ForceStopParser combination achieves this - parser_to_push = UnionParser([item_parser, ForceStopParser()]) + if isinstance(item_parser, UnionParser): + item_parser.parsers.append(ForceStopParser()) + parser_to_push = item_parser + else: + parser_to_push = UnionParser([item_parser, ForceStopParser()]) self.root.context.active_parser.object_stack.append(parser_to_push) elif new_character == "]": self.seen_list_closer = True diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index e447d15..a030b95 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -303,3 +303,18 @@ def test_more_string_constraints(): 'min_4_max_6': '12_' }.items(): _test_json_schema_parsing_with_string(f'{{"{k}": "{v}"}}', StringConstraints.model_json_schema(), False) + + +def test_union_typed_arrays(): + class AppleSchema(BaseModel): + apple_type: int + + class BananaSchema(BaseModel): + is_ripe: bool + + class FruitsSchema(BaseModel): + fruits: List[Union[AppleSchema, BananaSchema]] + + _test_json_schema_parsing_with_string('{"fruits": [{"apple_type": 1}, {"apple_type": 2}] }', FruitsSchema.model_json_schema(), True) + _test_json_schema_parsing_with_string('{"fruits": [{"apple_type": 1}, {"is_ripe": true}] }', FruitsSchema.model_json_schema(), True) + _test_json_schema_parsing_with_string('{"fruits": [{"apple_type": 1, "is_ripe": true}] }', FruitsSchema.model_json_schema(), False)