diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 5fe9866a7177a..cb2f7a4ea1d27 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -einsum-path +76c65b13280cd5782ace8050df45564ef17891f9 diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 02bbbb7088d6e..6d95ae3553322 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -916,7 +917,7 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } -Tensor cat_batching_rule(TensorList tensors, int64_t dim) { +Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); auto physical_tensors = fmap( physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); diff --git a/aten/src/ATen/DeviceGuard.h b/aten/src/ATen/DeviceGuard.h index a827a1ccc7fad..83bb31d7fd425 100644 --- a/aten/src/ATen/DeviceGuard.h +++ b/aten/src/ATen/DeviceGuard.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include // TensorList whyyyyy @@ -29,7 +30,7 @@ inline c10::optional device_of(const c10::optional& t) { /// Return the Device of a TensorList, if the list is non-empty and /// the first Tensor is defined. (This function implicitly assumes /// that all tensors in the list have the same device.) -inline c10::optional device_of(TensorList t) { +inline c10::optional device_of(ITensorListRef t) { if (!t.empty()) { return device_of(t.front()); } else { diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 2c60d3e77ba49..0c0faae0df6ab 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -370,14 +371,6 @@ c10::optional to_functional_tensor(const c10::optional& tensor) } return c10::nullopt; } -c10::List to_functional_tensor(const c10::List& t_list) { - c10::List outputs; - outputs.reserve(t_list.size()); - for (const auto i : c10::irange(t_list.size())) { - outputs.push_back(to_functional_tensor(t_list[i])); - } - return outputs; -} c10::List> to_functional_tensor(const c10::List>& t_list) { c10::List> outputs; outputs.reserve(t_list.size()); @@ -386,17 +379,11 @@ c10::List> to_functional_tensor(const c10::List to_functional_tensor(const std::vector& t_list) { - std::vector outputs(t_list.size()); - for (const auto i : c10::irange(t_list.size())) { - outputs[i] = to_functional_tensor(t_list[i]); - } - return outputs; -} -std::vector to_functional_tensor(const TensorList& t_list) { - std::vector outputs(t_list.size()); - for (const auto i : c10::irange(t_list.size())) { - outputs[i] = to_functional_tensor(t_list[i]); +std::vector to_functional_tensor(ITensorListRef t_list) { + std::vector outputs; + outputs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outputs.push_back(to_functional_tensor(tensor)); } return outputs; } @@ -422,17 +409,17 @@ c10::optional from_functional_tensor(const c10::optional& t, boo } return c10::nullopt; } -c10::List from_functional_tensor(const c10::List& t_list) { - c10::List outputs; +std::vector from_functional_tensor(ITensorListRef t_list) { + std::vector outputs; outputs.reserve(t_list.size()); - for (const auto i : c10::irange(t_list.size())) { + for (const auto& tensor : t_list) { // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call // it on a non-functional input, // but from_functional_tensor(TensorList) can recieve a list containing both // functional and non-functional tensors. // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor). // When that happens, we're okay with only unwrapping the functional tensors. - outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false)); + outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false)); } return outputs; } @@ -444,13 +431,6 @@ c10::List> from_functional_tensor(const c10::List from_functional_tensor(const TensorList& t_list) { - std::vector outputs(t_list.size()); - for (const auto i : c10::irange(t_list.size())) { - outputs[i] = from_functional_tensor(t_list[i], /*assert_functional=*/false); - } - return outputs; -} void sync(const Tensor& t) { if (t.unsafeGetTensorImpl()->is_wrapped_number()) { @@ -474,13 +454,8 @@ void sync(const c10::optional& t) { sync(*t); } } -void sync(const c10::List t_list) { - for (const auto i : c10::irange(t_list.size())) { - sync(t_list[i]); - } -} -void sync(const at::TensorList t_list) { - for (auto t: t_list) { +void sync(ITensorListRef t_list) { + for (const auto& t : t_list) { sync(t); } } @@ -495,22 +470,24 @@ void replace_(const Tensor& functional_tensor, const Tensor& other) { unsafeGetFunctionalWrapper(functional_tensor)->replace_(other); } -void replace_(const TensorList functional_tensor, TensorList other) { +void replace_(const ITensorListRef functional_tensor, ITensorListRef other) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size()); + auto functional_tensor_it = functional_tensor.begin(); + auto other_it = other.begin(); for (const auto i : c10::irange(functional_tensor.size())) { - replace_(functional_tensor[i], other[i]); + (void)i; // Suppress unused variable warning + replace_(*functional_tensor_it++, *other_it++); } } - void commit_update(const Tensor& functional_tensor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); unsafeGetFunctionalWrapper(functional_tensor)->commit_update(); } -void commit_update(const TensorList functional_tensor) { - for (const auto i : c10::irange(functional_tensor.size())) { - commit_update(functional_tensor[i]); +void commit_update(ITensorListRef functional_tensor) { + for (const auto& t : functional_tensor) { + commit_update(t); } } @@ -526,21 +503,6 @@ bool isFunctionalTensor(const c10::optional& t) { } } -// For lists that have a mix of functional and nonfunctional tensors, -// functionalization machinery should just unwrap the functional wrappers -// and leave the ordinary tensors alone. -bool isFunctionalTensor(const c10::List& t_list) { - if (t_list.size() == 0) return false; - auto functional_count = 0; - for (const auto i : c10::irange(t_list.size())) { - if (!t_list[i].defined()) continue; - if (isFunctionalTensor(t_list[i])) { - ++functional_count; - } - } - return functional_count > 0; -} - bool isFunctionalTensor(const c10::List>& t_list) { if (t_list.size() == 0) return false; auto functional_count = 0; @@ -553,18 +515,23 @@ bool isFunctionalTensor(const c10::List>& t_list) { return functional_count > 0; } -bool isFunctionalTensor(const c10::ArrayRef t_list) { - if (t_list.size() == 0) return false; +template +bool isFunctionalTensorIListRef(c10::IListRef list) { + if (list.size() == 0) return false; auto functional_count = 0; - for (const auto i : c10::irange(t_list.size())) { - if (!t_list[i].defined()) continue; - if (isFunctionalTensor(t_list[i])) { + for (const auto& tensor : list) { + if (!tensor.defined()) continue; + if (isFunctionalTensor(tensor)) { ++functional_count; } } return functional_count > 0; } +bool isFunctionalTensor(ITensorListRef list) { + return isFunctionalTensorIListRef(list); +} + Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); @@ -578,18 +545,12 @@ Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, c return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta); } -std::vector create_functional_tensor_with_view_meta(const c10::List& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) { - std::vector outputs(view_to_wrap.size()); - for (const auto i : c10::irange(view_to_wrap.size())) { - outputs[i] = create_functional_tensor_with_view_meta(view_to_wrap[i], base, meta, i); - } - return outputs; -} - -std::vector create_functional_tensor_with_view_meta(const std::vector& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) { +std::vector create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) { std::vector outputs(view_to_wrap.size()); - for (const auto i : c10::irange(view_to_wrap.size())) { - outputs[i] = create_functional_tensor_with_view_meta(view_to_wrap[i], base, meta, i); + int64_t i = 0; + for (const auto& tensor : view_to_wrap) { + outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i); + i++; } return outputs; } diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index cf389715795af..27a7440fe36b5 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -183,44 +184,38 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); TORCH_API bool isFunctionalTensor(const c10::optional& t); -TORCH_API bool isFunctionalTensor(const c10::List& t_list); TORCH_API bool isFunctionalTensor( const c10::List>& t_list); -TORCH_API bool isFunctionalTensor(const c10::ArrayRef t_list); +TORCH_API bool isFunctionalTensor(ITensorListRef list); TORCH_API Tensor to_functional_tensor(const Tensor& tensor); TORCH_API c10::optional to_functional_tensor( const c10::optional& tensor); -TORCH_API c10::List to_functional_tensor( - const c10::List& t_list); TORCH_API c10::List> to_functional_tensor( const c10::List>& t_list); -TORCH_API std::vector to_functional_tensor( - const std::vector& t_list); -TORCH_API std::vector to_functional_tensor(const TensorList& t_list); +TORCH_API std::vector to_functional_tensor(ITensorListRef t_list); TORCH_API Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional = true); TORCH_API c10::optional from_functional_tensor( const c10::optional& t, bool assert_functional = true); -TORCH_API c10::List from_functional_tensor( - const c10::List& t_list); TORCH_API c10::List> from_functional_tensor( const c10::List>& t_list); -TORCH_API std::vector from_functional_tensor(const TensorList& tensors); +TORCH_API std::vector from_functional_tensor(ITensorListRef t_list); TORCH_API void sync(const at::Tensor& t); TORCH_API void sync(const c10::optional& t); -TORCH_API void sync(const c10::List t_list); -TORCH_API void sync(const at::TensorList t_list); TORCH_API void sync(const c10::List> t_list); +TORCH_API void sync(ITensorListRef t_list); TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other); -TORCH_API void replace_(const TensorList functional_tensor, TensorList other); +TORCH_API void replace_( + const ITensorListRef functional_tensor, + ITensorListRef other); TORCH_API void commit_update(const Tensor& functional_tensor); -TORCH_API void commit_update(const TensorList functional_tensor); +TORCH_API void commit_update(ITensorListRef functional_tensor); Tensor create_functional_tensor_with_view_meta( const Tensor& view_to_wrap, @@ -228,11 +223,7 @@ Tensor create_functional_tensor_with_view_meta( functionalization::ViewMeta meta, int64_t out_idx = 0); std::vector create_functional_tensor_with_view_meta( - const c10::List& view_to_wrap, - const Tensor& base, - functionalization::ViewMeta meta); -std::vector create_functional_tensor_with_view_meta( - const std::vector& view_to_wrap, + ITensorListRef view_to_wrap, const Tensor& base, functionalization::ViewMeta meta); diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp index ca38f7be31bd9..d9b726e52baca 100644 --- a/aten/src/ATen/NamedTensorUtils.cpp +++ b/aten/src/ATen/NamedTensorUtils.cpp @@ -410,12 +410,12 @@ std::vector broadcast_to_outnames( return unify_from_right(reference_names, tensor_names); } -std::vector compute_cat_outnames(ITensorListRef tensors) { +std::vector compute_cat_outnames(const MaterializedITensorListRef& tensors) { if (!at::has_names(tensors)) { return {}; } std::vector result; - for (const auto& tensor : tensors) { + for (const Tensor& tensor : tensors) { const auto tensor_names = tensor.names(); TORCH_CHECK(tensor_names.size() > 0, "zero-dimensional tensor cannot be concatenated"); TORCH_CHECK(result.empty() || tensor_names.size() == result.size(), diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h index a77f38501f53e..c9ff27c2d1b21 100644 --- a/aten/src/ATen/NamedTensorUtils.h +++ b/aten/src/ATen/NamedTensorUtils.h @@ -118,7 +118,8 @@ TORCH_API void propagate_names_for_expand( const Tensor& result, const Tensor& self); -TORCH_API std::vector compute_cat_outnames(ITensorListRef tensors); +TORCH_API std::vector compute_cat_outnames( + const MaterializedITensorListRef& tensors); TORCH_API std::vector compute_broadcast_outnames( const Tensor& self, diff --git a/aten/src/ATen/VmapTransforms.cpp b/aten/src/ATen/VmapTransforms.cpp index 20c792f73709b..71ef7a169026d 100644 --- a/aten/src/ATen/VmapTransforms.cpp +++ b/aten/src/ATen/VmapTransforms.cpp @@ -1,5 +1,6 @@ #include #include +#include #include namespace at { @@ -188,7 +189,7 @@ static Tensor alignBatchDimsAtFront( // 4. Expand each physical tensor so that they have output batch size equal // to `batch_sizes` VmapPhysicalViewVec -MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) { +MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) { // Figure out all of the collective vmap levels in `logical_tensors`. std::bitset collective_levels; for (const auto& logical_tensor : logical_tensors) { diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h index 53e476e2243fa..cece52dcbc410 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/VmapTransforms.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace at { @@ -55,7 +56,7 @@ using VmapDimVector = SmallVector; // and returns a VmapPhysicalView on the tensor(s). struct TORCH_API MultiBatchVmapTransform { static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); - static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); + static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); }; // VmapTransform for operators that broadcast all inputs. diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 13f8658c354d9..1d4f45c6345e7 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -97,8 +97,10 @@ static inline int64_t legacy_cat_wrap_dim( return dim; } -static inline int64_t legacy_cat_wrap_dim(int64_t dim, ITensorListRef tensors) { - for (auto& tensor : tensors) { +static inline int64_t legacy_cat_wrap_dim( + int64_t dim, + const MaterializedITensorListRef& tensors) { + for (const Tensor& tensor : tensors) { if (tensor.dim() == 1 && tensor.sizes()[0] == 0) { continue; } diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 087bb4c7bfe5e..1247cf31a40ad 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -699,8 +699,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { &ADD_NS(linalg_inv_ex)>::type::call))); // promote - KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote) KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) + KERNEL_CPU(ADD_NS(cat), "cat", Tensor (const ITensorListRef &, int64_t), promote) KERNEL_CPU(ADD_NS(index_copy), "index_copy", Tensor (const Tensor &, int64_t, const Tensor &, const Tensor &), promote) KERNEL_CPU(ADD_NS(index_copy), "index_copy.dimname", Tensor (const Tensor &, at::Dimname, const Tensor &, const Tensor &), promote) diff --git a/aten/src/ATen/core/IListRef.h b/aten/src/ATen/core/IListRef.h index 3dd669cc7d14e..0b0ff67b02e2d 100644 --- a/aten/src/ATen/core/IListRef.h +++ b/aten/src/ATen/core/IListRef.h @@ -313,7 +313,10 @@ using _MaterializedIListRefElem = typename std::conditional< T>::type; template -using MaterializedIListRef = std::vector<_MaterializedIListRefElem>>; +using MaterializedIListRefElem = _MaterializedIListRefElem>; + +template +using MaterializedIListRef = std::vector>; } // namespace detail @@ -510,6 +513,7 @@ class IListRef { using iterator = IListRefIterator; using const_iterator = IListRefIterator; + using reverse_iterator = std::reverse_iterator; using value_type = typename iterator::value_type; IListRef() : tag_(IListRefTag::None) {} diff --git a/aten/src/ATen/core/IListRef_inl.h b/aten/src/ATen/core/IListRef_inl.h index a14bcfddae2de..534272f69b64f 100644 --- a/aten/src/ATen/core/IListRef_inl.h +++ b/aten/src/ATen/core/IListRef_inl.h @@ -93,9 +93,9 @@ class IListRefTagImplBase { * implementation for `IListRefTag::Materialized`. */ template -class IListRefTagImplBase> { +class IListRefTagImplBase> { public: - using elem_type = _MaterializedIListRefElem; + using elem_type = MaterializedIListRefElem; using list_type = MaterializedIListRef; static const list_type& unwrap(const IListRef& ilist) { @@ -141,7 +141,7 @@ class IListRefTagImpl : public IListRefTagImplBase< IListRefTag::Materialized, at::Tensor, - _MaterializedIListRefElem> {}; + MaterializedIListRefElem> {}; /* * [Note: IOptTensorListRef] @@ -182,7 +182,7 @@ class IListRefTagImpl : public IListRefTagImplBase< IListRefTag::Materialized, at::OptionalTensorRef, - _MaterializedIListRefElem> {}; + MaterializedIListRefElem> {}; } // namespace detail } // namespace c10 diff --git a/aten/src/ATen/core/TorchDispatchUtils.cpp b/aten/src/ATen/core/TorchDispatchUtils.cpp index 323019b3bbbb3..e2f981c6a8332 100644 --- a/aten/src/ATen/core/TorchDispatchUtils.cpp +++ b/aten/src/ATen/core/TorchDispatchUtils.cpp @@ -8,8 +8,8 @@ bool tensor_has_dispatch(const at::Tensor& t) { return t.key_set().has_any(key_set); } -bool tensorlist_has_dispatch(const at::TensorList& li) { - for (const auto& t: li) { +bool tensorlist_has_dispatch(at::ITensorListRef li) { + for (const auto& t : li) { if (tensor_has_dispatch(t)) { return true; } diff --git a/aten/src/ATen/core/TorchDispatchUtils.h b/aten/src/ATen/core/TorchDispatchUtils.h index 08c009c81b478..ed7b4181095d5 100644 --- a/aten/src/ATen/core/TorchDispatchUtils.h +++ b/aten/src/ATen/core/TorchDispatchUtils.h @@ -10,7 +10,7 @@ namespace at { namespace impl { bool tensor_has_dispatch(const at::Tensor& t); -bool tensorlist_has_dispatch(const at::TensorList& li); +bool tensorlist_has_dispatch(at::ITensorListRef li); bool tensorlist_has_dispatch(const c10::List>& li); using c10::impl::dispatch_mode_enabled; diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h index d33f3d575177c..61b6a35a0b1cb 100644 --- a/aten/src/ATen/core/Variadic.h +++ b/aten/src/ATen/core/Variadic.h @@ -48,6 +48,15 @@ struct IterArgs { // you may be able to process these structures more efficiently // than handling them one-by-one. + template + void operator()(c10::IListRef args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + template void operator()(at::ArrayRef args) { for (const auto& arg : args) { diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 8c5ced3d462bd..7f728dd0333bf 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -342,6 +343,13 @@ namespace impl { } }; + template + struct ivalue_to_arg final { + static List call(IValue& v) { + return v.toTensorList(); + } + }; + template struct ivalue_to_arg, AllowDeprecatedTypes> final { // If an argument is ArrayRef, convert the IValue to a std::vector and pass that diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 76082c5b01a4b..27c6e26721a2e 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -74,6 +74,12 @@ namespace detail { } } } + // Structured Tensor[] translates to this case + void operator()(at::ITensorListRef xs) { + for (const auto& x : xs) { + ts = ts | x.key_set(); + } + } void operator()(at::ArrayRef>) { // Just checking that the handling of Tensor?[] didn't change. TORCH_INTERNAL_ASSERT(false); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 38211f2dbd5b7..d04107126cc5b 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -28,6 +28,8 @@ template class Dict; template class List; +template +class IListRef; struct IValue; struct ClassType; struct Type; @@ -697,6 +699,15 @@ struct TORCH_API IValue final { template IValue(std::array v); + template + using enable_if_ilist_is_ivalue_constructible = std::enable_if_t< + std::is_constructible::value && + std::is_constructible::boxed_type>::value, + std::nullptr_t>; + + template = nullptr> + IValue(c10::IListRef v); + // GenericDict IValue(c10::Dict v); bool isGenericDict() const { diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 5f6eaef1263ec..9bef63d8b2c54 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -2037,6 +2038,25 @@ inline IValue::IValue(std::array v) : IValue(c10::List()) { } } +template > +inline IValue::IValue(c10::IListRef v) : IValue() { + constexpr bool boxed_type_constructs_ivalue = + std::is_constructible::boxed_type>::value; + // First, we try to use the boxed value. + // If we fail (either it's not in the boxed state, or its boxed type + // can not construct an IValue), we fallback to copying the list. + if (boxed_type_constructs_ivalue && v.isBoxed()) { + *this = IValue(impl::toList(v.toBoxed())); + } else { + c10::List list; + list.reserve(v.size()); + for (const auto& t : v) { + list.push_back(t); + } + *this = IValue(impl::toList(std::move(list))); + } +} + inline IValue::IValue(c10::impl::GenericDict v) : tag(Tag::GenericDict) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index b850886562a05..56783acdcb707 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1922,6 +1922,14 @@ struct getMaybeFakeTypePtr_, fake> final { return type; } }; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + static auto type = ListType::get("List", inner_type); + return type; + } +}; template struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { diff --git a/aten/src/ATen/core/op_registration/adaption.h b/aten/src/ATen/core/op_registration/adaption.h index 5bf1b691ebad3..3112a206bb4e1 100644 --- a/aten/src/ATen/core/op_registration/adaption.h +++ b/aten/src/ATen/core/op_registration/adaption.h @@ -68,7 +68,7 @@ inline void check_and_update_common_device(optional& common_device, cons } } -inline void check_and_update_common_device(optional& common_device, at::TensorList tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { +inline void check_and_update_common_device(optional& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { for (const auto& tensor : tensors) { check_and_update_common_device(common_device, tensor, methodName, argName); } diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp index 15410f1db9928..c9ae562080f6b 100644 --- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp @@ -82,7 +82,7 @@ static bool participatesInCurrentLevel(const Tensor& self) { return self_level == current_level; } -static bool participatesInCurrentLevel(TensorList self) { +static bool participatesInCurrentLevel(ITensorListRef self) { for (const Tensor& tensor : self) { if (participatesInCurrentLevel(tensor)) { return true; @@ -606,7 +606,7 @@ Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) { return makeBatched(output_physical, input_batched->bdim(), input_batched->level()); } -Tensor cat_batching_rule(TensorList tensors, int64_t dim) { +Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { if (!participatesInCurrentLevel(tensors)) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); return at::cat(tensors, dim); diff --git a/aten/src/ATen/functorch/LegacyVmapTransforms.cpp b/aten/src/ATen/functorch/LegacyVmapTransforms.cpp index 50ccc9abd520b..682169a52622d 100644 --- a/aten/src/ATen/functorch/LegacyVmapTransforms.cpp +++ b/aten/src/ATen/functorch/LegacyVmapTransforms.cpp @@ -118,7 +118,7 @@ static Tensor moveDimToFrontAndExpand(Tensor tensor, optional dim, int6 // 4. Expand each physical tensor so that they have output batch size equal // to `batch_sizes` VmapPhysicalViewVec -MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) { +MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) { auto cur_level = maybeCurrentDynamicLayer().value().layerId(); auto bdim_size = -1; diff --git a/aten/src/ATen/functorch/LegacyVmapTransforms.h b/aten/src/ATen/functorch/LegacyVmapTransforms.h index 199254a504c17..5fc05b6c8038c 100644 --- a/aten/src/ATen/functorch/LegacyVmapTransforms.h +++ b/aten/src/ATen/functorch/LegacyVmapTransforms.h @@ -64,7 +64,7 @@ using VmapDimVector = SmallVector; // and returns a VmapPhysicalView on the tensor(s). struct TORCH_API MultiBatchVmapTransform { static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); - static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); + static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); }; // VmapTransform for operators that broadcast all inputs. diff --git a/aten/src/ATen/functorch/PlumbingHelper.cpp b/aten/src/ATen/functorch/PlumbingHelper.cpp index 3101f1dd9a310..5dd01d0abbcbe 100644 --- a/aten/src/ATen/functorch/PlumbingHelper.cpp +++ b/aten/src/ATen/functorch/PlumbingHelper.cpp @@ -50,7 +50,7 @@ bool isBatchedAtLevel(const c10::optional& maybe_tensor, int64_t level) return isBatchedAtLevel(*maybe_tensor, level); } -bool isBatchedAtLevel(TensorList tensors, int64_t level) { +bool isBatchedAtLevel(ITensorListRef tensors, int64_t level) { for (const auto& tensor : tensors) { if (isBatchedAtLevel(tensor, level)) { return true; diff --git a/aten/src/ATen/functorch/PlumbingHelper.h b/aten/src/ATen/functorch/PlumbingHelper.h index d1bd0a22ff3e8..9eb486a6eefa0 100644 --- a/aten/src/ATen/functorch/PlumbingHelper.h +++ b/aten/src/ATen/functorch/PlumbingHelper.h @@ -39,7 +39,7 @@ TORCH_API std::tuple> unwrapTensorAtLevel(const Tensor TORCH_API std::vector makeBatchedVector(const std::vector& tensors, optional bdim, int64_t level); // Returns True if ANY tensor in tensors is batched at level -TORCH_API bool isBatchedAtLevel(TensorList tensors, int64_t level); +TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level); TORCH_API bool isBatchedAtLevel(const c10::List> maybe_tensors, int64_t level); TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level); TORCH_API bool isBatchedAtLevel(const c10::optional& maybe_tensor, int64_t level); diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index a61d8fe7e56bf..a8ce11163fde0 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -1697,11 +1697,11 @@ c10::optional to_meta(const c10::optional& tensor) { return c10::nullopt; } -std::vector to_meta(const at::TensorList& t_list) { +std::vector to_meta(at::ITensorListRef t_list) { std::vector outs; outs.reserve(t_list.size()); - for (const auto& i : c10::irange(t_list.size())) { - outs.push_back(to_meta(t_list[i])); + for (const auto& tensor : t_list) { + outs.push_back(to_meta(tensor)); } return outs; } diff --git a/aten/src/ATen/native/TensorConversions.h b/aten/src/ATen/native/TensorConversions.h index 75a01ea0e7554..8ec21a75dcac1 100644 --- a/aten/src/ATen/native/TensorConversions.h +++ b/aten/src/ATen/native/TensorConversions.h @@ -19,7 +19,7 @@ bool to_will_alias( Tensor to_meta(const Tensor& tensor); c10::optional to_meta(const c10::optional& tensor); -std::vector to_meta(const at::TensorList& t_list); +std::vector to_meta(at::ITensorListRef t_list); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index a1f050ee76f4f..105bbcb2e5444 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -57,7 +57,7 @@ inline c10::MemoryFormat cat_compute_output_memory_format(const MaterializedITen return format.value(); } -TORCH_PRECOMPUTE_META_FUNC(cat)(ITensorListRef tensors, int64_t dim) { +TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific @@ -65,10 +65,10 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(ITensorListRef tensors, int64_t dim) { auto materialized = tensors.materialize(); cat_check_no_zero_dim(materialized); - dim = at::legacy_cat_wrap_dim(dim, tensors); + dim = at::legacy_cat_wrap_dim(dim, materialized); // Checking names before the actual dimensions. - auto maybe_outnames = namedinference::compute_cat_outnames(tensors); + auto maybe_outnames = namedinference::compute_cat_outnames(materialized); TORCH_CHECK( materialized.size() > 0, "torch.cat(): expected a non-empty list of Tensors"); @@ -367,7 +367,7 @@ std::vector broadcast_tensors(TensorList tensors) { } TORCH_IMPL_FUNC(cat_out_cpu) -(ITensorListRef tensors, +(const ITensorListRef& tensors, int64_t dim, int64_t valid, bool all_contiguous, @@ -515,16 +515,16 @@ static void check_cat_sparse_dims(Tensor const &t, ", but tensor at position ", pos, " has ", t.sparse_dim(), ", ", t.dense_dim(), "."); } -static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) { +static Tensor cat_sparse_impl(const MaterializedITensorListRef& tensors, int64_t dim) { std::vector indices; std::vector values; - int64_t wrapped = maybe_wrap_dim(dim, tensors[0].dim()); - int64_t sparse_dim = tensors[0].sparse_dim(); - int64_t dense_dim = tensors[0].dense_dim(); - IntArrayRef sizes = tensors[0].sizes(); + int64_t wrapped = maybe_wrap_dim(dim, tensors[0].get().dim()); + int64_t sparse_dim = tensors[0].get().sparse_dim(); + int64_t dense_dim = tensors[0].get().dense_dim(); + IntArrayRef sizes = tensors[0].get().sizes(); if (wrapped < sparse_dim) { for (const auto i : c10::irange(tensors.size())) { - auto const &t = tensors[i]; + const Tensor& t = tensors[i]; check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim); indices.push_back(t._indices()); values.push_back(t._values()); @@ -543,7 +543,7 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) { int64_t col = 0; int64_t cumulative_offset = 0; for (const auto i : c10::irange(tensors.size())) { - auto const &t = tensors[i]; + const Tensor& t = tensors[i]; int64_t this_piece_size = t._nnz(); // cumulative_offset is zero for the first piece, so // don't waste time doing this operation unless i > 0. @@ -559,10 +559,10 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) { idxs, vals, sizes_copy, - optTypeMetaToScalarType(tensors[0].options().dtype_opt()), - tensors[0].options().layout_opt(), - tensors[0].options().device_opt(), - tensors[0].options().pinned_memory_opt()); + optTypeMetaToScalarType(tensors[0].get().options().dtype_opt()), + tensors[0].get().options().layout_opt(), + tensors[0].get().options().device_opt(), + tensors[0].get().options().pinned_memory_opt()); } else { // Catting along a dense dimension requires us to create new values. @@ -584,15 +584,19 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) { // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting. int64_t values_dim = wrapped - sparse_dim + 1; // The final size along the catted dimension. - const int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), static_cast(0), [values_dim](int64_t l, Tensor const &r) { - return l + r._values().size(values_dim); - }); - auto zeros_sizes = tensors[0]._values().sizes().vec(); + const int64_t total_size = std::accumulate( + tensors.begin(), + tensors.end(), + static_cast(0), + [values_dim](int64_t l, const Tensor& r) { + return l + r._values().size(values_dim); + }); + auto zeros_sizes = tensors[0].get()._values().sizes().vec(); int64_t cumulative_size = 0; std::vector vals_pieces; std::vector idxs_pieces; for (const auto i : c10::irange(tensors.size())) { - auto const &t = tensors[i]; + const Tensor& t = tensors[i]; check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim); // dimension 0 of values corresponds to the number of values, // rather than to any logical dimension of the sparse tensor. @@ -622,16 +626,17 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) { at::cat(idxs_pieces, 1), at::cat(vals_pieces), sizes_copy, - optTypeMetaToScalarType(tensors[0].options().dtype_opt()), - tensors[0].options().layout_opt(), - tensors[0].options().device_opt(), - tensors[0].options().pinned_memory_opt()); + optTypeMetaToScalarType(tensors[0].get().options().dtype_opt()), + tensors[0].get().options().layout_opt(), + tensors[0].get().options().device_opt(), + tensors[0].get().options().pinned_memory_opt()); } } -Tensor cat_sparse(TensorList tensors, int64_t dim) { - auto maybe_outnames = namedinference::compute_cat_outnames(tensors); - auto result = cat_sparse_impl(tensors, at::legacy_cat_wrap_dim(dim, tensors)); +Tensor cat_sparse(const ITensorListRef& tensors, int64_t dim) { + auto materialized = tensors.materialize(); + auto maybe_outnames = namedinference::compute_cat_outnames(materialized); + auto result = cat_sparse_impl(materialized, at::legacy_cat_wrap_dim(dim, materialized)); namedinference::propagate_names_if_nonempty(result, maybe_outnames); return result; } diff --git a/aten/src/ATen/native/TensorShape.h b/aten/src/ATen/native/TensorShape.h index bb296b5ae5bc8..21d0ba78261ec 100644 --- a/aten/src/ATen/native/TensorShape.h +++ b/aten/src/ATen/native/TensorShape.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include namespace at { namespace native { @@ -26,11 +27,12 @@ inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & seco } } -inline void check_cat_no_zero_dim(at::ArrayRef tensors) { - for(const auto i : c10::irange(tensors.size())) { - auto& t = tensors[i]; +inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) { + int64_t i = 0; + for(const Tensor& t : tensors) { TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", i, ") cannot be concatenated"); + i++; } } diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 08605cf4ed1b7..389515eac1e6c 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -252,7 +252,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i } // namespace TORCH_IMPL_FUNC(cat_out_cuda) -(ITensorListRef tensors, +(const ITensorListRef& tensors, int64_t dim, int64_t valid, bool all_contiguous, diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 99dfbcecc24a9..9beafb3a15f3c 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -217,7 +217,7 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, //} TORCH_IMPL_FUNC(cat_out_mps) - (ITensorListRef inputs, + (const ITensorListRef& inputs, int64_t dimension, int64_t valid, bool all_contiguous, @@ -239,7 +239,7 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, idx++; } - dimension = legacy_cat_wrap_dim(dimension, inputs); + dimension = legacy_cat_wrap_dim(dimension, materialized_inputs); // previously, size [0] tensors were the only possible empty tensors; thus, it // wasn't possible to cat empty tensors unless all the other tensors were diff --git a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h index 506f0e46e573f..94023b2f8e973 100644 --- a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h +++ b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h @@ -143,7 +143,7 @@ using qupsample_bilinear2d_fn = void (*)( c10::optional scales_w); using qcat_nhwc_fn = Tensor (*)( - const c10::List& qxs, + const MaterializedITensorListRef& qxs, int64_t dim, double scale, int64_t zero_point); diff --git a/aten/src/ATen/native/quantized/cpu/TensorShape.cpp b/aten/src/ATen/native/quantized/cpu/TensorShape.cpp index 172ad041a610f..c3e846986716e 100644 --- a/aten/src/ATen/native/quantized/cpu/TensorShape.cpp +++ b/aten/src/ATen/native/quantized/cpu/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -19,7 +20,7 @@ DEFINE_DISPATCH(qcat_relu_nhwc_stub); namespace { -bool is_cat_nhwc_fast_path(const c10::List& qxs, int dim) { +bool is_cat_nhwc_fast_path(const MaterializedITensorListRef& qxs, int64_t dim) { TORCH_CHECK(qxs.size() > 0); bool is_fast_path = dim == 1; // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) @@ -35,21 +36,21 @@ bool is_valid_quantization_scheme(const Tensor& t) { return (qtype == kPerTensorAffine) || (qtype == kPerTensorSymmetric); } -bool all_inputs_sharing_qparams(TensorList qxs) { +bool all_inputs_sharing_qparams(const MaterializedITensorListRef& qxs) { bool is_valid = true; for (const auto i : c10::irange(1, qxs.size())) { - is_valid |= qxs[0].is_quantized(); - is_valid |= qxs[i].is_quantized() == qxs[0].is_quantized(); - is_valid |= qxs[i].qscheme() == qxs[0].qscheme(); - is_valid |= qxs[i].dtype() == qxs[0].dtype(); - if (qxs[0].qscheme() == kPerTensorAffine) { - is_valid |= qxs[i].q_scale() == qxs[0].q_scale(); - is_valid |= qxs[i].q_zero_point() == qxs[0].q_zero_point(); - } else if (qxs[0].qscheme() == kPerChannelAffine) { - is_valid |= qxs[i].q_per_channel_scales().equal(qxs[0].q_per_channel_scales()); - is_valid |= qxs[i].q_per_channel_zero_points().equal(qxs[0].q_per_channel_zero_points()); + is_valid |= qxs[0].get().is_quantized(); + is_valid |= qxs[i].get().is_quantized() == qxs[0].get().is_quantized(); + is_valid |= qxs[i].get().qscheme() == qxs[0].get().qscheme(); + is_valid |= qxs[i].get().dtype() == qxs[0].get().dtype(); + if (qxs[0].get().qscheme() == kPerTensorAffine) { + is_valid |= qxs[i].get().q_scale() == qxs[0].get().q_scale(); + is_valid |= qxs[i].get().q_zero_point() == qxs[0].get().q_zero_point(); + } else if (qxs[0].get().qscheme() == kPerChannelAffine) { + is_valid |= qxs[i].get().q_per_channel_scales().equal(qxs[0].get().q_per_channel_scales()); + is_valid |= qxs[i].get().q_per_channel_zero_points().equal(qxs[0].get().q_per_channel_zero_points()); } else { - TORCH_CHECK(false, "Unrecognized qscheme:", toString(qxs[0].qscheme())); + TORCH_CHECK(false, "Unrecognized qscheme:", toString(qxs[0].get().qscheme())); } } return is_valid; @@ -61,7 +62,7 @@ bool all_inputs_sharing_qparams(TensorList qxs) { */ template Tensor quantized_cat_impl( - const c10::List& qxs, + const MaterializedITensorListRef& qxs, int64_t dim, double scale, int64_t zero_point) { @@ -73,8 +74,8 @@ Tensor quantized_cat_impl( } } - const auto x_dtype = qxs.get(0).scalar_type(); - const auto x_qscheme = qxs.get(0).qscheme(); + const auto x_dtype = qxs[0].get().scalar_type(); + const auto x_qscheme = qxs[0].get().qscheme(); std::vector xs; xs.reserve(qxs.size()); // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) @@ -99,6 +100,15 @@ Tensor quantized_cat_impl( return qy; } +template +Tensor quantized_cat_impl( + ITensorListRef qxs, + int64_t dim, + double scale, + int64_t zero_point) { + return quantized_cat_impl(qxs.materialize(), dim, scale, zero_point); +} + template Tensor qcat( const c10::List& qxs, @@ -134,28 +144,29 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu_out"), TORCH_FN(qcat_out)); } -Tensor cat_quantized_cpu(TensorList qxs, int64_t dim) { - TORCH_CHECK(is_valid_quantization_scheme(qxs[0]), +Tensor cat_quantized_cpu(const ITensorListRef& qxs, int64_t dim) { + auto materialized = qxs.materialize(); + TORCH_CHECK(is_valid_quantization_scheme(materialized[0]), "Only per-tensor quantization is supported in 'cat'!"); TORCH_CHECK( - all_inputs_sharing_qparams(qxs), + all_inputs_sharing_qparams(materialized), "All inputs should share the same quantization parameters."); - check_cat_no_zero_dim(qxs); - dim = legacy_cat_wrap_dim(dim, qxs); - double _scale = qxs[0].q_scale(); - int64_t _zero_point = qxs[0].q_zero_point(); - return quantized_cat_impl(c10::List(qxs), dim, _scale, _zero_point); + check_cat_no_zero_dim(materialized); + dim = legacy_cat_wrap_dim(dim, materialized); + double _scale = materialized[0].get().q_scale(); + int64_t _zero_point = materialized[0].get().q_zero_point(); + return quantized_cat_impl(materialized, dim, _scale, _zero_point); } -Tensor& cat_out_quantized_cpu(TensorList qxs, int64_t dim, Tensor& out) { - TORCH_CHECK(is_valid_quantization_scheme(qxs[0]), +Tensor& cat_out_quantized_cpu(const ITensorListRef& qxs, int64_t dim, Tensor& out) { + auto materialized = qxs.materialize(); + TORCH_CHECK(is_valid_quantization_scheme(materialized[0]), "Only per-tensor quantization is supported in 'cat'!") TORCH_CHECK(is_valid_quantization_scheme(out), "Only per-tensor quantization is supported in 'cat'!") - check_cat_no_zero_dim(qxs); - dim = legacy_cat_wrap_dim(dim, qxs); - auto out_ = quantized_cat_impl(c10::List(qxs), dim, out.q_scale(), - out.q_zero_point()); + check_cat_no_zero_dim(materialized); + dim = legacy_cat_wrap_dim(dim, materialized); + auto out_ = quantized_cat_impl(qxs, dim, out.q_scale(), out.q_zero_point()); at::native::copy_(out, out_, /*non_blocking=*/false); return out; } diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 34972d53dd5c7..369de19372e23 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -47,7 +47,7 @@ void check_tensor_memory_format(const Tensor& ref, const Tensor& other) { template Tensor qcat_nhwc_kernel( - const c10::List& qxs, + const MaterializedITensorListRef& qxs, int64_t dim, double scale, int64_t zero_point) { diff --git a/aten/src/ATen/native/vulkan/ops/Concat.cpp b/aten/src/ATen/native/vulkan/ops/Concat.cpp index 4ab543f5527f0..ac15b3924b080 100644 --- a/aten/src/ATen/native/vulkan/ops/Concat.cpp +++ b/aten/src/ATen/native/vulkan/ops/Concat.cpp @@ -16,20 +16,22 @@ inline int64_t normalize_dim(int64_t d, int64_t n) { } } // namespace -Tensor cat_batch(const TensorList tensors, vTensor& v_output) { +Tensor cat_batch(const MaterializedITensorListRef& tensors, vTensor& v_output) { TORCH_CHECK(false, "Vulkan cat not implemented for batch dimension!"); } -Tensor cat_feature(const TensorList tensors, vTensor& v_output) { +Tensor cat_feature( + const MaterializedITensorListRef& tensors, + vTensor& v_output) { api::Context* const context = api::context(); int64_t ch_size_allprior = 0; int64_t ch_interval = 0; - for (const auto& tensor : tensors) { + for (const at::Tensor& tensor : tensors) { ch_interval += tensor.sizes()[1]; } - for (const auto& tensor : tensors) { + for (const at::Tensor& tensor : tensors) { const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan(); const vTensor& v_self = convert(self); @@ -84,12 +86,14 @@ Tensor cat_feature(const TensorList tensors, vTensor& v_output) { return convert(v_output); } -Tensor cat_feature_mult4ch(const TensorList tensors, vTensor& v_output) { +Tensor cat_feature_mult4ch( + const MaterializedITensorListRef& tensors, + vTensor& v_output) { api::Context* const context = api::context(); int64_t depth_size_allprior = 0; int64_t ch_interval = 0; - for (const auto& tensor : tensors) { + for (const at::Tensor& tensor : tensors) { ch_interval += tensor.sizes()[1]; } const int64_t depth_interval = ch_interval / 4; @@ -97,7 +101,7 @@ Tensor cat_feature_mult4ch(const TensorList tensors, vTensor& v_output) { uvec3 src_offset{}; uvec3 dst_offset{}; - for (const auto& tensor_arg : tensors) { + for (const at::Tensor& tensor_arg : tensors) { const Tensor tensor = tensor_arg.is_vulkan() ? tensor_arg : tensor_arg.vulkan(); const vTensor& v_self = convert(tensor); @@ -137,17 +141,19 @@ Tensor cat_feature_mult4ch(const TensorList tensors, vTensor& v_output) { return convert(v_output); } -Tensor cat_width(const TensorList tensors, vTensor& v_output) { +Tensor cat_width(const MaterializedITensorListRef& tensors, vTensor& v_output) { TORCH_CHECK(false, "Vulkan cat not implemented for width dimension!"); } -Tensor cat_height(const TensorList tensors, vTensor& v_output) { +Tensor cat_height( + const MaterializedITensorListRef& tensors, + vTensor& v_output) { api::Context* const context = api::context(); uvec3 src_offset{}; uvec3 dst_offset{}; - for (const auto& tensor : tensors) { + for (const at::Tensor& tensor : tensors) { const vTensor& v_self = convert(tensor); api::PipelineBarrier pipeline_barrier{}; @@ -175,14 +181,15 @@ Tensor cat_height(const TensorList tensors, vTensor& v_output) { return convert(v_output); } -Tensor cat(const at::TensorList tensors, const int64_t dim) { +Tensor cat(const at::ITensorListRef& tensors, const int64_t dim) { TORCH_CHECK(tensors.size() > 0, "Vulkan cat expects at least one tensor"); - at::Tensor tensor = tensors[0]; + auto materialized = tensors.materialize(); + const at::Tensor& tensor = materialized[0]; int64_t cat_dim_size = 0; bool is_mult4ch = true; - for (const auto& t : tensors) { + for (const at::Tensor& t : materialized) { TORCH_INTERNAL_ASSERT( t.dim() == 4, "Vulkan cat expects 4 dimensional inputs"); @@ -207,17 +214,17 @@ Tensor cat(const at::TensorList tensors, const int64_t dim) { vTensor v_output{api::context(), result_size, tensor.options()}; if (dim == 3) { - return cat_width(tensors, v_output); + return cat_width(materialized, v_output); } if (dim == 2) { - return cat_height(tensors, v_output); + return cat_height(materialized, v_output); } else if (dim == 1) { if (is_mult4ch) { - return cat_feature_mult4ch(tensors, v_output); + return cat_feature_mult4ch(materialized, v_output); } - return cat_feature(tensors, v_output); + return cat_feature(materialized, v_output); } - return cat_batch(tensors, v_output); + return cat_batch(materialized, v_output); } #ifdef USE_VULKAN_API diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp index 7b64cde1ad6ee..7160856798932 100644 --- a/aten/src/ATen/templates/RegisterFunctionalization.cpp +++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp @@ -50,10 +50,11 @@ inline c10::optional to_meta(const c10::optional& t) { return c10::nullopt; } -inline std::vector to_meta(const TensorList& t_list) { - std::vector outputs(t_list.size()); - for (const auto i : c10::irange(t_list.size())) { - outputs[i] = to_meta(t_list[i]); +inline std::vector to_meta(at::ITensorListRef t_list) { + std::vector outputs; + outputs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outputs.push_back(to_meta(tensor)); } return outputs; } diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 662712c641f11..836b3651d6013 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -41,6 +41,7 @@ namespace c10{ template class List; +template class IListRef; } namespace at { struct Generator; @@ -65,6 +66,7 @@ namespace at { class OptionalTensorRef; class Tensor; using TensorList = ArrayRef; +using ITensorList = c10::IListRef; using Stream = c10::Stream; diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index 55f1faba2750b..dad452b5216a3 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -68,8 +68,13 @@ def value_has_tensors(v): def value_is_tensor_type(v): - return value_has_tensors(v) and v['dynamic_type'] not in ['at::TensorList', 'const c10::List> &'] + return value_has_tensors(v) and v['dynamic_type'] not in TENSORLIST_TYPE +TENSORLIST_TYPE = [ + 'at::TensorList', + 'const at::ITensorListRef &', + 'const c10::List> &', +] # for each aten type, how do we handle a return value of that type? RETURN_MAP = { @@ -208,7 +213,7 @@ def self_as_first_argument(arguments): def get_num_inputs(o): args = 0 for a in o['arguments']: - if a['type'] in ['at::TensorList', 'const c10::List> &']: + if a['type'] in TENSORLIST_TYPE: return '*' elif value_has_tensors(a): args += 1 @@ -277,17 +282,17 @@ def emit_assignments(o, env): # e.g. "Float" is at::kFloat assert('Type' in o['method_of']) - static_tensor_inputs = sum(arg['type'] not in ['at::TensorList', 'const c10::List> &'] and value_is_tensor_type(arg) for arg in o['arguments']) - has_tensorlist = any(arg['type'] in ['at::TensorList', 'const c10::List> &'] for arg in o['arguments']) + static_tensor_inputs = sum(arg['type'] not in TENSORLIST_TYPE and value_is_tensor_type(arg) for arg in o['arguments']) + has_tensorlist = any(arg['type'] in TENSORLIST_TYPE for arg in o['arguments']) if has_tensorlist: - tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['at::TensorList', 'const c10::List> &']][0] + tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in TENSORLIST_TYPE][0] real_inputs = 0 for i, arg in enumerate(o['arguments']): env['arguments'].append(arg['name']) # Pretend the flat argument list is a stack where the end is the top. view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs - if arg['type'] == 'at::TensorList': + if arg['type'] == 'at::TensorList' or arg['type'] == 'const at::ITensorListRef &': # NOTE: do not advance real_inputs here. After this we will # switch to indexing the "stack" from the end env['statements'].append( diff --git a/test/test_overrides.py b/test/test_overrides.py index 9a7747b5f2171..70ca676a1762c 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -635,7 +635,7 @@ def instance_gen(): func = func.__get__(instance_gen()) continue func_args.append(instance_gen()) - elif t == 'TensorList': + elif t == 'TensorList' or t == 'ITensorListRef': func_args.append([instance_gen(), instance_gen()]) elif t == 'c10::List>': func_args.append([instance_gen(), instance_gen()]) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 5c6e75bef895f..b0277ad5cc5f8 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -20,6 +20,7 @@ boolT, doubleT, intArrayRefT, + iTensorListRefT, ListCType, longT, MutRefCType, @@ -29,6 +30,7 @@ stringT, symIntArrayRefT, SymIntT, + TENSOR_LIST_LIKE_CTYPES, tensorListT, tensorT, ) @@ -472,10 +474,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str py_getsetdef_structs: List[str] = [] for arg in info.args_with_derivatives: - if ( - arg.type == "at::TensorList" - or arg.type == "const c10::List> &" - ): + if arg.type in TENSOR_LIST_LIKE_CTYPES: size = f"{arg.name}_size_" saved_list_sizes.append(f"size_t {arg.name}_size_;") else: @@ -509,7 +508,7 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: ) ) should_append_raw_getsetdef = True - elif type == BaseCType(tensorListT): + elif type == BaseCType(tensorListT) or type == BaseCType(iTensorListRefT): saved_variables.append(f"std::vector {name}_;") saved_variables.append(f"bool {name}_released_ = false;") # Just clear() is sufficient, we don't need to loop and clear each variable. diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 3135038865877..f5f3e3a0a765a 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -43,6 +43,7 @@ Binding, DispatcherSignature, intArrayRefT, + iTensorListRefT, ListCType, MutRefCType, OptionalCType, @@ -50,6 +51,7 @@ SpecialArgName, stringT, symIntArrayRefT, + TENSOR_LIST_LIKE_CTYPES, tensorListT, tensorT, TupleCType, @@ -1115,11 +1117,7 @@ def emit_check_if_in_complex_autograd_allowlist() -> List[str]: for arg in differentiable_outputs: name = arg.name # TODO: should be `arg.type.is_tensor_like()`? - if arg.cpp_type in [ - "at::Tensor", - "at::TensorList", - "const c10::List> &", - ]: + if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES: body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");') return body @@ -1197,8 +1195,10 @@ def save_variables( expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})" else: expr = f"SavedVariable({var}, {str(is_output).lower()})" - elif type == BaseCType(tensorListT) or type == ListCType( - OptionalCType(BaseCType(tensorT)) + elif ( + type == BaseCType(tensorListT) + or type == ListCType(OptionalCType(BaseCType(tensorT))) + or type == BaseCType(iTensorListRefT) ): expr = f"make_saved_variable_list({name})" name += "_" @@ -1277,7 +1277,9 @@ def check_tensorimpl_and_storage( for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() - if noref_cpp_type == BaseCType(tensorListT): + if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType( + iTensorListRefT + ): stmts_before_call += [ SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg), diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 5ee2083250a55..423db4ab51656 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -272,7 +272,7 @@ def postprocess_forward_derivatives( def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: required_inputs = set() for arg in args_with_derivatives: - if arg.type == "at::TensorList": + if arg.type in ("at::TensorList", "const at::ITensorListRef &"): # The functions taking TensorList handle everything internally continue arg_name = arg.name diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index f7c7450c831f9..ad2abc2bdb724 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -51,7 +51,7 @@ namespace VariableType { at::Tensor & unpack(Tensor & t, const char * name, int pos); const at::Tensor & unpack(const Tensor & t, const char * name, int pos); at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); - std::vector unpack(at::TensorList tl, const char *name, int pos); + std::vector unpack(at::ITensorListRef tl, const char *name, int pos); }; }} // namespace torch::autograd diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index bf2ea9e58ce82..0948424414e1f 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5660,18 +5660,19 @@ Tensor lu_unpack_backward( } } -Tensor cat_jvp(at::TensorList tensors, int64_t dim) { +Tensor cat_jvp(at::ITensorListRef tensors, int64_t dim) { Tensor out_fw_grad; + auto materialized = tensors.materialize(); auto any_defined = false; - for (const auto& t : tensors) { + for (const Tensor& t : materialized) { any_defined |= isFwGradDefined(t); } if (any_defined) { std::vector fw_grads; - for (auto& t : tensors) { + for (const Tensor& t : materialized) { fw_grads.push_back( isFwGradDefined(t) ? t._fw_grad(/*level*/ 0) diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 59ed8daabf992..0816d44ff6e0b 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -955,7 +955,7 @@ Tensor convolution_backward_jvp_grad_bias( const Tensor& grad_out_t, const Tensor& grad_bias); -Tensor cat_jvp(at::TensorList tensors, int64_t dim); +Tensor cat_jvp(at::ITensorListRef tensors, int64_t dim); Tensor block_diag_jvp(at::TensorList tensors); Tensor stack_jvp(at::TensorList tensors, int64_t dim); Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim); diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 9c668a8ef1f48..fa97497d76231 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -90,14 +90,14 @@ Tensor unpack_opt(const Tensor& t, const char* name, int pos) { return unpack(t, name, pos); } -std::vector unpack(at::TensorList tl, const char* name, int pos) { - std::vector ret(tl.size()); - for (const auto i : c10::irange(tl.size())) { - const auto& t = tl[i]; - if (!t.defined()) { - continue; - } - ret[i] = static_cast(t); +std::vector unpack( + at::ITensorListRef tl, + const char* name, + int pos) { + std::vector ret; + ret.reserve(tl.size()); + for (const auto& t : tl) { + ret.push_back(t.defined() ? static_cast(t) : Variable{}); } return ret; } diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 1cde834704032..169fdd03ae3d7 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -61,7 +61,7 @@ inline void check_inplace(const at::Tensor& tensor, bool requires_grad) { } } -inline void check_inplace(const at::TensorList tensors, bool requires_grad) { +inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) { for (const auto& tensor : tensors) { check_inplace(tensor, requires_grad); } @@ -98,7 +98,7 @@ inline void throw_error_if_base_and_tensor_are_same( } inline void throw_error_for_complex_autograd( - const at::TensorList& tensorlist, + at::ITensorListRef tensorlist, const char* name) { for (const auto& tensor : tensorlist) { throw_error_for_complex_autograd(tensor, name); @@ -395,7 +395,7 @@ inline void check_no_requires_grad( } inline void check_no_requires_grad( - at::TensorList tensors, + at::ITensorListRef tensors, const char* name, const char* fn_name = "") { // GradMode check is expensive, so check it only once for TensorLists @@ -424,7 +424,7 @@ inline void check_no_requires_grad( // Assumed that saved tensor lists are never inplace outputs inline std::vector make_saved_variable_list( - at::TensorList tensors) { + at::ITensorListRef tensors) { return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable { return SavedVariable{tensor, false /* is output */}; }); @@ -443,19 +443,22 @@ inline std::vector make_saved_variable_list( }); } -inline std::vector> to_args_sizes(at::TensorList tensors) { +inline std::vector> to_args_sizes( + at::ITensorListRef tensors) { std::vector> args_sizes(tensors.size()); - for (const auto i : c10::irange(tensors.size())) { - args_sizes[i] = tensors[i].sizes().vec(); + size_t i = 0; + for (const auto& t : tensors) { + args_sizes[i++] = t.sizes().vec(); } return args_sizes; } inline std::vector to_args_scalartypes( - at::TensorList tensors) { + at::ITensorListRef tensors) { std::vector args_scalartypes(tensors.size()); - for (const auto i : c10::irange(tensors.size())) { - args_scalartypes[i] = tensors[i].scalar_type(); + size_t i = 0; + for (const auto& t : tensors) { + args_scalartypes[i++] = t.scalar_type(); } return args_scalartypes; } diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index 75df1a0302c95..fbc1e79549043 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -100,7 +100,7 @@ inline bool isFwGradDefined(const c10::optional& t) { return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined(); } -inline bool isFwGradDefinedTensorList(const at::TensorList& variables) { +inline bool isFwGradDefinedTensorList(const at::ITensorListRef& variables) { bool ret = false; for (auto& variable : variables) { ret |= isFwGradDefined(variable); diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index abe2c43d77e82..057a483deb4c9 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -724,11 +724,24 @@ void addInputs( const c10::optional& value) { detail::genericAddOptionalInput(n, name, value); } - void addInputs( Node* n, const char* name, - at::TensorList value, + at::ArrayRef value, + bool allow_undefined) { + addInputs(n, name, at::ITensorListRef(value), allow_undefined); +} +void addInputs( + Node* n, + const char* name, + std::vector value, + bool allow_undefined) { + addInputs(n, name, at::ITensorListRef(value), allow_undefined); +} +void addInputs( + Node* n, + const char* name, + at::ITensorListRef value, bool allow_undefined) { Graph* g = n->owningGraph(); Node* list_node = nullptr; @@ -752,7 +765,6 @@ TORCH_API void addInputs( OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace))); n->addInput(list_node->output()); } - void addInputs( Node* n, const char* name, diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index 356c1c21c061f..612a3dfd39956 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -281,6 +281,16 @@ TORCH_API void addInputs( const char* name, ArrayRef value, bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + std::vector value, + bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + at::ITensorListRef value, + bool allow_undefined = false); TORCH_API void addInputs( Node* n, const char* name, diff --git a/torch/csrc/lazy/backend/backend_device.cpp b/torch/csrc/lazy/backend/backend_device.cpp index ca19d1c42d7e8..e178aab755d84 100644 --- a/torch/csrc/lazy/backend/backend_device.cpp +++ b/torch/csrc/lazy/backend/backend_device.cpp @@ -54,7 +54,7 @@ c10::Device backendDeviceToAtenDevice(const BackendDevice& device) { return c10::Device(at::kLazy, device.ordinal()); } -c10::optional GetBackendDevice(const at::TensorList tensors) { +c10::optional GetBackendDevice(at::ITensorListRef tensors) { for (auto& tensor : tensors) { if (auto lt = TryGetLtcTensor(tensor)) { return lt->GetDevice(); @@ -63,6 +63,10 @@ c10::optional GetBackendDevice(const at::TensorList tensors) { return c10::nullopt; } +c10::optional GetBackendDevice(at::TensorList tensors) { + return GetBackendDevice(at::ITensorListRef(tensors)); +} + c10::optional GetBackendDevice(const at::Tensor& tensor) { if (auto lt = TryGetLtcTensor(tensor)) { return lt->GetDevice(); diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index 55d7ecdb5d3a5..920314993513a 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -73,6 +73,8 @@ TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device); // Tries to extract the backend device out of the lazy tensor. Returns nullopt // if the input is not a lazy tensor. +TORCH_API c10::optional GetBackendDevice( + const at::ITensorListRef tensors); TORCH_API c10::optional GetBackendDevice( const at::TensorList tensors); TORCH_API c10::optional GetBackendDevice( diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index 86971dc49bcb3..bf673a72361d3 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -470,7 +470,7 @@ int64_t LazyTensor::GetNextTensorId() { return id_generator->fetch_add(1); } -torch::lazy::Value GetTensorList(c10::ArrayRef tensors) { +torch::lazy::Value GetTensorList(at::ITensorListRef tensors) { std::vector values; for (const auto& t : tensors) { auto* impl = dynamic_cast(t.unsafeGetTensorImpl()); diff --git a/torch/csrc/lazy/core/tensor.h b/torch/csrc/lazy/core/tensor.h index bf521bb3f92f2..12cfdd2827d74 100644 --- a/torch/csrc/lazy/core/tensor.h +++ b/torch/csrc/lazy/core/tensor.h @@ -211,7 +211,7 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { // skips // the LazyTensor wrappers, assuming that the list of underlying IR nodes // is actually more useful for downstream computations. TBD. -TORCH_API torch::lazy::Value GetTensorList(c10::ArrayRef tensors); +TORCH_API torch::lazy::Value GetTensorList(at::ITensorListRef tensors); // Section 1: at::Tensor => LazyTensor. // Extracts the LazyTensor out of an at::Tensor. Returns a null LazyTensor diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 128215f2a2670..991e10a3b4b33 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -13,6 +13,7 @@ CType, dimnameListT, intArrayRefT, + iTensorListRefT, ListCType, longT, MutRefCType, @@ -189,7 +190,10 @@ def argumenttype_type( else: return NamedCType(binds, BaseCType(intArrayRefT)) if str(t.elem) == "Tensor": - return NamedCType(binds, BaseCType(tensorListT)) + if local.use_ilistref_for_tensor_lists(): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + else: + return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == "Scalar": return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) elif str(t.elem) == "Dimname": diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index a5ab3f6e54320..9ad45a37ac8cf 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -69,7 +69,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): - return NamedCType(binds, BaseCType(iTensorListRefT)) + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): return NamedCType(binds, BaseCType(iOptTensorListRefT)) # TODO: delete these special cases; see torchgen.api.cpp--these diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 913b5f673742c..7c371276a90ca 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -9,7 +9,6 @@ Expr, intArrayRefT, iOptTensorListRefT, - iTensorListRefT, layoutT, ListCType, longT, @@ -27,7 +26,6 @@ SpecialArgName, symIntArrayRefT, SymIntT, - tensorListT, tensorOptionsT, tensorT, VectorCType, @@ -184,12 +182,6 @@ def translate( NamedCType(t.name, BaseCType(opmath_t)) ] = f"static_cast({b.expr})" - # [Note: ITensorListRef] - if t.type == BaseCType(tensorListT): - ctx[ - NamedCType(t.name, BaseCType(iTensorListRefT)) - ] = f"at::ITensorListRef({b.expr})" - # [Note: IOptTensorListRef] if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): ctx[ diff --git a/torchgen/api/types.py b/torchgen/api/types.py index e8741c0e8f6b3..a3d141489aef2 100644 --- a/torchgen/api/types.py +++ b/torchgen/api/types.py @@ -17,6 +17,12 @@ _T = TypeVar("_T") +TENSOR_LIST_LIKE_CTYPES = [ + "at::TensorList", + "const c10::List> &", + "const at::ITensorListRef &", +] + # An ArgName is just the str name of the argument in schema; # but in some special circumstances, we may add a little extra # context. The Enum SpecialArgName covers all of these cases; diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index a1bca50538647..0a3aad42864ed 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -122,7 +122,9 @@ def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]: ) argument: Argument = arg.argument unboxed_name, _, code, decl = argumenttype_ivalue_convert( - argument.type, argument.name, mutable=argument.is_write + argument.type, + argument.name, + mutable=argument.is_write, ) code_list.extend(decl) code_list.extend(code) @@ -149,12 +151,18 @@ def argumenttype_ivalue_convert( elif isinstance(t, OptionalType): out_name = f"{arg_name}_opt_out" code, decl = _gen_code_optional_type( - arg_name=arg_name, out_name=out_name, t=t, ctype=ctype + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, ) elif isinstance(t, ListType): out_name = f"{arg_name}_list_out" code, decl = _gen_code_list_type( - arg_name=arg_name, out_name=out_name, t=t, ctype=ctype + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, ) else: raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") diff --git a/torchgen/context.py b/torchgen/context.py index 6924befb2550e..b643890d97992 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -51,7 +51,8 @@ def native_function_manager( f = g with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"): with local.parametrize( - use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors + use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=f.part_of_structured_group, ): yield diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index fd35e5cb27883..236c0ab7317fe 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -8,6 +8,7 @@ CType, DispatcherSignature, FunctionalizationLambda, + iTensorListRefT, NativeSignature, tensorListT, tensorT, @@ -173,6 +174,8 @@ def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]: if t == BaseCType(tensorListT): return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" + if t == BaseCType(iTensorListRefT): + return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}" # There are technically other non-owning types out there (like IntArrayRef), # but functionalization only actually cares about the ones involving tensors. return t, lambda x: x diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index ab0d2a1b4cac4..46f11eeff13c8 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -182,11 +182,11 @@ def get_ltc_helper_fns() -> str: return c10::nullopt; } -std::vector to_meta(const at::TensorList& t_list) { +std::vector to_meta(at::ITensorListRef t_list) { std::vector outs; outs.reserve(t_list.size()); - for (const auto& i : c10::irange(t_list.size())) { - outs.push_back(to_meta(t_list[i])); + for (const auto& tensor : t_list) { + outs.push_back(to_meta(tensor)); } return outs; } diff --git a/torchgen/local.py b/torchgen/local.py index 65efce2c3b11b..f72e53601ab12 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -17,6 +17,7 @@ class Locals(threading.local): use_const_ref_for_mutable_tensors: Optional[bool] = None + use_ilistref_for_tensor_lists: Optional[bool] = None _locals = Locals() @@ -30,13 +31,26 @@ def use_const_ref_for_mutable_tensors() -> bool: return _locals.use_const_ref_for_mutable_tensors +def use_ilistref_for_tensor_lists() -> bool: + assert _locals.use_ilistref_for_tensor_lists is not None, ( + "need to initialize local.use_ilistref_for_tensor_lists with " + "local.parametrize" + ) + return _locals.use_ilistref_for_tensor_lists + + @contextmanager -def parametrize(*, use_const_ref_for_mutable_tensors: bool) -> Iterator[None]: +def parametrize( + *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool +) -> Iterator[None]: old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors + old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists try: _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors + _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists yield finally: _locals.use_const_ref_for_mutable_tensors = ( old_use_const_ref_for_mutable_tensors ) + _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists diff --git a/torchgen/model.py b/torchgen/model.py index 6f484c0eae4cc..f87f2be28b7f3 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -965,6 +965,10 @@ def view_schema_kind(self) -> ViewSchemaKind: def root_name(self) -> str: return self.func.name.name.base + @property + def part_of_structured_group(self) -> bool: + return self.structured or self.structured_delegate is not None + SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out", "mutable", "scratch")) @@ -994,6 +998,12 @@ def __post_init__(self) -> None: "NativeFunctionsGroup constructed from two NativeFunctions " f"that don't have matching signatures: {test_sig} != {f.func.signature()}" ) + + if self.structured != f.part_of_structured_group: + raise AssertionError( + "NativeFunctionsGroup constructed from structured and unstructured " + f"functions: {self.out.func.name} and {f.func.name}" + ) assert self.functional.func.kind() == SchemaKind.functional assert self.out.func.kind() == SchemaKind.out assert self.functional.namespace == self.out.namespace