diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 6319da52..956bb9be 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -135,6 +135,9 @@ def add_method( arguments, arguments_dict = self.arguments_generator.generate( definition.variable_definitions ) + + variable_names = self.get_variable_names(arguments) + operation_name = definition.name.value if definition.name else "" if definition.operation == OperationType.SUBSCRIPTION: if not async_: @@ -149,6 +152,7 @@ def add_method( arguments=arguments, arguments_dict=arguments_dict, operation_str=operation_str, + variable_names=variable_names, ) ) elif async_: @@ -159,6 +163,7 @@ def add_method( arguments_dict=arguments_dict, operation_str=operation_str, operation_name=operation_name, + variable_names=variable_names, ) else: method_def = self._generate_method( @@ -168,6 +173,7 @@ def add_method( arguments_dict=arguments_dict, operation_str=operation_str, operation_name=operation_name, + variable_names=variable_names, ) method_def.lineno = len(self._class_def.body) + 1 @@ -181,6 +187,23 @@ def add_method( generate_import_from(names=[return_type], from_=return_type_module, level=1) ) + def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]: + mapped_variable_names = [ + self._operation_str_variable, + self._variables_dict_variable, + self._response_variable, + self._data_variable, + ] + variable_names = {} + argument_names = set(arg.arg for arg in arguments.args) + + for variable in mapped_variable_names: + variable_names[variable] = ( + f"_{variable}" if variable in argument_names else variable + ) + + return variable_names + def _add_import(self, import_: Optional[ast.ImportFrom] = None): if not import_: return @@ -197,6 +220,7 @@ def _generate_subscription_method_def( arguments: ast.arguments, arguments_dict: ast.Dict, operation_str: str, + variable_names: Dict[str, str], ) -> ast.AsyncFunctionDef: return generate_async_method_definition( name=name, @@ -205,9 +229,11 @@ def _generate_subscription_method_def( value=generate_name(ASYNC_ITERATOR), slice_=generate_name(return_type) ), body=[ - self._generate_operation_str_assign(operation_str, 1), - self._generate_variables_assign(arguments_dict, 2), - self._generate_async_generator_loop(operation_name, return_type, 3), + self._generate_operation_str_assign(variable_names, operation_str, 1), + self._generate_variables_assign(variable_names, arguments_dict, 2), + self._generate_async_generator_loop( + variable_names, operation_name, return_type, 3 + ), ], ) @@ -219,17 +245,18 @@ def _generate_async_method( arguments_dict: ast.Dict, operation_str: str, operation_name: str, + variable_names: Dict[str, str], ) -> ast.AsyncFunctionDef: return generate_async_method_definition( name=name, arguments=arguments, return_type=generate_name(return_type), body=[ - self._generate_operation_str_assign(operation_str, 1), - self._generate_variables_assign(arguments_dict, 2), - self._generate_async_response_assign(operation_name, 3), - self._generate_data_retrieval(), - self._generate_return_parsed_obj(return_type), + self._generate_operation_str_assign(variable_names, operation_str, 1), + self._generate_variables_assign(variable_names, arguments_dict, 2), + self._generate_async_response_assign(variable_names, operation_name, 3), + self._generate_data_retrieval(variable_names), + self._generate_return_parsed_obj(variable_names, return_type), ], ) @@ -241,25 +268,26 @@ def _generate_method( arguments_dict: ast.Dict, operation_str: str, operation_name: str, + variable_names: Dict[str, str], ) -> ast.FunctionDef: return generate_method_definition( name=name, arguments=arguments, return_type=generate_name(return_type), body=[ - self._generate_operation_str_assign(operation_str, 1), - self._generate_variables_assign(arguments_dict, 2), - self._generate_response_assign(operation_name, 3), - self._generate_data_retrieval(), - self._generate_return_parsed_obj(return_type), + self._generate_operation_str_assign(variable_names, operation_str, 1), + self._generate_variables_assign(variable_names, arguments_dict, 2), + self._generate_response_assign(variable_names, operation_name, 3), + self._generate_data_retrieval(variable_names), + self._generate_return_parsed_obj(variable_names, return_type), ], ) def _generate_operation_str_assign( - self, operation_str: str, lineno: int = 1 + self, variable_names: Dict[str, str], operation_str: str, lineno: int = 1 ) -> ast.Assign: return generate_assign( - targets=[self._operation_str_variable], + targets=[variable_names[self._operation_str_variable]], value=generate_call( func=generate_name(self._gql_func_name), args=[ @@ -270,10 +298,10 @@ def _generate_operation_str_assign( ) def _generate_variables_assign( - self, arguments_dict: ast.Dict, lineno: int = 1 + self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1 ) -> ast.AnnAssign: return generate_ann_assign( - target=self._variables_dict_variable, + target=variable_names[self._variables_dict_variable], annotation=generate_subscript( generate_name(DICT), generate_tuple([generate_name("str"), generate_name("object")]), @@ -283,87 +311,107 @@ def _generate_variables_assign( ) def _generate_async_response_assign( - self, operation_name: str, lineno: int = 1 + self, variable_names: Dict[str, str], operation_name: str, lineno: int = 1 ) -> ast.Assign: return generate_assign( - targets=[self._response_variable], + targets=[variable_names[self._response_variable]], value=generate_await( - self._generate_execute_call(operation_name=operation_name) + self._generate_execute_call(variable_names, operation_name) ), lineno=lineno, ) def _generate_response_assign( - self, operation_name: str, lineno: int = 1 + self, + variable_names: Dict[str, str], + operation_name: str, + lineno: int = 1, ) -> ast.Assign: return generate_assign( - targets=[self._response_variable], - value=self._generate_execute_call(operation_name=operation_name), + targets=[variable_names[self._response_variable]], + value=self._generate_execute_call(variable_names, operation_name), lineno=lineno, ) - def _generate_execute_call(self, operation_name: str) -> ast.Call: + def _generate_execute_call( + self, variable_names: Dict[str, str], operation_name: str + ) -> ast.Call: return generate_call( func=generate_attribute(generate_name("self"), "execute"), keywords=[ generate_keyword( - value=generate_name(self._operation_str_variable), arg="query" + value=generate_name(variable_names[self._operation_str_variable]), + arg="query", ), generate_keyword( value=generate_constant(operation_name), arg="operation_name" ), generate_keyword( - value=generate_name(self._variables_dict_variable), arg="variables" + value=generate_name(variable_names[self._variables_dict_variable]), + arg="variables", ), generate_keyword(value=generate_name(KWARGS_NAMES)), ], ) - def _generate_data_retrieval(self) -> ast.Assign: + def _generate_data_retrieval(self, variable_names: Dict[str, str]) -> ast.Assign: return generate_assign( - targets=[self._data_variable], + targets=[variable_names[self._data_variable]], value=generate_call( func=generate_attribute(value=generate_name("self"), attr="get_data"), - args=[generate_name(self._response_variable)], + args=[generate_name(variable_names[self._response_variable])], ), ) - def _generate_return_parsed_obj(self, return_type: str) -> ast.Return: + def _generate_return_parsed_obj( + self, variable_names: Dict[str, str], return_type: str + ) -> ast.Return: return generate_return( generate_call( func=generate_attribute( generate_name(return_type), MODEL_VALIDATE_METHOD ), - args=[generate_name(self._data_variable)], + args=[generate_name(variable_names[self._data_variable])], ) ) def _generate_async_generator_loop( - self, operation_name: str, return_type: str, lineno: int = 1 + self, + variable_names: Dict[str, str], + operation_name: str, + return_type: str, + lineno: int = 1, ) -> ast.AsyncFor: return generate_async_for( - target=generate_name(self._data_variable), + target=generate_name(variable_names[self._data_variable]), iter_=generate_call( func=generate_attribute(value=generate_name("self"), attr="execute_ws"), keywords=[ generate_keyword( - value=generate_name(self._operation_str_variable), arg="query" + value=generate_name( + variable_names[self._operation_str_variable] + ), + arg="query", ), generate_keyword( value=generate_constant(operation_name), arg="operation_name" ), generate_keyword( - value=generate_name(self._variables_dict_variable), + value=generate_name( + variable_names[self._variables_dict_variable] + ), arg="variables", ), generate_keyword(value=generate_name(KWARGS_NAMES)), ], ), - body=[self._generate_yield_parsed_obj(return_type)], + body=[self._generate_yield_parsed_obj(variable_names, return_type)], lineno=lineno, ) - def _generate_yield_parsed_obj(self, return_type: str) -> ast.Expr: + def _generate_yield_parsed_obj( + self, variable_names: Dict[str, str], return_type: str + ) -> ast.Expr: return generate_expr( generate_yield( generate_call( @@ -371,7 +419,7 @@ def _generate_yield_parsed_obj(self, return_type: str) -> ast.Expr: value=generate_name(return_type), attr=MODEL_VALIDATE_METHOD, ), - args=[generate_name(self._data_variable)], + args=[generate_name(variable_names[self._data_variable])], ) ) ) diff --git a/pyproject.toml b/pyproject.toml index 620c5aba..bddcc760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ disable = [ "duplicate-code", "no-name-in-module", "too-many-locals", + "too-many-lines", ] [tool.pytest.ini_options] diff --git a/tests/client_generators/test_client_generator.py b/tests/client_generators/test_client_generator.py index c3e4b410..1f59080d 100644 --- a/tests/client_generators/test_client_generator.py +++ b/tests/client_generators/test_client_generator.py @@ -617,3 +617,458 @@ def test_add_method_triggers_generate_client_method_hook( generator.generate() assert mocked_plugin_manager.generate_client_method.called + + +def test_add_method_generates_correct_method_body_for_shadowed_variables( + base_client_import, +): + schema_str = """ + schema { query: Query } + type Query { xyz(query: String!, variables: String!, response: String!, data: String!): String } + """ + query_str = """ + query GetXyz($query: String!, $variables: String!, $response: String!, $data: String! ) { + xyz(query: $query, variables: $variables, response: $response, data: $data) + } + """ + generator = ClientGenerator( + base_client_import=base_client_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), + ) + method_name = "list_xyz" + return_type = "GetXyz" + return_type_module_name = method_name + expected_method_body = [ + ast.Assign( + targets=[ast.Name(id="_query")], + value=ast.Call( + func=ast.Name("gql"), + keywords=[], + args=[[ast.Constant(value=l + "\n") for l in query_str.splitlines()]], + ), + ), + ast.AnnAssign( + target=ast.Name(id="_variables"), + annotation=ast.Subscript( + value=ast.Name(id="Dict"), + slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]), + ), + value=ast.Dict( + keys=[ + ast.Constant(value="query"), + ast.Constant(value="variables"), + ast.Constant(value="response"), + ast.Constant(value="data"), + ], + values=[ + ast.Name(id="query"), + ast.Name(id="variables"), + ast.Name(id="response"), + ast.Name(id="data"), + ], + ), + simple=1, + ), + ast.Assign( + targets=[ast.Name(id="_response")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="execute"), + args=[], + keywords=[ + ast.keyword(arg="query", value=ast.Name(id="_query")), + ast.keyword( + arg="operation_name", value=ast.Constant(value="GetXyz") + ), + ast.keyword(arg="variables", value=ast.Name(id="_variables")), + ast.keyword(value=ast.Name(id="kwargs")), + ], + ), + ), + ast.Assign( + targets=[ast.Name(id="_data")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="get_data"), + args=[ast.Name(id="_response")], + keywords=[], + ), + ), + ast.Return( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="GetXyz"), attr="model_validate"), + args=[ast.Name(id="_data")], + keywords=[], + ) + ), + ] + + generator.add_method( + definition=cast(OperationDefinitionNode, parse(query_str).definitions[0]), + name=method_name, + return_type=return_type, + return_type_module=return_type_module_name, + operation_str=query_str, + async_=False, + ) + module = generator.generate() + + class_def = get_class_def(module) + assert class_def + method_def = class_def.body[0] + assert isinstance(method_def, ast.FunctionDef) + assert compare_ast(method_def.body, expected_method_body) + + +def test_add_method_generates_correct_method_body_for_shadowed_query_variable( + base_client_import, +): + schema_str = """ + schema { query: Query } + type Query { xyz(query: String, name: String): String } + """ + query_str = """ + query GetXyz($query: String, $name: String ) { + xyz(query: $query, name: $name) + } + """ + generator = ClientGenerator( + base_client_import=base_client_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), + ) + method_name = "list_xyz" + return_type = "GetXyz" + return_type_module_name = method_name + expected_method_body = [ + ast.Assign( + targets=[ast.Name(id="_query")], + value=ast.Call( + func=ast.Name("gql"), + keywords=[], + args=[[ast.Constant(value=l + "\n") for l in query_str.splitlines()]], + ), + ), + ast.AnnAssign( + target=ast.Name(id="variables"), + annotation=ast.Subscript( + value=ast.Name(id="Dict"), + slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]), + ), + value=ast.Dict( + keys=[ast.Constant(value="query"), ast.Constant(value="name")], + values=[ast.Name(id="query"), ast.Name(id="name")], + ), + simple=1, + ), + ast.Assign( + targets=[ast.Name(id="response")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="execute"), + args=[], + keywords=[ + ast.keyword(arg="query", value=ast.Name(id="_query")), + ast.keyword( + arg="operation_name", value=ast.Constant(value="GetXyz") + ), + ast.keyword(arg="variables", value=ast.Name(id="variables")), + ast.keyword(value=ast.Name(id="kwargs")), + ], + ), + ), + ast.Assign( + targets=[ast.Name(id="data")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="get_data"), + args=[ast.Name(id="response")], + keywords=[], + ), + ), + ast.Return( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="GetXyz"), attr="model_validate"), + args=[ast.Name(id="data")], + keywords=[], + ) + ), + ] + + generator.add_method( + definition=cast(OperationDefinitionNode, parse(query_str).definitions[0]), + name=method_name, + return_type=return_type, + return_type_module=return_type_module_name, + operation_str=query_str, + async_=False, + ) + module = generator.generate() + + class_def = get_class_def(module) + assert class_def + method_def = class_def.body[0] + assert isinstance(method_def, ast.FunctionDef) + assert compare_ast(method_def.body, expected_method_body) + + +def test_add_method_generates_correct_method_body_for_shadowed_variables_variable( + base_client_import, +): + schema_str = """ + schema { query: Query } + type Query { xyz(variables: String, name: String): String } + """ + query_str = """ + query GetXyz($variables: String, $name: String ) { + xyz(variables: $variables, name: $name) + } + """ + generator = ClientGenerator( + base_client_import=base_client_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), + ) + method_name = "list_xyz" + return_type = "GetXyz" + return_type_module_name = method_name + expected_method_body = [ + ast.Assign( + targets=[ast.Name(id="query")], + value=ast.Call( + func=ast.Name("gql"), + keywords=[], + args=[[ast.Constant(value=l + "\n") for l in query_str.splitlines()]], + ), + ), + ast.AnnAssign( + target=ast.Name(id="_variables"), + annotation=ast.Subscript( + value=ast.Name(id="Dict"), + slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]), + ), + value=ast.Dict( + keys=[ast.Constant(value="variables"), ast.Constant(value="name")], + values=[ast.Name(id="variables"), ast.Name(id="name")], + ), + simple=1, + ), + ast.Assign( + targets=[ast.Name(id="response")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="execute"), + args=[], + keywords=[ + ast.keyword(arg="query", value=ast.Name(id="query")), + ast.keyword( + arg="operation_name", value=ast.Constant(value="GetXyz") + ), + ast.keyword(arg="variables", value=ast.Name(id="_variables")), + ast.keyword(value=ast.Name(id="kwargs")), + ], + ), + ), + ast.Assign( + targets=[ast.Name(id="data")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="get_data"), + args=[ast.Name(id="response")], + keywords=[], + ), + ), + ast.Return( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="GetXyz"), attr="model_validate"), + args=[ast.Name(id="data")], + keywords=[], + ) + ), + ] + + generator.add_method( + definition=cast(OperationDefinitionNode, parse(query_str).definitions[0]), + name=method_name, + return_type=return_type, + return_type_module=return_type_module_name, + operation_str=query_str, + async_=False, + ) + module = generator.generate() + + class_def = get_class_def(module) + assert class_def + method_def = class_def.body[0] + assert isinstance(method_def, ast.FunctionDef) + assert compare_ast(method_def.body, expected_method_body) + + +def test_add_method_generates_correct_method_body_for_shadowed_response_variable( + base_client_import, +): + schema_str = """ + schema { query: Query } + type Query { xyz(response: String, name: String): String } + """ + query_str = """ + query GetXyz($response: String, $name: String ) { + xyz(response: $response, name: $name) + } + """ + generator = ClientGenerator( + base_client_import=base_client_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), + ) + method_name = "list_xyz" + return_type = "GetXyz" + return_type_module_name = method_name + expected_method_body = [ + ast.Assign( + targets=[ast.Name(id="query")], + value=ast.Call( + func=ast.Name("gql"), + keywords=[], + args=[[ast.Constant(value=l + "\n") for l in query_str.splitlines()]], + ), + ), + ast.AnnAssign( + target=ast.Name(id="variables"), + annotation=ast.Subscript( + value=ast.Name(id="Dict"), + slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]), + ), + value=ast.Dict( + keys=[ast.Constant(value="response"), ast.Constant(value="name")], + values=[ast.Name(id="response"), ast.Name(id="name")], + ), + simple=1, + ), + ast.Assign( + targets=[ast.Name(id="_response")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="execute"), + args=[], + keywords=[ + ast.keyword(arg="query", value=ast.Name(id="query")), + ast.keyword( + arg="operation_name", value=ast.Constant(value="GetXyz") + ), + ast.keyword(arg="variables", value=ast.Name(id="variables")), + ast.keyword(value=ast.Name(id="kwargs")), + ], + ), + ), + ast.Assign( + targets=[ast.Name(id="data")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="get_data"), + args=[ast.Name(id="_response")], + keywords=[], + ), + ), + ast.Return( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="GetXyz"), attr="model_validate"), + args=[ast.Name(id="data")], + keywords=[], + ) + ), + ] + + generator.add_method( + definition=cast(OperationDefinitionNode, parse(query_str).definitions[0]), + name=method_name, + return_type=return_type, + return_type_module=return_type_module_name, + operation_str=query_str, + async_=False, + ) + module = generator.generate() + + class_def = get_class_def(module) + assert class_def + method_def = class_def.body[0] + assert isinstance(method_def, ast.FunctionDef) + assert compare_ast(method_def.body, expected_method_body) + + +def test_add_method_generates_correct_method_body_for_shadowed_data_variable( + base_client_import, +): + schema_str = """ + schema { query: Query } + type Query { xyz(data: String, name: String): String } + """ + query_str = """ + query GetXyz($data: String, $name: String ) { + xyz(data: $data, name: $name) + } + """ + generator = ClientGenerator( + base_client_import=base_client_import, + arguments_generator=ArgumentsGenerator(schema=build_schema(schema_str)), + ) + method_name = "list_xyz" + return_type = "GetXyz" + return_type_module_name = method_name + expected_method_body = [ + ast.Assign( + targets=[ast.Name(id="query")], + value=ast.Call( + func=ast.Name("gql"), + keywords=[], + args=[[ast.Constant(value=l + "\n") for l in query_str.splitlines()]], + ), + ), + ast.AnnAssign( + target=ast.Name(id="variables"), + annotation=ast.Subscript( + value=ast.Name(id="Dict"), + slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]), + ), + value=ast.Dict( + keys=[ast.Constant(value="data"), ast.Constant(value="name")], + values=[ast.Name(id="data"), ast.Name(id="name")], + ), + simple=1, + ), + ast.Assign( + targets=[ast.Name(id="response")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="execute"), + args=[], + keywords=[ + ast.keyword(arg="query", value=ast.Name(id="query")), + ast.keyword( + arg="operation_name", value=ast.Constant(value="GetXyz") + ), + ast.keyword(arg="variables", value=ast.Name(id="variables")), + ast.keyword(value=ast.Name(id="kwargs")), + ], + ), + ), + ast.Assign( + targets=[ast.Name(id="_data")], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="self"), attr="get_data"), + args=[ast.Name(id="response")], + keywords=[], + ), + ), + ast.Return( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="GetXyz"), attr="model_validate"), + args=[ast.Name(id="_data")], + keywords=[], + ) + ), + ] + + generator.add_method( + definition=cast(OperationDefinitionNode, parse(query_str).definitions[0]), + name=method_name, + return_type=return_type, + return_type_module=return_type_module_name, + operation_str=query_str, + async_=False, + ) + module = generator.generate() + + class_def = get_class_def(module) + assert class_def + method_def = class_def.body[0] + assert isinstance(method_def, ast.FunctionDef) + assert compare_ast(method_def.body, expected_method_body)