Skip to content

Commit

Permalink
PR #20794: [gpu][ds-fusion] Add handling for offset module in ds-fusi…
Browse files Browse the repository at this point in the history
…on thunk

Imported from GitHub PR #20794

This patch adds support for offset modules in ds fusion thunk. It also moves the `ResourceRequests` structure from `gpu_executable.cc` to `gpu_executable.h` because a valid implementation of the abstract class `Thunk::ResourceRequests` is required for calling `Thunk::Prepare()`.

This is split from #20332 as per request.
Copybara import of the project:

--
53226e9 by Shraiysh Vaishay <[email protected]>:

[gpu][ds-fusion] Add handling for offset module in ds-fusion thunk

This patch adds support for offset modules in ds fusion thunk. It also
moves the `ResourceRequests` structure from `gpu_executable.cc` to
`gpu_executable.h` because a valid implementation of the abstract class
`Thunk::ResourceRequests` is required for calling `Thunk::Prepare()`.

--
e554ac6 by Shraiysh Vaishay <[email protected]>:

Address comments and rebase

--
4ac8efa by Shraiysh Vaishay <[email protected]>:

Address comments

--
8c23e2a by Shraiysh Vaishay <[email protected]>:

Addressed comments.

--
efbe6bf by Shraiysh Vaishay <[email protected]>:

Addressed comments

