From 9ef1c90a016d939c7209a7130e0e8aef41c92d93 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Fri, 3 May 2024 11:58:01 +0300 Subject: [PATCH] v0.9.10 - JsonSchemaParser supports anyOf (#96) * Added multi function schema support to json schema parser * Cleaning up json schema parser so there aren't erroneous logs that affect debugging * v0.9.10 --- CHANGELOG.md | 3 + lmformatenforcer/jsonschemaparser.py | 32 ++++++- pyproject.toml | 2 +- tests/test_jsonschemaparser.py | 129 ++++++++++++++++++++++++++- 4 files changed, 161 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a42de7..7c7fd05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # LM Format Enforcer Changelog +## v0.9.10 +- [#95] Added anyOf support to JsonSchemaParser, making function calls possible. + ## v0.9.9 - Updated README with vLLM OpenAI Server Inference integration diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 1dedfe4..6d895d5 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -60,7 +60,7 @@ def add_character(self, new_character: str) -> CharacterLevelParser: # This is different from the SequenceParser, in which we need to split (union) into all options. receiving_idx = len(self.object_stack) - 1 last_parsed_string = self.last_parsed_string - while new_character not in self.object_stack[receiving_idx].get_allowed_characters(): + while receiving_idx >= 0 and new_character not in self.object_stack[receiving_idx].get_allowed_characters(): finished_receiver = self.object_stack[receiving_idx] if isinstance(finished_receiver, StringParsingState): last_parsed_string = finished_receiver.parsed_string @@ -70,7 +70,8 @@ def add_character(self, new_character: str) -> CharacterLevelParser: updated_parser = JsonSchemaParser(self.context, self.config, updated_stack, self.num_consecutive_whitespaces) updated_parser.context.active_parser = updated_parser updated_parser.last_parsed_string = last_parsed_string - updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character) + if receiving_idx >= 0: + updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character) if new_character in WHITESPACE_CHARACTERS: updated_parser.num_consecutive_whitespaces += 1 updated_parser.last_non_whitespace_character = self.last_non_whitespace_character @@ -78,7 +79,7 @@ def add_character(self, new_character: str) -> CharacterLevelParser: updated_parser.num_consecutive_whitespaces = 0 updated_parser.last_non_whitespace_character = new_character - if isinstance(updated_parser.object_stack[-1], UnionParser) and \ + if updated_parser.object_stack and isinstance(updated_parser.object_stack[-1], UnionParser) and \ any(isinstance(parser, (ObjectParsingState, ListParsingState)) for parser in updated_parser.object_stack[-1].parsers): # If the top parser is a union parser with "advanced" (=parsers that modify the object stack) parsers inside, # we need to split the top level parser into the different options, @@ -150,6 +151,18 @@ def __init__(self, root: JsonSchemaParser): self.root = root +def _merge_object_schemas(base_schema: JsonSchemaObject, option_schema: JsonSchemaObject) -> JsonSchemaObject: + for property_name, property_value in base_schema.properties.items(): + # We assume that if a property exists in both base and option, the option version will be + # more specific, therefore we only take missing entries + if property_name not in option_schema.properties: + option_schema.properties[property_name] = property_value + for required_property in base_schema.required: + if required_property not in option_schema.required: + option_schema.required.append(required_property) + return option_schema + + def get_parser( parsing_state: JsonSchemaParser, value_schema: JsonSchemaObject @@ -159,6 +172,13 @@ def get_parser( if value_schema.anyOf: parsers = [get_parser(parsing_state, schema) for schema in value_schema.anyOf] return UnionParser(parsers) + if value_schema.extras and 'const' in value_schema.extras: + allowed_value = value_schema.extras['const'] + is_string = type(allowed_value) == str + return StringParsingState(parsing_state, + [allowed_value], + require_opening_quote=is_string, + require_closing_quote=is_string) if value_schema.type == "string": return StringParsingState( parsing_state, @@ -169,6 +189,12 @@ def get_parser( pattern=value_schema.pattern, ) elif value_schema.type == "object": + if value_schema.oneOf: + # We create a combined object schema for each option that includes the information from the parent + # And then create a UnionParser based on the combined options + merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf] + object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas] + return UnionParser(object_parsing_options) return ObjectParsingState(value_schema, parsing_state) elif value_schema.type == None and value_schema.ref: value_class_name = value_schema.ref.split('/')[-1] diff --git a/pyproject.toml b/pyproject.toml index ed5a86f..e179ffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lm-format-enforcer" -version = "0.9.9" +version = "0.9.10" description = "Enforce the output format (JSON Schema, Regex etc) of a language model" authors = ["Noam Gat "] license = "MIT" diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index ac5b027..68f7cd3 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -370,4 +370,131 @@ def test_comma_cannot_start_list_2(): }""" class FlightRoute(BaseModel): airports: List[str] - _test_json_schema_parsing_with_string(output_notok, FlightRoute.model_json_schema(), False) \ No newline at end of file + _test_json_schema_parsing_with_string(output_notok, FlightRoute.model_json_schema(), False) + + +def test_multi_function_schema(): + # https://github.com/noamgat/lm-format-enforcer/issues/95 + _multi_function_schema = { + "type": "object", + "properties": { + "name": { + "type": "string", + "enum": [ + "sums", + "concat" + ] + } + }, + "oneOf": [ + { + "properties": { + "name": { + "const": "sums" + }, + "arguments": { + "properties": { + "a": { + "type": "integer" + }, + "b": { + "default": 1, + "type": "integer" + } + }, + "required": [ + "a" + ], + "type": "object" + } + } + }, + { + "properties": { + "name": { + "const": "concat" + }, + "arguments": { + "properties": { + "c": { + "type": "string" + }, + "d": { + "default": 1, + "type": "string" + } + }, + "required": [ + "c" + ], + "type": "object" + } + } + } + ], + "required": [ + "name", + "arguments" + ] + } + valid_examples = [ + """{"name": "concat", "arguments": {"c": "hello", "d": "world"}}""", + """{"name": "sums", "arguments": {"a": 1}}""", + ] + invalid_examples = [ + """{"name": "concat", "arguments": {"b": 1}}""", + """{"name": "concat", "arguments": {"a": 1}}""", + """{"name": "concat"}""", + """{"name": "badname", "arguments": {"c": "hello", "b": "world"}}""", + ] + for example in valid_examples: + _test_json_schema_parsing_with_string(example, _multi_function_schema, True) + for example in invalid_examples: + _test_json_schema_parsing_with_string(example, _multi_function_schema, False) + + +def test_top_level_array_object(): + test_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "arguments": { + "type": "object" + }, + "name": { + "type": "string" + } + }, + "required": [ + "name", + "arguments" + ] + }, + "minItems": 1 + } + valid_result = """[ + { + "name": "sums", + "arguments": { + "a": 5, + "b": 6 + } + }, + { + "name": "sums", + "arguments": { + "a": 2, + "b": 7 + } + }, + { + "name": "subtraction", + "arguments": { + "c": 3, + "d": 3 + } + }]""" + invalid_result = valid_result[:-1] + _test_json_schema_parsing_with_string(valid_result, test_schema, True) + _test_json_schema_parsing_with_string(invalid_result, test_schema, False)