diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index 0c597b84d..614227531 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -63,6 +63,9 @@ def to_regex(resolver: Resolver, instance: dict): - Handle types defined as a list - Handle constraints on numbers - Handle special patterns: `date`, `uri`, etc. + - Handle optional fields (not in `required`) + + This does not support recursive definitions. Parameters ---------- @@ -116,12 +119,14 @@ def to_regex(resolver: Resolver, instance: dict): # The enum keyword is used to restrict a value to a fixed set of values. It # must be an array with at least one element, where each element is unique. elif "enum" in instance: - if instance["type"] == "string": - choices = [f'"{re.escape(choice)}"' for choice in instance["enum"]] - return f"({'|'.join(choices)})" - else: - choices = [re.escape(str(choice)) for choice in instance["enum"]] - return f"({'|'.join(choices)})" + choices = [] + for choice in instance["enum"]: + if type(choice) in [int, float, bool, None]: + choices.append(re.escape(str(choice))) + elif type(choice) == str: + choices.append(f'"{re.escape(choice)}"') + + return f"({'|'.join(choices)})" elif "$ref" in instance: path = f"{instance['$ref']}" @@ -134,8 +139,8 @@ def to_regex(resolver: Resolver, instance: dict): # the name of one of the basic types, and each element is unique. In this # case, the JSON snippet is valid if it matches any of the given types. elif "type" in instance: - type = instance["type"] - if type == "string": + instance_type = instance["type"] + if instance_type == "string": if "maxLength" in instance or "minLength" in instance: max_length = instance.get("maxLength", "") min_length = instance.get("minLength", "") @@ -156,13 +161,13 @@ def to_regex(resolver: Resolver, instance: dict): else: return type_to_regex["string"] - elif type == "number": + elif instance_type == "number": return type_to_regex["number"] - elif type == "integer": + elif instance_type == "integer": return type_to_regex["integer"] - elif type == "array": + elif instance_type == "array": if "items" in instance: items_regex = to_regex(resolver, instance["items"]) return rf"\[({items_regex})(,({items_regex}))*\]" @@ -180,17 +185,19 @@ def to_regex(resolver: Resolver, instance: dict): regexes = [to_regex(resolver, t) for t in types] return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]" - elif type == "boolean": + elif instance_type == "boolean": return type_to_regex["boolean"] - elif type == "null": + elif instance_type == "null": return type_to_regex["null"] - elif isinstance(type, list): + elif isinstance(instance_type, list): # Here we need to make the choice to exclude generating an object # if the specification of the object is not give, even though a JSON # object that contains an object here would be valid under the specification. - regexes = [to_regex(resolver, {"type": t}) for t in type if t != "object"] + regexes = [ + to_regex(resolver, {"type": t}) for t in instance_type if t != "object" + ] return rf"({'|'.join(regexes)})" raise NotImplementedError(