From af7a83817013c948370e1be2838f83323459285f Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Thu, 23 Jan 2025 00:52:28 -0800 Subject: [PATCH] PR #20794: [gpu][ds-fusion] Add handling for offset module in ds-fusion thunk Imported from GitHub PR https://github.com/openxla/xla/pull/20794 This patch adds support for offset modules in ds fusion thunk. It also moves the `ResourceRequests` structure from `gpu_executable.cc` to `gpu_executable.h` because a valid implementation of the abstract class `Thunk::ResourceRequests` is required for calling `Thunk::Prepare()`. This is split from #20332 as per request. Copybara import of the project: -- 53226e9428dff816819c192735b9569ef3d309ea by Shraiysh Vaishay : [gpu][ds-fusion] Add handling for offset module in ds-fusion thunk This patch adds support for offset modules in ds fusion thunk. It also moves the `ResourceRequests` structure from `gpu_executable.cc` to `gpu_executable.h` because a valid implementation of the abstract class `Thunk::ResourceRequests` is required for calling `Thunk::Prepare()`. -- e554ac6b09b5023082210c5100a1f1a86c1c6605 by Shraiysh Vaishay : Address comments and rebase -- 4ac8efa0c9885b0f21f526072f821ace40be4043 by Shraiysh Vaishay : Address comments -- 8c23e2a4e2ff66fb7361c48c9d988ebb2261bc41 by Shraiysh Vaishay : Addressed comments. -- efbe6bfa6f10f1ee998683161e398ae3adc0f183 by Shraiysh Vaishay : Addressed comments Merging this change closes #20794 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/20794 from shraiysh:ds_fusion_thunk_changes efbe6bfa6f10f1ee998683161e398ae3adc0f183 PiperOrigin-RevId: 718748248 --- xla/backends/gpu/runtime/BUILD | 4 + .../gpu/runtime/dynamic_slice_thunk.cc | 61 ++++- .../gpu/runtime/dynamic_slice_thunk.h | 67 ++++- .../gpu/runtime/dynamic_slice_thunk_test.cc | 235 +++++++++++++++++- 4 files changed, 352 insertions(+), 15 deletions(-) diff --git a/xla/backends/gpu/runtime/BUILD b/xla/backends/gpu/runtime/BUILD index fb4c4e33912d7a..46b166d26d65b6 100644 --- a/xla/backends/gpu/runtime/BUILD +++ b/xla/backends/gpu/runtime/BUILD @@ -198,6 +198,7 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status_macros", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:ir_emission_utils", @@ -243,10 +244,12 @@ xla_test( "//xla/ffi:ffi_api", "//xla/service:buffer_assignment", "//xla/service:executable", + "//xla/service:hlo_runner", "//xla/service:platform_util", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:resource_requests", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", @@ -258,6 +261,7 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:hlo_runner_agnostic_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", diff --git a/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index 46f21f6cbf84ec..3d7abbbbbcc870 100644 --- a/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/while_thunk.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -47,6 +48,25 @@ limitations under the License. namespace xla { namespace gpu { +namespace { + +// Indvar is a thread-local map that stores the induction variable for each +// dynamic slice thunk. The same thunk object in the memory is shared by +// multiple replicas of the same computation. So, each replica should have its +// own tracking of the induction variable (threadlocal). With threadlocal, we +// cannot embed this inside the dynamic slice thunk object, and so we have a +// static map. There could be multiple dynamic slice thunks in the same module, +// and so we need a map to store the induction variable for each thunk. The +// usage of threadlocal in this context is similar to `LoopCounters` in +// while_thunk.cc (b/343294327). +Literal& Indvar(DynamicSliceThunk* thunk) { + static thread_local absl::flat_hash_map + indvar_map; + return indvar_map[thunk]; +} + +} // namespace + DynamicSliceThunk::DynamicSliceThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, std::vector> arguments, @@ -54,7 +74,9 @@ DynamicSliceThunk::DynamicSliceThunk( std::vector>> offsets, std::vector> orig_shapes, std::vector> sliced_shapes, - std::vector> offset_byte_sizes) + std::vector> offset_byte_sizes, + std::optional + offset_as_function_of_indvar_metadata) : Thunk(Kind::kDynamicSlice, thunk_info), embedded_thunk_(std::make_unique( ThunkInfo(), std::move(*embedded_thunk))), @@ -63,7 +85,9 @@ DynamicSliceThunk::DynamicSliceThunk( offsets_(offsets), orig_shapes_(orig_shapes), sliced_shapes_(sliced_shapes), - offset_byte_sizes_(offset_byte_sizes) { + offset_byte_sizes_(offset_byte_sizes), + offset_as_function_of_indvar_metadata_( + std::move(offset_as_function_of_indvar_metadata)) { // Zip all arguments together to create a list of SliceDef. for (auto [arg, offsets, orig_shape, sliced_shape, offset_byte_size] : llvm::zip_equal(arguments, offsets, orig_shapes, sliced_shapes, @@ -105,6 +129,16 @@ absl::Status DynamicSliceThunk::Prepare( } TF_RETURN_IF_ERROR(embedded_thunk_->Prepare(params, resource_requests)); + + if (offset_as_function_of_indvar_metadata_ != std::nullopt) { + Indvar(this) = + HloEvaluator() + .Evaluate( + /*module=*/*offset_as_function_of_indvar_metadata_->indvar_init, + /*arg_literals=*/{}) + .value(); + VLOG(2) << "Indvar = " << Indvar(this).ToString(); + } return absl::OkStatus(); } @@ -182,6 +216,20 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { << "]: constant offset = " << *const_offset; offset_value(argument_idx, offset_idx) = *const_offset; + } else if (HloModule** offset_module = std::get_if(&offset)) { + TF_ASSIGN_OR_RETURN( + Literal offset, + HloEvaluator().Evaluate(**offset_module, {&Indvar(this)})); + auto offset_int = LiteralUtil::LiteralAsScalarInt64(offset); + if (offset_int.has_value()) { + offset_value(argument_idx, offset_idx) = *offset_int; + } else { + return absl::InternalError( + absl::StrFormat("Unhandled type returned from offset module: %s", + offset.shape().ToString())); + } + VLOG(2) << "Offset value = " << offset_value(argument_idx, offset_idx); + } else { // Transfer slice offset value from device to host. auto alloc_slice = std::get(offset); @@ -252,6 +300,15 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { // Execute the underlying custom call thunk with the new buffers. TF_RETURN_IF_ERROR(embedded_thunk_->ExecuteOnStream(new_params)); + if (offset_as_function_of_indvar_metadata_ != std::nullopt) { + Indvar(this) = + HloEvaluator() + .Evaluate(*offset_as_function_of_indvar_metadata_->indvar_update, + {&Indvar(this)}) + .value(); + VLOG(2) << "Indvar = " << Indvar(this).ToString(); + } + return absl::OkStatus(); } diff --git a/xla/backends/gpu/runtime/dynamic_slice_thunk.h b/xla/backends/gpu/runtime/dynamic_slice_thunk.h index 10371d74c72ec9..6b0717265d85e4 100644 --- a/xla/backends/gpu/runtime/dynamic_slice_thunk.h +++ b/xla/backends/gpu/runtime/dynamic_slice_thunk.h @@ -32,6 +32,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream_executor.h" @@ -47,8 +48,59 @@ class DynamicSliceThunk : public Thunk { public: // Dynamic slice offset can be either: (1) a statically known constant value // or (2) a truly dynamic offset that is computed on device and have to be - // transferred to host. - using Offset = std::variant; + // transferred to host or (3) a temporary HloModule that computes the offset + // with a single induction variable as the input. + using Offset = std::variant; + + struct OffsetAsFunctionOfIndvarModulesMetadata { + // These two modules help keep track of the induction variable. The module + // `indvar_init_` is a module with prototype `() -> integer[]`. It takes + // no input, and returns the initial value of the induction variable. The + // module `indvar_update_` is a module with prototype `(integer[]) -> + // integer[]`. It takes the current value of the induction variable, and + // returns the next value of the induction variable. + std::unique_ptr indvar_init, indvar_update; + + // Extracted HloModules for computing dynamic offsets. The modules are + // not used here, this is solely for keeping the modules alive and maintain + // their ownership with the thunk while their raw pointers would be used + // during execution from the `offsets_` vector. These modules are with + // signature `(integer[]) -> integer[]`, where the input is the current + // value of the loop induction variable, and the output is the offset value + // for that iteration. + std::vector> extracted_offset_modules; + + OffsetAsFunctionOfIndvarModulesMetadata( + std::unique_ptr indvar_init, + std::unique_ptr indvar_update, + std::vector> extracted_offset_modules) + : indvar_init(std::move(indvar_init)), + indvar_update(std::move(indvar_update)), + extracted_offset_modules(std::move(extracted_offset_modules)) { + CHECK(this->indvar_init != nullptr && this->indvar_update != nullptr); + Shape init_output_shape = + this->indvar_init->entry_computation()->root_instruction()->shape(); + CHECK(this->indvar_init->entry_computation()->num_parameters() == 0 && + init_output_shape.IsInteger() && + ShapeUtil::IsScalar(init_output_shape)) + << "Induction variable init module expected with signature `() -> " + "integer[]`."; + Shape update_output_shape = + this->indvar_update->entry_computation()->root_instruction()->shape(); + CHECK(this->indvar_update->entry_computation()->num_parameters() == 1 && + update_output_shape.IsInteger() && + ShapeUtil::IsScalar(update_output_shape)) + << "Induction variable update module expected with signature " + "`(integer[]) -> integer[]`."; + Shape update_input_shape = this->indvar_update->entry_computation() + ->parameter_instruction(0) + ->shape(); + CHECK(ShapeUtil::IsScalar(update_input_shape) && + update_input_shape.IsInteger()) + << "Induction variable update module expected with signature " + "`(integer[]) -> integer[]`."; + } + }; DynamicSliceThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, @@ -57,8 +109,9 @@ class DynamicSliceThunk : public Thunk { std::vector>> offsets, std::vector> orig_shapes, std::vector> sliced_shapes, - std::vector> offset_byte_sizes); - + std::vector> offset_byte_sizes, + std::optional + offset_as_function_of_indvar_metadata = std::nullopt); DynamicSliceThunk(const DynamicSliceThunk&) = delete; DynamicSliceThunk& operator=(const DynamicSliceThunk&) = delete; @@ -132,6 +185,12 @@ class DynamicSliceThunk : public Thunk { // A mapping from argument index to the base offset in the `offsets_allocs_`. std::vector offsets_allocs_base_; + + // This structure holds the metadata for offset computations on host. It + // stores a single induction variable initialization module, its update module + // and the offsets that are a function of the induction variable. + std::optional + offset_as_function_of_indvar_metadata_; }; } // namespace gpu diff --git a/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc b/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc index 6c9977f9c7ac8e..5985a998844f0e 100644 --- a/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/resource_requests.h" +#include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape_util.h" @@ -48,6 +50,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" @@ -63,6 +66,13 @@ namespace xla::gpu { namespace { +class DynamicSliceThunkTest : public HloRunnerAgnosticTestBase { + public: + DynamicSliceThunkTest() + : HloRunnerAgnosticTestBase(std::make_unique( + PlatformUtil::GetDefaultPlatform().value())) {} +}; + static se::StreamExecutor* GpuExecutor() { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); @@ -72,7 +82,7 @@ static se::StreamExecutor* GpuExecutor() { } // namespace -TEST(DynamicSliceThunkTest, SlicedGemm) { +TEST_F(DynamicSliceThunkTest, SlicedGemm) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -212,7 +222,7 @@ TEST(DynamicSliceThunkTest, SlicedGemm) { ASSERT_EQ(dst, std::vector({9})); } -TEST(DynamicSliceThunkTest, MulipleSlicedOperandsGemm) { +TEST_F(DynamicSliceThunkTest, MulipleSlicedOperandsGemm) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -404,7 +414,7 @@ XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, kMemcpy); -TEST(DynamicSliceThunkTest, SlicedMemcpy) { +TEST_F(DynamicSliceThunkTest, SlicedMemcpy) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -540,7 +550,7 @@ TEST(DynamicSliceThunkTest, SlicedMemcpy) { ASSERT_EQ(out, ref); } -TEST(DynamicSliceThunkTest, SlicedOutputMemcpy) { +TEST_F(DynamicSliceThunkTest, SlicedOutputMemcpy) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -736,7 +746,7 @@ TEST(DynamicSliceThunkTest, SlicedOutputMemcpy) { ASSERT_EQ(out, ref); } -TEST(DynamicSliceThunkTest, SlicedGemmArbitraryArgumentOrder) { +TEST_F(DynamicSliceThunkTest, SlicedGemmArbitraryArgumentOrder) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -885,7 +895,7 @@ TEST(DynamicSliceThunkTest, SlicedGemmArbitraryArgumentOrder) { ASSERT_EQ(dst, std::vector({9})); } -TEST(DynamicSliceThunkTest, SlicedGemmArbitraryNumberOfArguments) { +TEST_F(DynamicSliceThunkTest, SlicedGemmArbitraryNumberOfArguments) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -1036,7 +1046,7 @@ TEST(DynamicSliceThunkTest, SlicedGemmArbitraryNumberOfArguments) { ASSERT_EQ(dst, std::vector({9})); } -TEST(DynamicSliceThunkTest, SlicedTupledOperandGemm) { +TEST_F(DynamicSliceThunkTest, SlicedTupledOperandGemm) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -1184,7 +1194,7 @@ TEST(DynamicSliceThunkTest, SlicedTupledOperandGemm) { ASSERT_EQ(dst, std::vector({9})); } -TEST(DynamicSliceThunkTest, SlicedMemcpyOOB) { +TEST_F(DynamicSliceThunkTest, SlicedMemcpyOOB) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -1383,7 +1393,7 @@ TEST(DynamicSliceThunkTest, SlicedMemcpyOOB) { ASSERT_EQ(out, ref); } -TEST(DynamicSliceThunkTest, SlicedOperandsSameBufferGemm) { +TEST_F(DynamicSliceThunkTest, SlicedOperandsSameBufferGemm) { se::StreamExecutor* executor = GpuExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); @@ -1537,4 +1547,211 @@ TEST(DynamicSliceThunkTest, SlicedOperandsSameBufferGemm) { ASSERT_EQ(dst, std::vector({9})); } +TEST_F(DynamicSliceThunkTest, + HostInductionVariableAndOffsetEvaluationExecutesCorrectly) { + std::vector> offset_modules; + const char* offset = R"( + HloModule offset + ENTRY main { + p0 = s32[] parameter(0) + c32 = s32[] constant(32) + c0 = s32[] constant(0) + add = s32[] add(p0, c32) + compare = pred[] compare(add, c0), direction=LT + ROOT select = s32[] select(compare, add, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(offset_modules.emplace_back(), + ParseAndReturnVerifiedModule(offset)); + HloModule* offset_module = offset_modules.back().get(); + const char* indvar_init = R"( + HloModule indvar_init + ENTRY main { + ROOT c0 = s32[] constant(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr indvar_init_module, + ParseAndReturnVerifiedModule(indvar_init)); + const char* indvar_update = R"( + HloModule indvar_update + ENTRY main { + p0 = s32[] parameter(0) + c1 = s32[] constant(1) + ROOT add = s32[] add(p0, c1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr indvar_update_module, + ParseAndReturnVerifiedModule(indvar_update)); + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 4 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back(std::make_unique( + /*index=*/0, /*size=*/rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake( + /*allocation=*/fake_allocations.back().get(), /*offset=*/0, + /*size=*/rhs_length); + + BufferAllocation alloc_lhs(/*index=*/0, /*size=*/lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, /*offset=*/0, + /*size=*/lhs_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, /*size=*/rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs( + /*allocation=*/fake_allocations.back().get(), /*offset=*/0, + /*size=*/rhs_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/2, /*size=*/out_length, /*color=*/0)); + BufferAllocation::Slice slice_out( + /*allocation=*/fake_allocations.back().get(), /*offset=*/0, + /*size=*/out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, /*size=*/1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace( + /*allocation=*/fake_allocations.back().get(), /*offset=*/0, + /*size=*/1024 * 1024); + + // Preparing config for GEMM thunk. + absl::StatusOr config = GemmConfig::For( + /*lhs_shape=*/ShapeUtil::MakeShape(/*element_type=*/PrimitiveType::F32, + /*dimensions=*/{1, 4}), + /*lhs_batch_dims=*/{}, /*lhs_contracting_dims=*/{1}, + /*rhs_shape=*/ + ShapeUtil::MakeShape(/*element_type=*/PrimitiveType::F32, + /*dimensions=*/{4, 1}), + /*rhs_batch_dims=*/{}, /*rhs_contracting_dims=*/{0}, + /*output_shape=*/ + ShapeUtil::MakeShape(/*element_type=*/PrimitiveType::F32, + /*dimensions=*/{1, 1}), + /*alpha_real=*/1.0, /*alpha_imag=*/0.0, /*beta=*/0.0, + /*precision_algorithm=*/PrecisionConfig::ALG_UNSET, + /*algorithm=*/std::nullopt, + /*compute_precision=*/se::blas::kDefaultComputePrecision, + /*grad_x=*/false, /*grad_y=*/false, + /*gpu_version=*/ + executor->GetDeviceDescription().gpu_compute_capability()); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + /*thunk_info*/ Thunk::ThunkInfo(), /*config=*/config.value(), + /*lhs_buffer=*/slice_lhs_fake, /*rhs_buffer=*/slice_rhs, + /*output_buffer=*/slice_out, + /*workspace=*/slice_workspace, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{offset_module, 0l}; + DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata + offset_as_function_of_indvar_modules_metadata( + std::move(indvar_init_module), std::move(indvar_update_module), + std::move(offset_modules)); + DynamicSliceThunk thunk( + /*thunk_info=*/Thunk::ThunkInfo(), + /*embedded_thunk=*/std::make_unique(std::move(seq)), + /*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace}, + /*fake_allocations=*/std::move(fake_allocations), + /*offsets=*/{lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + /*orig_shapes=*/ + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + /*sliced_shapes=*/ + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 4}), std::nullopt, + std::nullopt, std::nullopt}, + /*offset_byte_sizes=*/ + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}, + /*offset_as_function_of_indvar_metadata=*/ + std::move(offset_as_function_of_indvar_modules_metadata)); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [0:4]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + se::DeviceMemory lhs = + executor->AllocateArray(/*element_count=*/2 * 4); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(/*gpu_dst=*/&lhs, /*host_src=*/lhs_arr.data(), + /*size=*/lhs_length)); + + // rhs = [4.0, + // 3.0, + // 2.0, + // 1.0] + se::DeviceMemory rhs = + executor->AllocateArray(/*element_count=*/4 * 1); + std::vector rhs_arr{4, 3, 2, 1}; + TF_ASSERT_OK(stream->Memcpy(/*gpu_dst=*/&rhs, /*host_src=*/rhs_arr.data(), + /*size=*/rhs_length)); + + se::DeviceMemory out = + executor->AllocateArray(/*element_count=*/1 * 1); + TF_ASSERT_OK(stream->MemZero(/*location=*/&out, /*size=*/out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(/*element_count=*/1024 * 1024); + TF_ASSERT_OK(stream->MemZero(/*location=*/&workspace, /*size=*/1024 * 1024)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations(/*buffers=*/{lhs, rhs, out, workspace}, + /*device_ordinal=*/0, + /*memory_allocator=*/&allocator); + + Thunk::PrepareParams prepare_params{nullptr}; + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, /*buffer_allocations=*/allocations, stream.get(), + /*command_buffer_trace_stream=*/stream.get(), + /*collective_params=*/nullptr, /*collective_cliques=*/nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + ResourceRequests resource_requests; + TF_ASSERT_OK(thunk.Prepare(prepare_params, resource_requests)); + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(/*host_dst=*/dst.data(), /*gpu_src=*/out, + /*size=*/out_length)); + + ASSERT_EQ(dst, std::vector({1 * 4 + 2 * 3 + 3 * 2 + 4 * 1})); + + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + TF_ASSERT_OK(stream->Memcpy(/*host_dst=*/dst.data(), /*gpu_src=*/out, + /*size=*/out_length)); + + EXPECT_EQ(dst, std::vector({5 * 4 + 6 * 3 + 7 * 2 + 8 * 1})); +} + } // namespace xla::gpu