Skip to content

Commit

Permalink
fix: fixed missing location during configuration parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Feb 7, 2025
1 parent 8024c64 commit 45e04d1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
44 changes: 30 additions & 14 deletions aidial_sdk/chat_completion/configuration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Literal, Optional, Type, TypeVar

from pydantic.v1.fields import FieldInfo
from pydantic.v1.main import ModelMetaclass
from pydantic.v1.validators import make_literal_validator

from aidial_sdk.pydantic_v1 import BaseModel, root_validator
from aidial_sdk.pydantic_v1 import BaseModel, validator

_T = TypeVar("_T")

Expand All @@ -28,7 +30,33 @@ def schema(self) -> dict:
}


class Configuration(BaseModel):
class _ConfigurationMetaclass(ModelMetaclass):
def __new__(mcs, name, bases, namespace: dict, **kwargs):
validators = {}
for field_name, field_info in namespace.items():
if not isinstance(field_info, FieldInfo):
continue

buttons = field_info.extra.get("buttons") or []
assert all(isinstance(button, Button) for button in buttons)

consts = tuple(button.const for button in buttons)
literal_type = Literal[consts]
literal_validator = make_literal_validator(literal_type)

def _validate(value, values, config, field):
return literal_validator(value)

validators[f"_validate_{field_name}"] = validator(
field_name, allow_reuse=True
)(_validate)

namespace.update(validators)

return super().__new__(mcs, name, bases, namespace, **kwargs)


class Configuration(BaseModel, metaclass=_ConfigurationMetaclass):
class Config:
extra = "forbid"

Expand All @@ -37,18 +65,6 @@ def schema_extra(schema, model: Type["Configuration"]):
model._handle_top_level_extensions(schema)
model._handle_buttons_extension(schema)

@root_validator()
def _validate_button_value(cls, values: Dict[str, Any]) -> Dict[str, Any]:
for field_name, field in cls.__fields__.items():
buttons = field.field_info.extra.get("buttons")
if buttons and field_name in values:
value = values[field_name]
type = Literal[tuple(button.const for button in buttons)]
validator = make_literal_validator(type)
validator(value)

return values

@classmethod
def _handle_top_level_extensions(cls, schema: Dict[str, Any]) -> None:
if (
Expand Down
8 changes: 5 additions & 3 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class StaticConfiguration(Configuration):
"My model"
"Static application configuration"

_dial_chatMessageInputDisabled = True

Expand Down Expand Up @@ -42,7 +42,7 @@ def test_configuration_schema():
actual_schema = StaticConfiguration.schema()
assert actual_schema == {
"title": "StaticConfiguration",
"description": "My model",
"description": "Static application configuration",
"type": "object",
"properties": {
"int_field": {
Expand Down Expand Up @@ -122,9 +122,11 @@ def test_configuration_parsing_fail():
except ValidationError as e:
assert e.errors() == [
{
"loc": ("__root__",), # FIXME
"loc": ("int_button_field",),
"msg": "unexpected value; permitted: 10, 20",
"type": "value_error.const",
"ctx": {"given": 11, "permitted": (10, 20)},
}
]
else:
assert False, "Expected ValidationError"

0 comments on commit 45e04d1

Please sign in to comment.