Skip to content

Commit

Permalink
Port cat kernel to structured kernels.
Browse files Browse the repository at this point in the history
Tracking issue: pytorch#55070

Pull Request resolved: pytorch#68640

Approved by: https://github.com/ezyang
  • Loading branch information
ysiraichi authored and pytorchmergebot committed Apr 14, 2022
1 parent 4c3ee53 commit 22a10ce
Show file tree
Hide file tree
Showing 19 changed files with 254 additions and 349 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/NamedTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ std::vector<Dimname> broadcast_to_outnames(
return unify_from_right(reference_names, tensor_names);
}

std::vector<Dimname> compute_cat_outnames(TensorList tensors) {
std::vector<Dimname> compute_cat_outnames(ITensorListRef tensors) {
if (!at::has_names(tensors)) {
return {};
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/NamedTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace at {

using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;

inline bool has_names(TensorList tensors) {
inline bool has_names(ITensorListRef tensors) {
return std::any_of(
tensors.begin(), tensors.end(), [](const Tensor& t) { return t.has_names(); });
}
Expand Down Expand Up @@ -98,7 +98,7 @@ TORCH_API void propagate_names_for_reduction(const Tensor& result, const Tensor&

TORCH_API void propagate_names_for_expand(const Tensor& result, const Tensor& self);

TORCH_API std::vector<Dimname> compute_cat_outnames(TensorList tensors);
TORCH_API std::vector<Dimname> compute_cat_outnames(ITensorListRef tensors);

TORCH_API std::vector<Dimname> compute_broadcast_outnames(
const Tensor& self,
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/WrapDimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <c10/core/TensorImpl.h>
#include <c10/util/irange.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/ITensorListRef.h>

namespace at {

Expand Down Expand Up @@ -74,7 +75,7 @@ static inline int64_t legacy_cat_wrap_dim(int64_t dim, const std::vector<std::ve
return dim;
}

static inline int64_t legacy_cat_wrap_dim(int64_t dim, TensorList tensors) {
static inline int64_t legacy_cat_wrap_dim(int64_t dim, ITensorListRef tensors) {
for (auto& tensor : tensors) {
if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
continue;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ static void isin_sorting(
// 2. Stable sort all elements, maintaining order indices to reverse the
// operation. Stable sort is necessary to keep elements before test
// elements within the sorted list.
Tensor all_elements = at::_cat({elements_flat, test_elements_flat});
Tensor all_elements = at::cat({elements_flat, test_elements_flat});
Tensor sorted_elements, sorted_order;
std::tie (sorted_elements, sorted_order) = all_elements.sort(
/*stable=*/ true, /*dim=*/ 0, /*descending=*/ false);
Expand Down
Loading

0 comments on commit 22a10ce

Please sign in to comment.