Skip to content

Commit

Permalink
Change API type Tensor[] for structured kernels. (pytorch#73350)
Browse files Browse the repository at this point in the history
Partially fixes: pytorch#66328

This PR:
- adds support for `ITensorList` to the dispatcher for:
  - computing the dispatch key
  - boxing and unboxing `ITensorList`
- modified the codegen for structured kernels:
  - codegen APIs use `ITensorList` instead of `ArrayRef<Tensor>`

**Changes summary:**

- Signature changes due to the different APIs:
  - dispatcher API (e.g. `BatchingRegistrations.cpp`)
  - C++ API (e.g. `TensorShape.cpp`)
- Miscelaneous functions used by codegen'd functions (e.g. `FunctionalTensorWrapper.*`)
- Dispatcher changes for handling `ITensorList` correctly (e.g. `DispatchKeyExtractor.h`)
- Signature changes of `at::cat` due to the need of `const` inside `TensorBody.h`
- Forward declarations of `ITensorList` (e.g. `MethodOperators.h`)
- Codegen changes, special casing structured kernels (e.g. `gen.py`)

**Short description of structured kernels special casing:**

I introduced, mainly, 5 types of changes to the codegen for generating code depending on
whether the kernel is structured or not:

1. Added a `structured_type_override` flag to the `argument_type` function definition of
the affected APIs (mainly the dispatcher and C++ APIs).
  - `api/cpp.py`, `api/dispatcher.py`, `api/native.py`
2. Added a `structured_type_override` member to the signature
classes (e.g. `CppSignature`), since `FunctionSchema` doesn't really know whether the
function is structured or not
  - `api/types.py`
3. Added a `part_of_structured_group` to `NativeFunction` class, which is just a
convenient function to forward to `structured_type_override` wherever needed
  - `model.py`
4. Appropriately changed the rest of the codegen, whenever it used either the signature
classes or the `arguments` function directly
5. Added a check for `const ITensorList&` type wherever there was a check for `TensorList`
Pull Request resolved: pytorch#73350
Approved by: https://github.com/bdhirsh
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Sep 26, 2022
1 parent 1a2734e commit 4a2d2e5
Show file tree
Hide file tree
Showing 66 changed files with 408 additions and 279 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
einsum-path
76c65b13280cd5782ace8050df45564ef17891f9
3 changes: 2 additions & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/BatchedFallback.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/ATen.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
#include <c10/core/SymIntArrayRef.h>

Expand Down Expand Up @@ -916,7 +917,7 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
}

Tensor cat_batching_rule(TensorList tensors, int64_t dim) {
Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/DeviceGuard.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <ATen/core/IListRef.h>
#include <ATen/core/Tensor.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/ScalarType.h> // TensorList whyyyyy
Expand Down Expand Up @@ -29,7 +30,7 @@ inline c10::optional<Device> device_of(const c10::optional<Tensor>& t) {
/// Return the Device of a TensorList, if the list is non-empty and
/// the first Tensor is defined. (This function implicitly assumes
/// that all tensors in the list have the same device.)
inline c10::optional<Device> device_of(TensorList t) {
inline c10::optional<Device> device_of(ITensorListRef t) {
if (!t.empty()) {
return device_of(t.front());
} else {
Expand Down
109 changes: 35 additions & 74 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/FunctionalInverses.h>
#include <ATen/TensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <c10/util/Exception.h>

Expand Down Expand Up @@ -370,14 +371,6 @@ c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor)
}
return c10::nullopt;
}
c10::List<Tensor> to_functional_tensor(const c10::List<Tensor>& t_list) {
c10::List<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_functional_tensor(t_list[i]));
}
return outputs;
}
c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
Expand All @@ -386,17 +379,11 @@ c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optio
}
return outputs;
}
std::vector<Tensor> to_functional_tensor(const std::vector<Tensor>& t_list) {
std::vector<Tensor> outputs(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs[i] = to_functional_tensor(t_list[i]);
}
return outputs;
}
std::vector<Tensor> to_functional_tensor(const TensorList& t_list) {
std::vector<Tensor> outputs(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs[i] = to_functional_tensor(t_list[i]);
std::vector<Tensor> to_functional_tensor(ITensorListRef t_list) {
std::vector<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto& tensor : t_list) {
outputs.push_back(to_functional_tensor(tensor));
}
return outputs;
}
Expand All @@ -422,17 +409,17 @@ c10::optional<Tensor> from_functional_tensor(const c10::optional<Tensor>& t, boo
}
return c10::nullopt;
}
c10::List<Tensor> from_functional_tensor(const c10::List<Tensor>& t_list) {
c10::List<Tensor> outputs;
std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
std::vector<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
for (const auto& tensor : t_list) {
// from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call
// it on a non-functional input,
// but from_functional_tensor(TensorList) can recieve a list containing both
// functional and non-functional tensors.
// Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor).
// When that happens, we're okay with only unwrapping the functional tensors.
outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false));
outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false));
}
return outputs;
}
Expand All @@ -444,13 +431,6 @@ c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::opt
}
return outputs;
}
std::vector<Tensor> from_functional_tensor(const TensorList& t_list) {
std::vector<Tensor> outputs(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs[i] = from_functional_tensor(t_list[i], /*assert_functional=*/false);
}
return outputs;
}

