Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Needs to set required fields as required

Handle cases dataclasses, TypedDict and Pydantic models
  • Loading branch information
Mattias Wigh authored and Mattias Wigh committed Oct 3, 2024
1 parent f30c36d commit 2801c74
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import typing
from typing import Any, Callable, Union, get_type_hints, get_origin, get_args
from typing_extensions import TypedDict, is_typeddict
import dataclasses

import pydantic

Expand Down Expand Up @@ -334,9 +335,11 @@ def to_contents(contents: ContentsType) -> list[protos.Content]:
return contents


def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
def _schema_for_class(cls: type) -> dict[str, Any]:
schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
properties = schema["properties"]["dummy"]

# Handling TypedDict
if is_typeddict(cls):
required_keys = []
type_hints = get_type_hints(cls)
Expand All @@ -347,6 +350,26 @@ def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
continue
required_keys.append(key)
properties["required"] = required_keys

# Handling dataclasses
elif dataclasses.is_dataclass(cls):
required_keys = []
for field in dataclasses.fields(cls):
if field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING:
required_keys.append(field.name) # Field is required if it has no default value
properties["required"] = required_keys

# Handling Pydantic models
elif issubclass(cls, pydantic.BaseModel):
required_keys = [name for name, field in cls.__fields__.items() if field.is_required()]
properties["required"] = required_keys

# Bug that it sets default values in case default exists
# TODO: Should be handled in the schema generation or not be allowed

for key in properties["properties"]:
if 'default' in properties["properties"][key]:
properties["properties"][key].pop('default')
return properties


Expand Down

0 comments on commit 2801c74

Please sign in to comment.