Skip to content

Commit

Permalink
add model_rebuild_calls for result_types.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Minister944 committed Feb 23, 2024
1 parent 193df20 commit ebd40d0
Show file tree
Hide file tree
Showing 60 changed files with 252 additions and 2 deletions.
13 changes: 11 additions & 2 deletions ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
generate_module,
generate_pass,
generate_pydantic_field,
generate_method_call,
generate_expr,
)
from ..exceptions import NotSupported, ParsingError
from ..plugins.manager import PluginManager
Expand All @@ -61,6 +63,7 @@
TYPENAME_FIELD_NAME,
TYPING_MODULE,
UNION,
MODEL_REBUILD_METHOD,
)
from .result_fields import FieldContext, is_union, parse_operation_field
from .scalars import ScalarData, generate_scalar_imports
Expand Down Expand Up @@ -152,8 +155,14 @@ def _get_operation_type_name(self, definition: ExecutableDefinitionNode) -> str:
raise NotSupported(f"Not supported operation type: {definition}")

def generate(self) -> ast.Module:
module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], self._class_defs
model_rebuild_calls = [
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
for class_def in self._class_defs
]
module_body = (
cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], self._class_defs)
+ cast(List[ast.stmt], model_rebuild_calls)
)

module = generate_module(module_body)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
GetQueryAQueryA.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/custom_config_file/expected_client/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

class Test(BaseModel):
test_query: str = Field(alias="testQuery")


Test.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
GetQueryAQueryA.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/custom_scalars/expected_client/get_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class GetATestQuery(BaseModel):
code: Annotated[Code, BeforeValidator(parse_code)]
id: int
other: Any


GetA.model_rebuild()
GetATestQuery.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/example/expected_client/create_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class CreateUser(BaseModel):

class CreateUserUserCreate(BaseModel):
id: str


CreateUser.model_rebuild()
CreateUserUserCreate.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

class GetUsersCounter(BaseModel):
users_counter: int = Field(alias="usersCounter")


GetUsersCounter.model_rebuild()
5 changes: 5 additions & 0 deletions tests/main/clients/example/expected_client/list_all_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ class ListAllUsersUsers(BaseModel):

class ListAllUsersUsersLocation(BaseModel):
country: Optional[str]


ListAllUsers.model_rebuild()
ListAllUsersUsers.model_rebuild()
ListAllUsersUsersLocation.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class ListUsersByCountry(BaseModel):

class ListUsersByCountryUsers(BasicUser, UserPersonalData):
favourite_color: Optional[Color] = Field(alias="favouriteColor")


ListUsersByCountry.model_rebuild()
ListUsersByCountryUsers.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/example/expected_client/upload_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

class UploadFile(BaseModel):
file_upload: bool = Field(alias="fileUpload")


UploadFile.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ class FragmentsWithMixinsQueryA(FragmentA, CommonMixin):

class FragmentsWithMixinsQueryB(FragmentB, CommonMixin):
pass


FragmentsWithMixins.model_rebuild()
FragmentsWithMixinsQueryA.model_rebuild()
FragmentsWithMixinsQueryB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel, MixinA, CommonMixin):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
GetQueryAQueryA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class GetQueryAWithFragment(GetQueryAFragment):
pass


GetQueryAWithFragment.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class GetQueryB(BaseModel):

class GetQueryBQueryB(BaseModel, MixinB, CommonMixin):
field_b: str = Field(alias="fieldB")


GetQueryB.model_rebuild()
GetQueryBQueryB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ class QueryWithFragmentOnSubInterfaceQueryInterfaceBaseInterface(BaseModel):

class QueryWithFragmentOnSubInterfaceQueryInterfaceInterfaceA(FragmentA):
typename__: Literal["InterfaceA"] = Field(alias="__typename")


QueryWithFragmentOnSubInterface.model_rebuild()
QueryWithFragmentOnSubInterfaceQueryInterfaceBaseInterface.model_rebuild()
QueryWithFragmentOnSubInterfaceQueryInterfaceInterfaceA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ class QueryWithFragmentOnSubInterfaceWithInlineFragmentQueryInterfaceTypeA(BaseM
id: str
value_a: str = Field(alias="valueA")
another: str


QueryWithFragmentOnSubInterfaceWithInlineFragment.model_rebuild()
QueryWithFragmentOnSubInterfaceWithInlineFragmentQueryInterfaceBaseInterface.model_rebuild()
QueryWithFragmentOnSubInterfaceWithInlineFragmentQueryInterfaceInterfaceA.model_rebuild()
QueryWithFragmentOnSubInterfaceWithInlineFragmentQueryInterfaceTypeA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ class QueryWithFragmentOnUnionMemberQueryUnionTypeA(BaseModel):

class QueryWithFragmentOnUnionMemberQueryUnionTypeB(FragmentB):
typename__: Literal["TypeB"] = Field(alias="__typename")


QueryWithFragmentOnUnionMember.model_rebuild()
QueryWithFragmentOnUnionMemberQueryUnionTypeA.model_rebuild()
QueryWithFragmentOnUnionMemberQueryUnionTypeB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ class InterfaceAQueryITypeB(BaseModel):
typename__: Literal["TypeB"] = Field(alias="__typename")
id: str
field_b: str = Field(alias="fieldB")


InterfaceA.model_rebuild()
InterfaceAQueryIInterface.model_rebuild()
InterfaceAQueryITypeA.model_rebuild()
InterfaceAQueryITypeB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ class InterfaceBQueryITypeA(BaseModel):
typename__: Literal["TypeA"] = Field(alias="__typename")
id: str
field_a: str = Field(alias="fieldA")


InterfaceB.model_rebuild()
InterfaceBQueryIInterface.model_rebuild()
InterfaceBQueryITypeA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ class InterfaceCQueryI(BaseModel):
alias="__typename"
)
id: str


