Skip to content

Commit

Permalink
Feature/union typed arrays (#58)
Browse files Browse the repository at this point in the history
Solves #53 

* Added failing test based on issue #53

* Improved support for union typed objects in arrays
  • Loading branch information
noamgat authored Jan 10, 2024
1 parent a02b342 commit 48d2c4d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
27 changes: 25 additions & 2 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, MAX_CONSECUTIVE_WHITESPACES, WHITESPACE_CHARACTERS


_ANY_JSON_SCHEMA_DICT = {'anyOf': [{'type': type} for type in json_schema_data_formats.keys()]}
# No need to include the 'integer' option in the anyOf, as it is a subset of 'number'
_ANY_JSON_SCHEMA_DICT = {'anyOf': [{'type': type} for type in json_schema_data_formats.keys() if type != 'integer']}

class JsonSchemaParser(CharacterLevelParser):
ANY_JSON_OBJECT_SCHEMA: JsonSchemaObject = JsonSchemaObject(**_ANY_JSON_SCHEMA_DICT)
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(self,
self.last_non_whitespace_character = ""

def add_character(self, new_character: str) -> CharacterLevelParser:
self.context.active_parser = self
# Assumption: The top-most parser that can accept the character is the one that should accept it.
# This is different from the SequenceParser, in which we need to split (union) into all options.
receiving_idx = len(self.object_stack) - 1
Expand All @@ -73,9 +75,26 @@ def add_character(self, new_character: str) -> CharacterLevelParser:
else:
updated_parser.num_consecutive_whitespaces = 0
updated_parser.last_non_whitespace_character = new_character

if isinstance(updated_parser.object_stack[-1], UnionParser) and \
any(isinstance(parser, (ObjectParsingState, ListParsingState)) for parser in updated_parser.object_stack[-1].parsers):
# If the top parser is a union parser with "advanced" (=parsers that modify the object stack) parsers inside,
# we need to split the top level parser into the different options,
# As each "fork" can live with a different object stack, and we need to make sure they have their own ones.
option_json_schema_parsers = []
for option_parser in updated_parser.object_stack[-1].parsers:
option_stack = updated_parser.object_stack[:-1] + [option_parser]
option_parser = JsonSchemaParser(self.context, self.config, option_stack, updated_parser.num_consecutive_whitespaces)
option_parser.context.active_parser = option_parser
option_parser.last_parsed_string = last_parsed_string
option_json_schema_parsers.append(option_parser)
return UnionParser(option_json_schema_parsers)

return updated_parser

def get_allowed_characters(self) -> str:
self.context.active_parser = self

allowed_character_strs = []
for parser in reversed(self.object_stack):
# Similar to SequenceParser, if the top object can end, we need to know to accept the next character of parser below, etc.
Expand Down Expand Up @@ -535,7 +554,11 @@ def add_character(self, new_character: str) -> "ListParsingState":
parser_to_push = item_parser
else:
# If we don't require items, we can also end immediately, the Union + ForceStopParser combination achieves this
parser_to_push = UnionParser([item_parser, ForceStopParser()])
if isinstance(item_parser, UnionParser):
item_parser.parsers.append(ForceStopParser())
parser_to_push = item_parser
else:
parser_to_push = UnionParser([item_parser, ForceStopParser()])
self.root.context.active_parser.object_stack.append(parser_to_push)
elif new_character == "]":
self.seen_list_closer = True
Expand Down
15 changes: 15 additions & 0 deletions tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,18 @@ def test_more_string_constraints():
'min_4_max_6': '12_'
}.items():
_test_json_schema_parsing_with_string(f'{{"{k}": "{v}"}}', StringConstraints.model_json_schema(), False)


def test_union_typed_arrays():
class AppleSchema(BaseModel):
apple_type: int

class BananaSchema(BaseModel):
is_ripe: bool

class FruitsSchema(BaseModel):
fruits: List[Union[AppleSchema, BananaSchema]]

_test_json_schema_parsing_with_string('{"fruits": [{"apple_type": 1}, {"apple_type": 2}] }', FruitsSchema.model_json_schema(), True)
_test_json_schema_parsing_with_string('{"fruits": [{"apple_type": 1}, {"is_ripe": true}] }', FruitsSchema.model_json_schema(), True)
_test_json_schema_parsing_with_string('{"fruits": [{"apple_type": 1, "is_ripe": true}] }', FruitsSchema.model_json_schema(), False)

0 comments on commit 48d2c4d

Please sign in to comment.