Skip to content

Commit

Permalink
Remove ExclusivelyOwned from register_dispatch_key (pytorch#106791)
Browse files Browse the repository at this point in the history
This fixes a bug that could occur with python decompositions.

When an operation is intercepted in the c++ code in pytorch the outputs a created as `ExclusivelyOwned<at::Tensor>`s. Later on when it dispatches back to python for the decomposition these tensors have their ownership shared with python. In a normal use case the exclusively owned tensor is released and it's value returned as a non-exclusively owned tensor from the operation. However if the python decomposition throws an error the `ExclusivelyOwned` wrapper destroys the `at::Tensor` leading to a python reference to a tensor which isn't alive (and meaning pytorch falls over in debug mode).

Note this will be a performance hit when handling errors.

Fixes pytorch#106790

Pull Request resolved: pytorch#106791
Approved by: https://github.com/ezyang
  • Loading branch information
Anstow authored and pytorchmergebot committed Aug 11, 2023
1 parent d97b18d commit c9cdcb2
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions torchgen/dest/register_dispatch_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,15 +577,14 @@ def gen_class_set_output_functions(
set_output_super = ""

def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
maybe_star = "*" if k is SchemaKind.functional else ""
return f"""
void set_output_{name}(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {{
{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
if (!names.empty()) {{
namedinference::propagate_names({maybe_star}outputs_[output_idx], names);
namedinference::propagate_names(outputs_[output_idx], names);
}}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
Expand Down Expand Up @@ -621,7 +620,7 @@ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) ->
create_proxy = """
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
if (C10_UNLIKELY(maybe_proxy.has_value())) {
proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value());
proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
}
"""
else:
Expand Down Expand Up @@ -683,17 +682,17 @@ def gen_class(
generate_super: bool,
) -> str:
if k is SchemaKind.functional:
output_type = "c10::ExclusivelyOwned<Tensor>"
output_value = "*outputs_[output_idx]"
output_type = "Tensor"
output_value = "outputs_[output_idx]"
proxy_field = ""
elif k is SchemaKind.inplace:
output_type = "std::reference_wrapper<Tensor>"
output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()"
proxy_field = f"std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, {len(f.func.returns)}> proxy_outputs_;"
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
elif k is SchemaKind.out:
output_type = "std::reference_wrapper<Tensor>"
output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()"
proxy_field = f"std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, {len(f.func.returns)}> proxy_outputs_;"
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"

if self.backend_index.dispatch_key == DispatchKey.CUDA:
if self.rocm:
Expand Down Expand Up @@ -886,8 +885,7 @@ def generate_defn(cpp_sig: CppSignature) -> str:
if k is SchemaKind.out:
expr = f"op.maybe_get_output({i})"
else:
maybe_star = "*" if k is SchemaKind.functional else ""
expr = f"{maybe_star}op.outputs_[{i}]"
expr = f"op.outputs_[{i}]"

context.append(
Expr(
Expand Down Expand Up @@ -942,17 +940,17 @@ def generate_defn(cpp_sig: CppSignature) -> str:
if k is SchemaKind.out or k is SchemaKind.inplace:
for i in range(len(f.func.returns)):
sig_body.append(
f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(**op.proxy_outputs_[{i}]);"
f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
)

# Destructively return the final tensors
# TODO: Do this in translate instead
if k is SchemaKind.functional:
if len(f.func.returns) == 1:
ret_expr = "std::move(op.outputs_[0]).take()" # small optimization
ret_expr = "std::move(op.outputs_[0])" # small optimization
else:
moved = ", ".join(
f"std::move(op.outputs_[{i}]).take()"
f"std::move(op.outputs_[{i}])"
for i in range(len(f.func.returns))
)
ret_expr = f"std::make_tuple({moved})"
Expand Down

0 comments on commit c9cdcb2

Please sign in to comment.