Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gpu][ds-fusion] Add handling for offset module in ds-fusion thunk #20794

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ cc_library(
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/service:buffer_assignment",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:ir_emission_utils",
Expand Down Expand Up @@ -247,6 +248,8 @@ xla_test(
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:resource_requests",
"//xla/service:hlo_runner",
"//xla/stream_executor:blas",
"//xla/stream_executor:command_buffer",
"//xla/stream_executor:device_memory",
Expand All @@ -258,6 +261,7 @@ xla_test(
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/stream_executor/gpu:gpu_test_kernels",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/tests:hlo_runner_agnostic_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
Expand Down
61 changes: 59 additions & 2 deletions xla/backends/gpu/runtime/dynamic_slice_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "llvm/ADT/STLExtras.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
Expand All @@ -47,14 +48,35 @@ limitations under the License.
namespace xla {
namespace gpu {

namespace {

// Indvar is a thread-local map that stores the induction variable for each
// dynamic slice thunk. The same thunk object in the memory is shared by
// multiple replicas of the same computation. So, each replica should have its
// own tracking of the induction variable (threadlocal). With threadlocal, we
// cannot embed this inside the dynamic slice thunk object, and so we have a
// static map. There could be multiple dynamic slice thunks in the same module,
// and so we need a map to store the induction variable for each thunk. The
// usage of threadlocal in this context is similar to `LoopCounters` in
// while_thunk.cc (b/343294327).
Literal& Indvar(DynamicSliceThunk* thunk) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These definitions are also in another change, where I commented on them & their documentation. Where should I review their introduction?

Either way, I'd like to see better documentations of why this is an appropriate implementation, and what this relies on :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I review their introduction?

In this change. That change is supposed to be dependent on this, but it is hard with github to sync changes frequently.

Either way, I'd like to see better documentations of why this is an appropriate implementation, and what this relies on :)

I will edit that comment with more details.

static thread_local absl::flat_hash_map<DynamicSliceThunk*, Literal>
indvar_map;
return indvar_map[thunk];
}

} // namespace

DynamicSliceThunk::DynamicSliceThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<std::optional<std::vector<Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes)
std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata)
: Thunk(Kind::kDynamicSlice, thunk_info),
embedded_thunk_(std::make_unique<SequentialThunk>(
ThunkInfo(), std::move(*embedded_thunk))),
Expand All @@ -63,7 +85,9 @@ DynamicSliceThunk::DynamicSliceThunk(
offsets_(offsets),
orig_shapes_(orig_shapes),
sliced_shapes_(sliced_shapes),
offset_byte_sizes_(offset_byte_sizes) {
offset_byte_sizes_(offset_byte_sizes),
offset_as_function_of_indvar_metadata_(
std::move(offset_as_function_of_indvar_metadata)) {
// Zip all arguments together to create a list of SliceDef.
for (auto [arg, offsets, orig_shape, sliced_shape, offset_byte_size] :
llvm::zip_equal(arguments, offsets, orig_shapes, sliced_shapes,
Expand Down Expand Up @@ -105,6 +129,16 @@ absl::Status DynamicSliceThunk::Prepare(
}

TF_RETURN_IF_ERROR(embedded_thunk_->Prepare(params, resource_requests));

if (offset_as_function_of_indvar_metadata_ != std::nullopt) {
Indvar(this) =
HloEvaluator()
.Evaluate(
/*module=*/*offset_as_function_of_indvar_metadata_->indvar_init,
/*arg_literals=*/{})
.value();
VLOG(2) << "Indvar = " << Indvar(this).ToString();
}
return absl::OkStatus();
}

Expand Down Expand Up @@ -182,6 +216,20 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
<< "]: constant offset = " << *const_offset;
offset_value(argument_idx, offset_idx) = *const_offset;

} else if (HloModule** offset_module = std::get_if<HloModule*>(&offset)) {
TF_ASSIGN_OR_RETURN(
Literal offset,
HloEvaluator().Evaluate(**offset_module, {&Indvar(this)}));
auto offset_int = LiteralUtil::LiteralAsScalarInt64(offset);
if (offset_int.has_value()) {
offset_value(argument_idx, offset_idx) = *offset_int;
} else {
return absl::InternalError(
absl::StrFormat("Unhandled type returned from offset module: %s",
offset.shape().ToString()));
}
VLOG(2) << "Offset value = " << offset_value(argument_idx, offset_idx);

} else {
// Transfer slice offset value from device to host.
auto alloc_slice = std::get<BufferAllocation::Slice>(offset);
Expand Down Expand Up @@ -252,6 +300,15 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
// Execute the underlying custom call thunk with the new buffers.
TF_RETURN_IF_ERROR(embedded_thunk_->ExecuteOnStream(new_params));

if (offset_as_function_of_indvar_metadata_ != std::nullopt) {
Indvar(this) =
HloEvaluator()
.Evaluate(*offset_as_function_of_indvar_metadata_->indvar_update,
{&Indvar(this)})
.value();
VLOG(2) << "Indvar = " << Indvar(this).ToString();
}

return absl::OkStatus();
}

Expand Down
67 changes: 63 additions & 4 deletions xla/backends/gpu/runtime/dynamic_slice_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/service/buffer_assignment.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream_executor.h"

Expand All @@ -47,8 +48,59 @@ class DynamicSliceThunk : public Thunk {
public:
// Dynamic slice offset can be either: (1) a statically known constant value
// or (2) a truly dynamic offset that is computed on device and have to be
// transferred to host.
using Offset = std::variant<int64_t, BufferAllocation::Slice>;
// transferred to host or (3) a temporary HloModule that computes the offset
// with a single induction variable as the input.
using Offset = std::variant<int64_t, BufferAllocation::Slice, HloModule*>;

struct OffsetAsFunctionOfIndvarModulesMetadata {
// These two modules help keep track of the induction variable. The module
// `indvar_init_` is a module with prototype `() -> integer[]`. It takes
// no input, and returns the initial value of the induction variable. The
// module `indvar_update_` is a module with prototype `(integer[]) ->
// integer[]`. It takes the current value of the induction variable, and
// returns the next value of the induction variable.
std::unique_ptr<HloModule> indvar_init, indvar_update;

// Extracted HloModules for computing dynamic offsets. The modules are
// not used here, this is solely for keeping the modules alive and maintain
// their ownership with the thunk while their raw pointers would be used
// during execution from the `offsets_` vector. These modules are with
// signature `(integer[]) -> integer[]`, where the input is the current
// value of the loop induction variable, and the output is the offset value
// for that iteration.
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules;

OffsetAsFunctionOfIndvarModulesMetadata(
std::unique_ptr<HloModule> indvar_init,
std::unique_ptr<HloModule> indvar_update,
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules)
: indvar_init(std::move(indvar_init)),
indvar_update(std::move(indvar_update)),
extracted_offset_modules(std::move(extracted_offset_modules)) {
CHECK(this->indvar_init != nullptr && this->indvar_update != nullptr);
Shape init_output_shape =
this->indvar_init->entry_computation()->root_instruction()->shape();
CHECK(this->indvar_init->entry_computation()->num_parameters() == 0 &&
init_output_shape.IsInteger() &&
ShapeUtil::IsScalar(init_output_shape))
<< "Induction variable init module expected with signature `() -> "
"integer[]`.";
Shape update_output_shape =
this->indvar_update->entry_computation()->root_instruction()->shape();
CHECK(this->indvar_update->entry_computation()->num_parameters() == 1 &&
update_output_shape.IsInteger() &&
ShapeUtil::IsScalar(update_output_shape))
<< "Induction variable update module expected with signature "
"`(integer[]) -> integer[]`.";
Shape update_input_shape = this->indvar_update->entry_computation()
->parameter_instruction(0)
->shape();
CHECK(ShapeUtil::IsScalar(update_input_shape) &&
update_input_shape.IsInteger())
<< "Induction variable update module expected with signature "
"`(integer[]) -> integer[]`.";
}
};

DynamicSliceThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
Expand All @@ -57,8 +109,9 @@ class DynamicSliceThunk : public Thunk {
std::vector<std::optional<std::vector<Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes);

std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata = std::nullopt);
DynamicSliceThunk(const DynamicSliceThunk&) = delete;
DynamicSliceThunk& operator=(const DynamicSliceThunk&) = delete;

Expand Down Expand Up @@ -132,6 +185,12 @@ class DynamicSliceThunk : public Thunk {

// A mapping from argument index to the base offset in the `offsets_allocs_`.
std::vector<int64_t> offsets_allocs_base_;

// This structure holds the metadata for offset computations on host. It
// stores a single induction variable initialization module, its update module
// and the offsets that are a function of the induction variable.
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata_;
};

} // namespace gpu
Expand Down
Loading