Skip to content

Commit

Permalink
Support enums with different types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 8, 2023
1 parent 1e628b3 commit 3753612
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def to_regex(resolver: Resolver, instance: dict):
- Handle types defined as a list
- Handle constraints on numbers
- Handle special patterns: `date`, `uri`, etc.
- Handle optional fields (not in `required`)
This does not support recursive definitions.
Parameters
----------
Expand Down Expand Up @@ -116,12 +119,14 @@ def to_regex(resolver: Resolver, instance: dict):
# The enum keyword is used to restrict a value to a fixed set of values. It
# must be an array with at least one element, where each element is unique.
elif "enum" in instance:
if instance["type"] == "string":
choices = [f'"{re.escape(choice)}"' for choice in instance["enum"]]
return f"({'|'.join(choices)})"
else:
choices = [re.escape(str(choice)) for choice in instance["enum"]]
return f"({'|'.join(choices)})"
choices = []
for choice in instance["enum"]:
if type(choice) in [int, float, bool, None]:
choices.append(re.escape(str(choice)))
elif type(choice) == str:
choices.append(f'"{re.escape(choice)}"')

return f"({'|'.join(choices)})"

elif "$ref" in instance:
path = f"{instance['$ref']}"
Expand All @@ -134,8 +139,8 @@ def to_regex(resolver: Resolver, instance: dict):
# the name of one of the basic types, and each element is unique. In this
# case, the JSON snippet is valid if it matches any of the given types.
elif "type" in instance:
type = instance["type"]
if type == "string":
instance_type = instance["type"]
if instance_type == "string":
if "maxLength" in instance or "minLength" in instance:
max_length = instance.get("maxLength", "")
min_length = instance.get("minLength", "")
Expand All @@ -156,13 +161,13 @@ def to_regex(resolver: Resolver, instance: dict):
else:
return type_to_regex["string"]

elif type == "number":
elif instance_type == "number":
return type_to_regex["number"]

elif type == "integer":
elif instance_type == "integer":
return type_to_regex["integer"]

elif type == "array":
elif instance_type == "array":
if "items" in instance:
items_regex = to_regex(resolver, instance["items"])
return rf"\[({items_regex})(,({items_regex}))*\]"
Expand All @@ -180,17 +185,19 @@ def to_regex(resolver: Resolver, instance: dict):
regexes = [to_regex(resolver, t) for t in types]
return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]"

elif type == "boolean":
elif instance_type == "boolean":
return type_to_regex["boolean"]

elif type == "null":
elif instance_type == "null":
return type_to_regex["null"]

elif isinstance(type, list):
elif isinstance(instance_type, list):
# Here we need to make the choice to exclude generating an object
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes = [to_regex(resolver, {"type": t}) for t in type if t != "object"]
regexes = [
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
]
return rf"({'|'.join(regexes)})"

raise NotImplementedError(
Expand Down

0 comments on commit 3753612

Please sign in to comment.