Merging this change closes #20794

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20794 from shraiysh:ds_fusion_thunk_changes efbe6bf
PiperOrigin-RevId: 718748248
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Jan 27, 2025
1 parent db8d9e1 commit 4c9b9e8
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 15 deletions.
4 changes: 4 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ cc_library(
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/service:buffer_assignment",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:ir_emission_utils",
Expand Down Expand Up @@ -243,10 +244,12 @@ xla_test(
"//xla/ffi:ffi_api",
"//xla/service:buffer_assignment",
"//xla/service:executable",
"//xla/service:hlo_runner",
"//xla/service:platform_util",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:resource_requests",
"//xla/stream_executor:blas",
"//xla/stream_executor:command_buffer",
"//xla/stream_executor:device_memory",
Expand All @@ -258,6 +261,7 @@ xla_test(
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/stream_executor/gpu:gpu_test_kernels",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/tests:hlo_runner_agnostic_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
Expand Down
61 changes: 59 additions & 2 deletions xla/backends/gpu/runtime/dynamic_slice_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/service/gpu/ir_emission_utils.h"
Expand All @@ -47,14 +48,35 @@ limitations under the License.
namespace xla {
namespace gpu {

namespace {

// Indvar is a thread-local map that stores the induction variable for each
// dynamic slice thunk. The same thunk object in the memory is shared by
// multiple replicas of the same computation. So, each replica should have its
// own tracking of the induction variable (threadlocal). With threadlocal, we
// cannot embed this inside the dynamic slice thunk object, and so we have a
// static map. There could be multiple dynamic slice thunks in the same module,
// and so we need a map to store the induction variable for each thunk. The
// usage of threadlocal in this context is similar to `LoopCounters` in
// while_thunk.cc (b/343294327).
Literal& Indvar(DynamicSliceThunk* thunk) {
static thread_local absl::flat_hash_map<DynamicSliceThunk*, Literal>
indvar_map;
return indvar_map[thunk];
}

} // namespace

DynamicSliceThunk::DynamicSliceThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<std::optional<std::vector<Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes)
std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata)
: Thunk(Kind::kDynamicSlice, thunk_info),
embedded_thunk_(std::make_unique<SequentialThunk>(
ThunkInfo(), std::move(*embedded_thunk))),
Expand All @@ -63,7 +85,9 @@ DynamicSliceThunk::DynamicSliceThunk(
offsets_(offsets),
orig_shapes_(orig_shapes),
sliced_shapes_(sliced_shapes),
offset_byte_sizes_(offset_byte_sizes) {
offset_byte_sizes_(offset_byte_sizes),
offset_as_function_of_indvar_metadata_(
std::move(offset_as_function_of_indvar_metadata)) {
// Zip all arguments together to create a list of SliceDef.
for (auto [arg, offsets, orig_shape, sliced_shape, offset_byte_size] :
llvm::zip_equal(arguments, offsets, orig_shapes, sliced_shapes,
Expand Down Expand Up @@ -105,6 +129,16 @@ absl::Status DynamicSliceThunk::Prepare(
}

TF_RETURN_IF_ERROR(embedded_thunk_->Prepare(params, resource_requests));

if (offset_as_function_of_indvar_metadata_ != std::nullopt) {
Indvar(this) =
HloEvaluator()
.Evaluate(
/*module=*/*offset_as_function_of_indvar_metadata_->indvar_init,
/*arg_literals=*/{})
.value();
VLOG(2) << "Indvar = " << Indvar(this).ToString();
}
return absl::OkStatus();
}

Expand Down Expand Up @@ -182,6 +216,20 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
<< "]: constant offset = " << *const_offset;
offset_value(argument_idx, offset_idx) = *const_offset;

} else if (HloModule** offset_module = std::get_if<HloModule*>(&offset)) {
TF_ASSIGN_OR_RETURN(
Literal offset,
HloEvaluator().Evaluate(**offset_module, {&Indvar(this)}));
auto offset_int = LiteralUtil::LiteralAsScalarInt64(offset);
if (offset_int.has_value()) {
offset_value(argument_idx, offset_idx) = *offset_int;
} else {
return absl::InternalError(
absl::StrFormat("Unhandled type returned from offset module: %s",
offset.shape().ToString()));
}
VLOG(2) << "Offset value = " << offset_value(argument_idx, offset_idx);

} else {
// Transfer slice offset value from device to host.
auto alloc_slice = std::get<BufferAllocation::Slice>(offset);
Expand Down Expand Up @@ -252,6 +300,15 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
// Execute the underlying custom call thunk with the new buffers.
TF_RETURN_IF_ERROR(embedded_thunk_->ExecuteOnStream(new_params));

if (offset_as_function_of_indvar_metadata_ != std::nullopt) {
Indvar(this) =
HloEvaluator()
.Evaluate(*offset_as_function_of_indvar_metadata_->indvar_update,
{&Indvar(this)})
.value();
VLOG(2) << "Indvar = " << Indvar(this).ToString();
}

return absl::OkStatus();
}

Expand Down
67 changes: 63 additions & 4 deletions xla/backends/gpu/runtime/dynamic_slice_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/service/buffer_assignment.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream_executor.h"

Expand All @@ -47,8 +48,59 @@ class DynamicSliceThunk : public Thunk {
public:
// Dynamic slice offset can be either: (1) a statically known constant value
// or (2) a truly dynamic offset that is computed on device and have to be
// transferred to host.
using Offset = std::variant<int64_t, BufferAllocation::Slice>;
// transferred to host or (3) a temporary HloModule that computes the offset
// with a single induction variable as the input.
using Offset = std::variant<int64_t, BufferAllocation::Slice, HloModule*>;

struct OffsetAsFunctionOfIndvarModulesMetadata {
// These two modules help keep track of the induction variable. The module
// `indvar_init_` is a module with prototype `() -> integer[]`. It takes
// no input, and returns the initial value of the induction variable. The
// module `indvar_update_` is a module with prototype `(integer[]) ->
// integer[]`. It takes the current value of the induction variable, and
// returns the next value of the induction variable.
std::unique_ptr<HloModule> indvar_init, indvar_update;

// Extracted HloModules for computing dynamic offsets. The modules are
// not used here, this is solely for keeping the modules alive and maintain
// their ownership with the thunk while their raw pointers would be used
// during execution from the `offsets_` vector. These modules are with
// signature `(integer[]) -> integer[]`, where the input is the current
// value of the loop induction variable, and the output is the offset value
// for that iteration.
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules;

OffsetAsFunctionOfIndvarModulesMetadata(
std::unique_ptr<HloModule> indvar_init,
std::unique_ptr<HloModule> indvar_update,
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules)
: indvar_init(std::move(indvar_init)),
indvar_update(std::move(indvar_update)),
extracted_offset_modules(std::move(extracted_offset_modules)) {
CHECK(this->indvar_init != nullptr && this->indvar_update != nullptr);
Shape init_output_shape =
this->indvar_init->entry_computation()->root_instruction()->shape();
CHECK(this->indvar_init->entry_computation()->num_parameters() == 0 &&
init_output_shape.IsInteger() &&
ShapeUtil::IsScalar(init_output_shape))
<< "Induction variable init module expected with signature `() -> "
"integer[]`.";
Shape update_output_shape =
this->indvar_update->entry_computation()->root_instruction()->shape();
CHECK(this->indvar_update->entry_computation()->num_parameters() == 1 &&
update_output_shape.IsInteger() &&
ShapeUtil::IsScalar(update_output_shape))
<< "Induction variable update module expected with signature "
"`(integer[]) -> integer[]`.";
Shape update_input_shape = this->indvar_update->entry_computation()
->parameter_instruction(0)
->shape();
CHECK(ShapeUtil::IsScalar(update_input_shape) &&
update_input_shape.IsInteger())
<< "Induction variable update module expected with signature "
"`(integer[]) -> integer[]`.";
}
};

DynamicSliceThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
Expand All @@ -57,8 +109,9 @@ class DynamicSliceThunk : public Thunk {
std::vector<std::optional<std::vector<Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes);

std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata = std::nullopt);
DynamicSliceThunk(const DynamicSliceThunk&) = delete;
DynamicSliceThunk& operator=(const DynamicSliceThunk&) = delete;

Expand Down Expand Up @@ -132,6 +185,12 @@ class DynamicSliceThunk : public Thunk {

// A mapping from argument index to the base offset in the `offsets_allocs_`.
std::vector<int64_t> offsets_allocs_base_;

// This structure holds the metadata for offset computations on host. It
// stores a single induction variable initialization module, its update module
// and the offsets that are a function of the induction variable.
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata_;
};

} // namespace gpu
Expand Down
Loading

0 comments on commit 4c9b9e8

Please sign in to comment.