diff --git a/lmformatenforcer/characterlevelparser.py b/lmformatenforcer/characterlevelparser.py index 1958138..186eab9 100644 --- a/lmformatenforcer/characterlevelparser.py +++ b/lmformatenforcer/characterlevelparser.py @@ -127,8 +127,11 @@ def get_allowed_characters(self) -> str: def can_end(self) -> bool: return any([parser.can_end() for parser in self.parsers]) - def shortcut_key(self) -> Optional[str]: - return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None + def shortcut_key(self) -> Optional[Hashable]: + unique_shortcut_keys = set(parser.shortcut_key() for parser in self.parsers) + if len(unique_shortcut_keys) == 1: + return next(iter(unique_shortcut_keys)) + return None def cache_key(self) -> Optional[Hashable]: all_cache_keys = tuple(parser.cache_key() for parser in self.parsers) diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index a8082cf..d5d67a9 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -160,7 +160,8 @@ def __init__(self, root: JsonSchemaParser): def _merge_object_schemas(base_schema: JsonSchemaObject, option_schema: JsonSchemaObject) -> JsonSchemaObject: - for property_name, property_value in base_schema.properties.items(): + base_schema_properties = base_schema.properties or {} + for property_name, property_value in base_schema_properties.items(): # We assume that if a property exists in both base and option, the option version will be # more specific, therefore we only take missing entries if property_name not in option_schema.properties: @@ -201,13 +202,13 @@ def get_parser( max_length=value_schema.maxLength, pattern=value_schema.pattern, ) + if value_schema.oneOf: + # We create a combined object schema for each option that includes the information from the parent + # And then create a UnionParser based on the combined options + merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf] + object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas] + return UnionParser(object_parsing_options) elif value_schema.type == "object": - if value_schema.oneOf: - # We create a combined object schema for each option that includes the information from the parent - # And then create a UnionParser based on the combined options - merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf] - object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas] - return UnionParser(object_parsing_options) return ObjectParsingState(value_schema, parsing_state) elif value_schema.type == None and value_schema.ref: value_class_name = value_schema.ref.split('/')[-1] diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index e58bd6b..554f274 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -792,4 +792,9 @@ def test_invalid_number_formats_with_leading_zeros(test_input): ('{"value": -9007199254740992}', True), ]) def test_number_edge_cases(test_input, expected_success): - _test_json_schema_parsing_with_string(test_input, schema, expected_success) \ No newline at end of file + _test_json_schema_parsing_with_string(test_input, schema, expected_success) + +def test_chinese_oneof_schema(): + test_schema = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "array", "items": { "oneOf": [ { "type": "object", "properties": { "trigger": { "type": "string" }, "event_type": { "enum": [ "公司上市" ] }, "arguments": { "type": "array", "items": { "type": "object", "properties": { "role": { "enum": [ "上市公司", "证券代码", "环节", "披露时间", "发行价格", "事件时间", "市值", "募资金额" ] }, "argument": { "type": "string" } }, "required": [ "role", "argument" ] } } }, "required": [ "trigger", "event_type", "arguments" ] }, { "type": "object", "properties": { "trigger": { "type": "string" }, "event_type": { "enum": [ "被约谈" ] }, "arguments": { "type": "array", "items": { "type": "object", "properties": { "role": { "enum": [ "公司名称", "披露时间", "被约谈时间", "约谈机构" ] }, "argument": { "type": "string" } }, "required": [ "role", "argument" ] } } }, "required": [ "trigger", "event_type", "arguments" ] } ] } } + correct_output = """[{"trigger": "IPO", "event_type": "公司上市", "arguments": [{"role": "上市公司", "argument": "理想汽车"}, {"role": "披露时间", "argument": "30日"}, {"role": "发行价格", "argument": "8-10美元"}, {"role": "环节", "argument": "筹备上市"}]}]""" + _test_json_schema_parsing_with_string(correct_output, test_schema, True)