InterfaceC.model_rebuild()
InterfaceCQueryI.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ class InterfaceWithTypenameQueryI(BaseModel):
alias="__typename"
)
id: str


InterfaceWithTypename.model_rebuild()
InterfaceWithTypenameQueryI.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ class ListInterfaceQueryListITypeB(BaseModel):
typename__: Literal["TypeB"] = Field(alias="__typename")
id: str
field_b: str = Field(alias="fieldB")


ListInterface.model_rebuild()
ListInterfaceQueryListIInterface.model_rebuild()
ListInterfaceQueryListITypeA.model_rebuild()
ListInterfaceQueryListITypeB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ class ListUnionQueryListUTypeB(BaseModel):

class ListUnionQueryListUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


ListUnion.model_rebuild()
ListUnionQueryListUTypeA.model_rebuild()
ListUnionQueryListUTypeB.model_rebuild()
ListUnionQueryListUTypeC.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ class QueryWithFragmentOnInterfaceQueryITypeB(BaseModel):
typename__: Literal["TypeB"] = Field(alias="__typename")
id: str
field_b: str = Field(alias="fieldB")


QueryWithFragmentOnInterface.model_rebuild()
QueryWithFragmentOnInterfaceQueryIInterface.model_rebuild()
QueryWithFragmentOnInterfaceQueryITypeA.model_rebuild()
QueryWithFragmentOnInterfaceQueryITypeB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class QueryWithFragmentOnQueryWithInterface(FragmentOnQueryWithInterface):
pass


QueryWithFragmentOnQueryWithInterface.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class QueryWithFragmentOnQueryWithUnion(FragmentOnQueryWithUnion):
pass


QueryWithFragmentOnQueryWithUnion.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ class QueryWithFragmentOnUnionQueryUTypeB(BaseModel):

class QueryWithFragmentOnUnionQueryUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


QueryWithFragmentOnUnion.model_rebuild()
QueryWithFragmentOnUnionQueryUTypeA.model_rebuild()
QueryWithFragmentOnUnionQueryUTypeB.model_rebuild()
QueryWithFragmentOnUnionQueryUTypeC.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ class UnionAQueryUTypeB(BaseModel):

class UnionAQueryUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


UnionA.model_rebuild()
UnionAQueryUTypeA.model_rebuild()
UnionAQueryUTypeB.model_rebuild()
UnionAQueryUTypeC.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ class UnionBQueryUTypeB(BaseModel):

class UnionBQueryUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


UnionB.model_rebuild()
UnionBQueryUTypeA.model_rebuild()
UnionBQueryUTypeB.model_rebuild()
UnionBQueryUTypeC.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ class ExampleQuery1(BaseModel):

class ExampleQuery1ExampleQuery(MinimalA):
value: str


ExampleQuery1.model_rebuild()
ExampleQuery1ExampleQuery.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ class ExampleQuery2(BaseModel):

class ExampleQuery2ExampleQuery(FullA):
pass


ExampleQuery2.model_rebuild()
ExampleQuery2ExampleQuery.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ class ExampleQuery3(BaseModel):

class ExampleQuery3ExampleQuery(CompleteA):
pass


ExampleQuery3.model_rebuild()
ExampleQuery3ExampleQuery.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class GetA(BaseModel):
a: str


GetA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class GetA2(BaseModel):
a: str


GetA2.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class GetB(BaseModel):
b: str


GetB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class GetD(BaseModel):
d: str


GetD.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class GetE(BaseModel):
e: str


GetE.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class GetF(BaseModel):

class GetFF(BaseModel):
val: EnumF


GetF.model_rebuild()
GetFF.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class GetG(BaseModel):

class GetGG(FragmentG):
pass


GetG.model_rebuild()
GetGG.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/operations/expected_client/c_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

class CMutation(BaseModel):
const_mutation: int = Field(alias="constMutation")


CMutation.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/operations/expected_client/c_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

class CQuery(BaseModel):
const_query: str = Field(alias="constQuery")


CQuery.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

class CSubscription(BaseModel):
const_subscription: float = Field(alias="constSubscription")


CSubscription.model_rebuild()
5 changes: 5 additions & 0 deletions tests/main/clients/operations/expected_client/get_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ class GetAA(BaseModel):

class GetAAValueB(BaseModel):
value: str


GetA.model_rebuild()
GetAA.model_rebuild()
GetAAValueB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ class GetAWithFragmentA(BaseModel):

class GetAWithFragmentAValueB(FragmentB):
pass


GetAWithFragment.model_rebuild()
GetAWithFragmentA.model_rebuild()
GetAWithFragmentAValueB.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/operations/expected_client/get_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class GetS(BaseModel):

class GetSS(BaseModel):
id: int


GetS.model_rebuild()
GetSS.model_rebuild()
Loading

0 comments on commit ebd40d0

Please sign in to comment.