Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove model_rebuild calls #241

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## UNRELEASED

- Removed `model_rebuild` calls for generated input, fragment and result models.


## 0.10.0 (2023-11-15)

- Fixed generating results for nullable fields with nullable directives.
Expand Down
29 changes: 0 additions & 29 deletions EXAMPLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,6 @@ class NotificationsPreferencesInput(BaseModel):
receive_push_notifications: bool = Field(alias="receivePushNotifications")
receive_sms: bool = Field(alias="receiveSms")
title: str


UserCreateInput.model_rebuild()
LocationInput.model_rebuild()
UserPreferencesInput.model_rebuild()
NotificationsPreferencesInput.model_rebuild()
```

### Enums
Expand Down Expand Up @@ -400,10 +394,6 @@ class CreateUser(BaseModel):

class CreateUserUserCreate(BaseModel):
id: str


CreateUser.model_rebuild()
CreateUserUserCreate.model_rebuild()
```

```py
Expand All @@ -430,11 +420,6 @@ class ListAllUsersUsers(BaseModel):

class ListAllUsersUsersLocation(BaseModel):
country: Optional[str]


ListAllUsers.model_rebuild()
ListAllUsersUsers.model_rebuild()
ListAllUsersUsersLocation.model_rebuild()
```

```py
Expand All @@ -455,10 +440,6 @@ class ListUsersByCountry(BaseModel):

class ListUsersByCountryUsers(BasicUser, UserPersonalData):
favourite_color: Optional[Color] = Field(alias="favouriteColor")


ListUsersByCountry.model_rebuild()
ListUsersByCountryUsers.model_rebuild()
```

```py
Expand All @@ -471,9 +452,6 @@ from .base_model import BaseModel

class GetUsersCounter(BaseModel):
users_counter: int = Field(alias="usersCounter")


GetUsersCounter.model_rebuild()
```

```py
Expand All @@ -486,9 +464,6 @@ from .base_model import BaseModel

class UploadFile(BaseModel):
file_upload: bool = Field(alias="fileUpload")


UploadFile.model_rebuild()
```

