Skip to content

Commit

Permalink
v0.9.10 - JsonSchemaParser supports anyOf (#96)
Browse files Browse the repository at this point in the history
* Added multi function schema support to json schema parser

* Cleaning up json schema parser so there aren't erroneous logs that affect debugging

* v0.9.10
  • Loading branch information
noamgat authored May 3, 2024
1 parent 1e1cd77 commit 9ef1c90
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# LM Format Enforcer Changelog

## v0.9.10
- [#95] Added anyOf support to JsonSchemaParser, making function calls possible.

## v0.9.9
- Updated README with vLLM OpenAI Server Inference integration

Expand Down
32 changes: 29 additions & 3 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def add_character(self, new_character: str) -> CharacterLevelParser:
# This is different from the SequenceParser, in which we need to split (union) into all options.
receiving_idx = len(self.object_stack) - 1
last_parsed_string = self.last_parsed_string
while new_character not in self.object_stack[receiving_idx].get_allowed_characters():
while receiving_idx >= 0 and new_character not in self.object_stack[receiving_idx].get_allowed_characters():
finished_receiver = self.object_stack[receiving_idx]
if isinstance(finished_receiver, StringParsingState):
last_parsed_string = finished_receiver.parsed_string
Expand All @@ -70,15 +70,16 @@ def add_character(self, new_character: str) -> CharacterLevelParser:
updated_parser = JsonSchemaParser(self.context, self.config, updated_stack, self.num_consecutive_whitespaces)
updated_parser.context.active_parser = updated_parser
updated_parser.last_parsed_string = last_parsed_string
updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character)
if receiving_idx >= 0:
updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character)
if new_character in WHITESPACE_CHARACTERS:
updated_parser.num_consecutive_whitespaces += 1
updated_parser.last_non_whitespace_character = self.last_non_whitespace_character
else:
updated_parser.num_consecutive_whitespaces = 0
updated_parser.last_non_whitespace_character = new_character

if isinstance(updated_parser.object_stack[-1], UnionParser) and \
if updated_parser.object_stack and isinstance(updated_parser.object_stack[-1], UnionParser) and \
any(isinstance(parser, (ObjectParsingState, ListParsingState)) for parser in updated_parser.object_stack[-1].parsers):
# If the top parser is a union parser with "advanced" (=parsers that modify the object stack) parsers inside,
# we need to split the top level parser into the different options,
Expand Down Expand Up @@ -150,6 +151,18 @@ def __init__(self, root: JsonSchemaParser):
self.root = root


def _merge_object_schemas(base_schema: JsonSchemaObject, option_schema: JsonSchemaObject) -> JsonSchemaObject:
for property_name, property_value in base_schema.properties.items():
# We assume that if a property exists in both base and option, the option version will be
# more specific, therefore we only take missing entries
if property_name not in option_schema.properties:
option_schema.properties[property_name] = property_value
for required_property in base_schema.required:
if required_property not in option_schema.required:
option_schema.required.append(required_property)
return option_schema


def get_parser(
parsing_state: JsonSchemaParser,
value_schema: JsonSchemaObject
Expand All @@ -159,6 +172,13 @@ def get_parser(
if value_schema.anyOf:
parsers = [get_parser(parsing_state, schema) for schema in value_schema.anyOf]
return UnionParser(parsers)
if value_schema.extras and 'const' in value_schema.extras:
allowed_value = value_schema.extras['const']
is_string = type(allowed_value) == str
return StringParsingState(parsing_state,
[allowed_value],
require_opening_quote=is_string,
require_closing_quote=is_string)
if value_schema.type == "string":
return StringParsingState(
parsing_state,
Expand All @@ -169,6 +189,12 @@ def get_parser(
pattern=value_schema.pattern,
)
elif value_schema.type == "object":
if value_schema.oneOf:
# We create a combined object schema for each option that includes the information from the parent
# And then create a UnionParser based on the combined options
merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf]
object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas]
return UnionParser(object_parsing_options)
return ObjectParsingState(value_schema, parsing_state)
elif value_schema.type == None and value_schema.ref:
value_class_name = value_schema.ref.split('/')[-1]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "lm-format-enforcer"
version = "0.9.9"
version = "0.9.10"
description = "Enforce the output format (JSON Schema, Regex etc) of a language model"
authors = ["Noam Gat <[email protected]>"]
license = "MIT"
Expand Down
129 changes: 128 additions & 1 deletion tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,131 @@ def test_comma_cannot_start_list_2():
}"""
class FlightRoute(BaseModel):
airports: List[str]
_test_json_schema_parsing_with_string(output_notok, FlightRoute.model_json_schema(), False)
_test_json_schema_parsing_with_string(output_notok, FlightRoute.model_json_schema(), False)


def test_multi_function_schema():
# https://github.com/noamgat/lm-format-enforcer/issues/95
_multi_function_schema = {
"type": "object",
"properties": {
"name": {
"type": "string",
"enum": [
"sums",
"concat"
]
}
},
"oneOf": [
{
"properties": {
"name": {
"const": "sums"
},
"arguments": {
"properties": {
"a": {
"type": "integer"
},
"b": {
"default": 1,
"type": "integer"
}
},
"required": [
"a"
],
"type": "object"
}
}
},
{
"properties": {
"name": {
"const": "concat"
},
"arguments": {
"properties": {
"c": {
"type": "string"
},
"d": {
"default": 1,
"type": "string"
}
},
"required": [
"c"
],
"type": "object"
}
}
}
],
"required": [
"name",
"arguments"
]
}
valid_examples = [
"""{"name": "concat", "arguments": {"c": "hello", "d": "world"}}""",
"""{"name": "sums", "arguments": {"a": 1}}""",
]
invalid_examples = [
"""{"name": "concat", "arguments": {"b": 1}}""",
"""{"name": "concat", "arguments": {"a": 1}}""",
"""{"name": "concat"}""",
"""{"name": "badname", "arguments": {"c": "hello", "b": "world"}}""",
]
for example in valid_examples:
_test_json_schema_parsing_with_string(example, _multi_function_schema, True)
for example in invalid_examples:
_test_json_schema_parsing_with_string(example, _multi_function_schema, False)


def test_top_level_array_object():
test_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"arguments": {
"type": "object"
},
"name": {
"type": "string"
}
},
"required": [
"name",
"arguments"
]
},
"minItems": 1
}
valid_result = """[
{
"name": "sums",
"arguments": {
"a": 5,
"b": 6
}
},
{
"name": "sums",
"arguments": {
"a": 2,
"b": 7
}
},
{
"name": "subtraction",
"arguments": {
"c": 3,
"d": 3
}
}]"""
invalid_result = valid_result[:-1]
_test_json_schema_parsing_with_string(valid_result, test_schema, True)
_test_json_schema_parsing_with_string(invalid_result, test_schema, False)

0 comments on commit 9ef1c90

Please sign in to comment.