void sync(const Tensor& t) {
if (t.unsafeGetTensorImpl()->is_wrapped_number()) {
Expand All @@ -474,13 +454,8 @@ void sync(const c10::optional<Tensor>& t) {
sync(*t);
}
}
void sync(const c10::List<Tensor> t_list) {
for (const auto i : c10::irange(t_list.size())) {
sync(t_list[i]);
}
}
void sync(const at::TensorList t_list) {
for (auto t: t_list) {
void sync(ITensorListRef t_list) {
for (const auto& t : t_list) {
sync(t);
}
}
Expand All @@ -495,22 +470,24 @@ void replace_(const Tensor& functional_tensor, const Tensor& other) {
unsafeGetFunctionalWrapper(functional_tensor)->replace_(other);
}

void replace_(const TensorList functional_tensor, TensorList other) {
void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
auto functional_tensor_it = functional_tensor.begin();
auto other_it = other.begin();
for (const auto i : c10::irange(functional_tensor.size())) {
replace_(functional_tensor[i], other[i]);
(void)i; // Suppress unused variable warning
replace_(*functional_tensor_it++, *other_it++);
}
}


void commit_update(const Tensor& functional_tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
}

void commit_update(const TensorList functional_tensor) {
for (const auto i : c10::irange(functional_tensor.size())) {
commit_update(functional_tensor[i]);
void commit_update(ITensorListRef functional_tensor) {
for (const auto& t : functional_tensor) {
commit_update(t);
}
}

Expand All @@ -526,21 +503,6 @@ bool isFunctionalTensor(const c10::optional<Tensor>& t) {
}
}

// For lists that have a mix of functional and nonfunctional tensors,
// functionalization machinery should just unwrap the functional wrappers
// and leave the ordinary tensors alone.
bool isFunctionalTensor(const c10::List<Tensor>& t_list) {
if (t_list.size() == 0) return false;
auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
if (!t_list[i].defined()) continue;
if (isFunctionalTensor(t_list[i])) {
++functional_count;
}
}
return functional_count > 0;
}

bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
if (t_list.size() == 0) return false;
auto functional_count = 0;
Expand All @@ -553,18 +515,23 @@ bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
return functional_count > 0;
}

bool isFunctionalTensor(const c10::ArrayRef<Tensor> t_list) {
if (t_list.size() == 0) return false;
template <typename T>
bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
if (list.size() == 0) return false;
auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
if (!t_list[i].defined()) continue;
if (isFunctionalTensor(t_list[i])) {
for (const auto& tensor : list) {
if (!tensor.defined()) continue;
if (isFunctionalTensor(tensor)) {
++functional_count;
}
}
return functional_count > 0;
}

bool isFunctionalTensor(ITensorListRef list) {
return isFunctionalTensorIListRef(list);
}

Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
Expand All @@ -578,18 +545,12 @@ Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, c
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
}

std::vector<Tensor> create_functional_tensor_with_view_meta(const c10::List<at::Tensor>& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) {
std::vector<Tensor> outputs(view_to_wrap.size());
for (const auto i : c10::irange(view_to_wrap.size())) {
outputs[i] = create_functional_tensor_with_view_meta(view_to_wrap[i], base, meta, i);
}
return outputs;
}

std::vector<Tensor> create_functional_tensor_with_view_meta(const std::vector<at::Tensor>& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) {
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) {
std::vector<Tensor> outputs(view_to_wrap.size());
for (const auto i : c10::irange(view_to_wrap.size())) {
outputs[i] = create_functional_tensor_with_view_meta(view_to_wrap[i], base, meta, i);
int64_t i = 0;
for (const auto& tensor : view_to_wrap) {
outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i);
i++;
}
return outputs;
}
Expand Down
29 changes: 10 additions & 19 deletions aten/src/ATen/FunctionalTensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <ATen/ArrayRef.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/List.h>
#include <ATen/core/boxing/BoxedKernel.h>
#include <ATen/core/boxing/impl/boxing.h>
Expand Down Expand Up @@ -183,56 +184,46 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(

TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(const c10::List<Tensor>& t_list);
TORCH_API bool isFunctionalTensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API bool isFunctionalTensor(const c10::ArrayRef<Tensor> t_list);
TORCH_API bool isFunctionalTensor(ITensorListRef list);

TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
TORCH_API c10::optional<Tensor> to_functional_tensor(
const c10::optional<Tensor>& tensor);
TORCH_API c10::List<Tensor> to_functional_tensor(
const c10::List<Tensor>& t_list);
TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> to_functional_tensor(
const std::vector<Tensor>& t_list);
TORCH_API std::vector<Tensor> to_functional_tensor(const TensorList& t_list);
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);

TORCH_API Tensor
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
TORCH_API c10::optional<Tensor> from_functional_tensor(
const c10::optional<Tensor>& t,
bool assert_functional = true);
TORCH_API c10::List<Tensor> from_functional_tensor(
const c10::List<Tensor>& t_list);
TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> from_functional_tensor(const TensorList& tensors);
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);

TORCH_API void sync(const at::Tensor& t);
TORCH_API void sync(const c10::optional<Tensor>& t);
TORCH_API void sync(const c10::List<Tensor> t_list);
TORCH_API void sync(const at::TensorList t_list);
TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list);
TORCH_API void sync(ITensorListRef t_list);

TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
TORCH_API void replace_(const TensorList functional_tensor, TensorList other);
TORCH_API void replace_(
const ITensorListRef functional_tensor,
ITensorListRef other);

TORCH_API void commit_update(const Tensor& functional_tensor);
TORCH_API void commit_update(const TensorList functional_tensor);
TORCH_API void commit_update(ITensorListRef functional_tensor);

Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta,
int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta(
const c10::List<Tensor>& view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta);
std::vector<Tensor> create_functional_tensor_with_view_meta(
const std::vector<Tensor>& view_to_wrap,
ITensorListRef view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta);

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/NamedTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,12 @@ std::vector<Dimname> broadcast_to_outnames(
return unify_from_right(reference_names, tensor_names);
}

std::vector<Dimname> compute_cat_outnames(ITensorListRef tensors) {
std::vector<Dimname> compute_cat_outnames(const MaterializedITensorListRef& tensors) {
if (!at::has_names(tensors)) {
return {};
}
std::vector<Dimname> result;
for (const auto& tensor : tensors) {
for (const Tensor& tensor : tensors) {
const auto tensor_names = tensor.names();
TORCH_CHECK(tensor_names.size() > 0, "zero-dimensional tensor cannot be concatenated");
TORCH_CHECK(result.empty() || tensor_names.size() == result.size(),
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/NamedTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ TORCH_API void propagate_names_for_expand(
const Tensor& result,
const Tensor& self);

TORCH_API std::vector<Dimname> compute_cat_outnames(ITensorListRef tensors);
TORCH_API std::vector<Dimname> compute_cat_outnames(
const MaterializedITensorListRef& tensors);

TORCH_API std::vector<Dimname> compute_broadcast_outnames(
const Tensor& self,
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/VmapTransforms.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/VmapTransforms.h>
#include <ATen/ATen.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>

namespace at {
Expand Down Expand Up @@ -188,7 +189,7 @@ static Tensor alignBatchDimsAtFront(
// 4. Expand each physical tensor so that they have output batch size equal
// to `batch_sizes`
VmapPhysicalViewVec
MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
// Figure out all of the collective vmap levels in `logical_tensors`.
std::bitset<kVmapNumLevels> collective_levels;
for (const auto& logical_tensor : logical_tensors) {
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/VmapTransforms.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <ATen/BatchedTensorImpl.h>
#include <ATen/core/IListRef.h>

namespace at {

Expand Down Expand Up @@ -55,7 +56,7 @@ using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
// and returns a VmapPhysicalView on the tensor(s).
struct TORCH_API MultiBatchVmapTransform {
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
};

// VmapTransform for operators that broadcast all inputs.
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/WrapDimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ static inline int64_t legacy_cat_wrap_dim(
return dim;
}

static inline int64_t legacy_cat_wrap_dim(int64_t dim, ITensorListRef tensors) {
for (auto& tensor : tensors) {
static inline int64_t legacy_cat_wrap_dim(
int64_t dim,
const MaterializedITensorListRef& tensors) {
for (const Tensor& tensor : tensors) {
if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
&ADD_NS(linalg_inv_ex)>::type::call)));

// promote
KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL_CPU(ADD_NS(cat), "cat", Tensor (const ITensorListRef &, int64_t), promote)
KERNEL_CPU(ADD_NS(index_copy), "index_copy", Tensor (const Tensor &, int64_t, const Tensor &, const Tensor &), promote)
KERNEL_CPU(ADD_NS(index_copy), "index_copy.dimname", Tensor (const Tensor &, at::Dimname, const Tensor &, const Tensor &), promote)

Expand Down
Loading

0 comments on commit 4a2d2e5

Please sign in to comment.