Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Non-Strict mode in generate.json #1159

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/reference/generation/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ print(result)
generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*")
```

!!! Note "Non-Strict Mode"
Because models may exhaust their context window before a valid schema is generated, an error resulting from from an invalid generation may occur. This is particularly troublesome when an error interrupts a batch workload. To ensure `generate.json` returns a dict containing error details for invalid sequences rather than raising an error, use the following:

```python
generator = generate.json(model, User, strict=False)
```


!!! Note "Performance"

`generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once.
Expand Down
29 changes: 26 additions & 3 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def json(
schema_object: Union[str, object, Callable],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
strict=True,
) -> SequenceGeneratorAdapter:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Expand All @@ -36,28 +37,50 @@ def json(
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
strict
If strict mode is enabled, generations which don't conform to the schema or aren't
valid JSON will result in an error. Outlines guarantees generation complies with a schema,
but schemas often allow for infinite repetition and exhaust the model_max_length.
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the schema_object and
transforms the result if BaseModel is used.
"""

def maybe_strict_formatter(formatter):
"""If strict, use normal formatter. Otherwise, return error dict on failure"""
if strict:
return formatter

def allow_fail_formatter(generated_output):
try:
return formatter(generated_output)
except Exception as e:
return {
"error": str(e),
"error_type": type(e).__name__,
"output": generated_output,
}

return allow_fail_formatter

if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
generator.format_sequence = maybe_strict_formatter(schema_object.parse_raw)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
generator.format_sequence = maybe_strict_formatter(pyjson.loads)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
generator.format_sequence = maybe_strict_formatter(pyjson.loads)
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
Expand Down
115 changes: 115 additions & 0 deletions tests/generate/test_generate_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import string

import pytest
from pydantic import BaseModel, ValidationError

from outlines import generate


class MockCharacterTokenizer:
def __init__(self):
characters = set(
string.ascii_letters
+ string.digits
+ string.punctuation
+ string.whitespace
)
self.vocabulary = {tok: tok_id for tok_id, tok in enumerate(characters)}
self.vocabulary["eos"] = len(characters)
self.special_tokens = {"eos"}
self.eos_token_id = len(characters)

def convert_token_to_string(self, token):
return token


class MockModel:
def __init__(self, generated):
self.generated = generated
self.tokenizer = MockCharacterTokenizer()

def generate(self, *args, **kwargs):
return self.generated


mock_json_schema = json.dumps(
{
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
"additionalProperties": False,
}
)


class MockPydanticModel(BaseModel):
message: str


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_success(schema):
model = MockModel(generated='{"message": "foo"}')
generator = generate.json(model, schema)
generator("hi")


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_success_batch(schema):
model = MockModel(
generated=[
'{"message": "foo"}',
'{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"}',
]
)
generator = generate.json(model, schema)
for output in generator("hi"):
pass


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_fail(schema):
model = MockModel(generated='{"message": "foo')
generator = generate.json(model, schema)
with pytest.raises((json.decoder.JSONDecodeError, ValidationError)):
generator("hi")


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_fail_batch(schema):
model = MockModel(
generated=[
'{"message": "foo"}',
'{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"',
]
)
generator = generate.json(model, schema)
with pytest.raises((json.decoder.JSONDecodeError, ValidationError)):
generator("hi")


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_non_strict_evade_failure(schema):
model = MockModel(generated='{"message": "foo')
generator = generate.json(model, schema, strict=False)
result = generator("hi")
assert result["error_type"] in ("JSONDecodeError", "ValidationError")
assert result["output"] == model.generated


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_non_strict_evade_failure_batch(schema):
model = MockModel(
generated=[
'{"message": "foo"}',
'{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"',
]
)
generator = generate.json(model, schema, strict=False)
result = generator("hi")
if isinstance(schema, str):
assert result[0] == json.loads(model.generated[0])
else:
assert result[0] == schema.parse_raw(model.generated[0])
assert result[1]["error_type"] in ("JSONDecodeError", "ValidationError")
assert result[1]["output"] == model.generated[1]
Loading