### Fragments file
Expand All @@ -513,10 +488,6 @@ class BasicUser(BaseModel):
class UserPersonalData(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")


BasicUser.model_rebuild()
UserPersonalData.model_rebuild()
```

### Init file
Expand Down
1 change: 0 additions & 1 deletion ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
FIELD_CLASS = "Field"
ALIAS_KEYWORD = "alias"
DISCRIMINATOR_KEYWORD = "discriminator"
MODEL_REBUILD_METHOD = "model_rebuild"
MODEL_VALIDATE_METHOD = "model_validate"
PLAIN_SERIALIZER = "PlainSerializer"
BEFORE_VALIDATOR = "BeforeValidator"
Expand Down
13 changes: 3 additions & 10 deletions ariadne_codegen/client_generators/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from graphql import FragmentDefinitionNode, GraphQLSchema

from ..codegen import generate_expr, generate_method_call, generate_module
from ..codegen import generate_module
from ..plugins.manager import PluginManager
from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD
from .constants import BASE_MODEL_IMPORT
from .result_types import ResultTypesGenerator
from .scalars import ScalarData

Expand Down Expand Up @@ -59,15 +59,8 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
sorted_class_defs = self._get_sorted_class_defs(
class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict
)
model_rebuild_calls = [
generate_expr(generate_method_call(c.name, MODEL_REBUILD_METHOD))
for c in sorted_class_defs
]

module = generate_module(
body=cast(List[ast.stmt], imports)
+ cast(List[ast.stmt], sorted_class_defs)
+ cast(List[ast.stmt], model_rebuild_calls)
body=cast(List[ast.stmt], imports) + cast(List[ast.stmt], sorted_class_defs)
)
if self.plugin_manager:
module = self.plugin_manager.generate_fragments_module(
Expand Down
13 changes: 2 additions & 11 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
generate_ann_assign,
generate_class_def,
generate_constant,
generate_expr,
generate_import_from,
generate_keyword,
generate_method_call,
generate_module,
generate_pydantic_field,
)
Expand All @@ -30,7 +28,6 @@
BASE_MODEL_IMPORT,
FIELD_CLASS,
LIST,
MODEL_REBUILD_METHOD,
OPTIONAL,
PLAIN_SERIALIZER,
PYDANTIC_MODULE,
Expand Down Expand Up @@ -84,14 +81,8 @@ def generate(self) -> ast.Module:
scalar_data = self.custom_scalars[scalar_name]
self._imports.extend(generate_scalar_imports(scalar_data))

model_rebuild_calls = [
generate_expr(generate_method_call(c.name, MODEL_REBUILD_METHOD))
for c in self._class_defs
]
module_body = (
cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], self._class_defs)
+ cast(List[ast.stmt], model_rebuild_calls)
module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], self._class_defs
)
module = generate_module(body=module_body)
if self.plugin_manager:
Expand Down
13 changes: 2 additions & 11 deletions ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
generate_ann_assign,
generate_class_def,
generate_constant,
generate_expr,
generate_import_from,
generate_method_call,
generate_module,
generate_pass,
generate_pydantic_field,
Expand All @@ -57,7 +55,6 @@
MIXIN_FROM_NAME,
MIXIN_IMPORT_NAME,
MIXIN_NAME,
MODEL_REBUILD_METHOD,
OPTIONAL,
PYDANTIC_MODULE,
TYPENAME_ALIAS,
Expand Down Expand Up @@ -155,14 +152,8 @@ def _get_operation_type_name(self, definition: ExecutableDefinitionNode) -> str:
raise NotSupported(f"Not supported operation type: {definition}")

def generate(self) -> ast.Module:
model_rebuild_calls = [
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
for class_def in self._class_defs
]
module_body = (
cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], self._class_defs)
+ cast(List[ast.stmt], model_rebuild_calls)
module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], self._class_defs
)

module = generate_module(module_body)
Expand Down
60 changes: 0 additions & 60 deletions tests/client_generators/input_types_generator/test_method_calls.py

This file was deleted.

This file was deleted.

46 changes: 2 additions & 44 deletions tests/client_generators/test_fragments_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import pytest
from graphql import FragmentDefinitionNode, GraphQLSchema, build_schema, parse

from ariadne_codegen.client_generators.constants import (
ALIAS_KEYWORD,
MODEL_REBUILD_METHOD,
)
from ariadne_codegen.client_generators.constants import ALIAS_KEYWORD
from ariadne_codegen.client_generators.fragments import FragmentsGenerator

from ..utils import compare_ast, filter_ast_objects, filter_class_defs
from ..utils import compare_ast, filter_class_defs


@pytest.fixture
Expand Down Expand Up @@ -154,45 +151,6 @@ def test_generate_returns_module_without_models_for_excluded_fragments(
assert [c.name for c in generated_class_defs] == ["FragmentA"]


def test_generate_returns_module_with_update_refs_calls(
schema, fragment_a, test_fragment
):
expected_update_refs_calls = [
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="FragmentA"), attr=MODEL_REBUILD_METHOD
),
args=[],
keywords=[],
)
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="TestFragment"), attr=MODEL_REBUILD_METHOD
),
args=[],
keywords=[],
)
),
]
generator = FragmentsGenerator(
schema=schema,
enums_module_name="enums",
fragments_definitions={
"TestFragment": test_fragment,
"FragmentA": fragment_a,
},
convert_to_snake_case=True,
)

module = generator.generate()

generated_update_refs_calls = filter_ast_objects(module, ast.Expr)
assert compare_ast(generated_update_refs_calls, expected_update_refs_calls)


def test_generate_triggers_generate_fragments_module_hook(mocked_plugin_manager):
generator = FragmentsGenerator(
schema=GraphQLSchema(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,3 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
GetQueryAQueryA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,3 @@

class inputA(BaseModel):
version: enumA


inputA.model_rebuild()
3 changes: 0 additions & 3 deletions tests/main/clients/custom_config_file/expected_client/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@

class Test(BaseModel):
test_query: str = Field(alias="testQuery")


Test.model_rebuild()
Loading