diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index c4b24b1dce4a94..dde06ac3472e75 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -577,7 +577,6 @@ def gen_class_set_output_functions( set_output_super = "" def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: - maybe_star = "*" if k is SchemaKind.functional else "" return f""" void set_output_{name}( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, @@ -585,7 +584,7 @@ def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: ) override {{ {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")} if (!names.empty()) {{ - namedinference::propagate_names({maybe_star}outputs_[output_idx], names); + namedinference::propagate_names(outputs_[output_idx], names); }} // super must happen after, so that downstream can use maybe_get_output // to retrieve the output @@ -621,7 +620,7 @@ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> create_proxy = """ auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); if (C10_UNLIKELY(maybe_proxy.has_value())) { - proxy_outputs_[output_idx] = c10::ExclusivelyOwned(std::move(maybe_proxy).value()); + proxy_outputs_[output_idx] = std::move(maybe_proxy).value(); } """ else: @@ -683,17 +682,17 @@ def gen_class( generate_super: bool, ) -> str: if k is SchemaKind.functional: - output_type = "c10::ExclusivelyOwned" - output_value = "*outputs_[output_idx]" + output_type = "Tensor" + output_value = "outputs_[output_idx]" proxy_field = "" elif k is SchemaKind.inplace: output_type = "std::reference_wrapper" - output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()" - proxy_field = f"std::array>, {len(f.func.returns)}> proxy_outputs_;" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array, {len(f.func.returns)}> proxy_outputs_;" elif k is SchemaKind.out: output_type = "std::reference_wrapper" - output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()" - proxy_field = f"std::array>, {len(f.func.returns)}> proxy_outputs_;" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array, {len(f.func.returns)}> proxy_outputs_;" if self.backend_index.dispatch_key == DispatchKey.CUDA: if self.rocm: @@ -886,8 +885,7 @@ def generate_defn(cpp_sig: CppSignature) -> str: if k is SchemaKind.out: expr = f"op.maybe_get_output({i})" else: - maybe_star = "*" if k is SchemaKind.functional else "" - expr = f"{maybe_star}op.outputs_[{i}]" + expr = f"op.outputs_[{i}]" context.append( Expr( @@ -942,17 +940,17 @@ def generate_defn(cpp_sig: CppSignature) -> str: if k is SchemaKind.out or k is SchemaKind.inplace: for i in range(len(f.func.returns)): sig_body.append( - f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(**op.proxy_outputs_[{i}]);" + f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" ) # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: - ret_expr = "std::move(op.outputs_[0]).take()" # small optimization + ret_expr = "std::move(op.outputs_[0])" # small optimization else: moved = ", ".join( - f"std::move(op.outputs_[{i}]).take()" + f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns)) ) ret_expr = f"std::make_tuple({moved})"