diff --git a/outlines/generate/json.py b/outlines/generate/json.py index 703447958..f7c7278cd 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -1,9 +1,9 @@ import json as pyjson from enum import Enum from functools import singledispatch -from typing import Callable, Optional, Union +from typing import Callable, Optional, Type, Union, cast -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from outlines.fsm.json_schema import ( build_regex_from_schema, @@ -20,7 +20,7 @@ @singledispatch def json( model, - schema_object: Union[str, object, Callable], + schema_object: Union[str, Type[BaseModel], TypeAdapter, Type[Enum], Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, ) -> SequenceGeneratorAdapter: @@ -52,12 +52,17 @@ def json( schema = pyjson.dumps(schema_object.model_json_schema()) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: schema_object.parse_raw(x) + generator.format_sequence = lambda x: schema_object.model_validate_json(x) elif isinstance(schema_object, type(Enum)): schema = pyjson.dumps(get_schema_from_enum(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) + elif isinstance(schema_object, TypeAdapter): + schema = pyjson.dumps(schema_object.json_schema()) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: schema_object.validate_json(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) @@ -80,7 +85,9 @@ def json( @json.register(OpenAI) def json_openai( - model, schema_object: Union[str, object], sampler: Sampler = multinomial() + model, + schema_object: Union[str, Type[BaseModel], TypeAdapter], + sampler: Sampler = multinomial(), ): if not isinstance(sampler, multinomial): raise NotImplementedError( @@ -90,7 +97,10 @@ def json_openai( if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - format_sequence = lambda x: schema_object.parse_raw(x) + format_sequence = lambda x: schema_object.model_validate_json(x) + elif isinstance(schema_object, TypeAdapter): + schema = pyjson.dumps(schema_object.json_schema()) + format_sequence = lambda x: cast(TypeAdapter, schema_object).validate_json(x) elif isinstance(schema_object, str): schema = schema_object format_sequence = lambda x: pyjson.loads(x)