Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shraiysh committed Jan 13, 2025
1 parent d97beec commit 0995afc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
26 changes: 14 additions & 12 deletions xla/service/gpu/runtime/dynamic_slice_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ namespace gpu {

namespace {

std::unique_ptr<Literal>& Indvar(DynamicSliceThunk* thunk) {
static thread_local absl::flat_hash_map<DynamicSliceThunk*,
std::unique_ptr<Literal>>
// Indvar is a thread-local map that stores the induction variable for each
// dynamic slice thunk. The usage of threadlocal in this context is similar to
// `LoopCounters` in while_thunk.cc (b/343294327).
Literal& Indvar(DynamicSliceThunk* thunk) {
static thread_local absl::flat_hash_map<DynamicSliceThunk*, Literal>
indvar_map;
return indvar_map[thunk];
}
Expand All @@ -67,7 +69,7 @@ DynamicSliceThunk::DynamicSliceThunk(
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::vector<std::unique_ptr<HloModule>> temp_modules,
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules,
std::unique_ptr<HloModule> indvar_init,
std::unique_ptr<HloModule> indvar_update)
: Thunk(Kind::kDynamicSlice, thunk_info),
Expand All @@ -79,7 +81,7 @@ DynamicSliceThunk::DynamicSliceThunk(
orig_shapes_(orig_shapes),
sliced_shapes_(sliced_shapes),
offset_byte_sizes_(offset_byte_sizes),
temp_modules_(std::move(temp_modules)),
extracted_offset_modules_(std::move(extracted_offset_modules)),
indvar_init_(std::move(indvar_init)),
indvar_update_(
std::move(indvar_update)) { // Zip all arguments together to create a
Expand Down Expand Up @@ -128,8 +130,8 @@ absl::Status DynamicSliceThunk::Prepare(
if (indvar_init_ != nullptr) {
Indvar(this) = HloEvaluator()
.Evaluate(/*module=*/*indvar_init_, /*arg_literals=*/{})
->CloneToUnique();
VLOG(0) << "Indvar = " << Indvar(this)->ToString();
.value();
VLOG(2) << "Indvar = " << Indvar(this).ToString();
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -211,7 +213,7 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
} else if (HloModule** offset_module = std::get_if<HloModule*>(&offset)) {
TF_ASSIGN_OR_RETURN(
Literal offset,
HloEvaluator().Evaluate(**offset_module, {Indvar(this).get()}));
HloEvaluator().Evaluate(**offset_module, {&Indvar(this)}));
auto offset_int = LiteralUtil::LiteralAsScalarInt64(offset);
if (offset_int.has_value()) {
offset_value(argument_idx, offset_idx) = *offset_int;
Expand All @@ -220,7 +222,7 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
absl::StrFormat("Unhandled type returned from offset module: %s",
offset.shape().ToString()));
}
VLOG(1) << "Offset value = " << offset_value(argument_idx, offset_idx);
VLOG(2) << "Offset value = " << offset_value(argument_idx, offset_idx);

} else {
// Transfer slice offset value from device to host.
Expand Down Expand Up @@ -293,9 +295,9 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_RETURN_IF_ERROR(embedded_thunk_->ExecuteOnStream(new_params));

if (indvar_update_ != nullptr) {
Indvar(this) = HloEvaluator()
.Evaluate(*indvar_update_, {Indvar(this).get()})
->CloneToUnique();
Indvar(this) =
HloEvaluator().Evaluate(*indvar_update_, {&Indvar(this)}).value();
VLOG(2) << "Indvar = " << Indvar(this).ToString();
}

return absl::OkStatus();
Expand Down
20 changes: 17 additions & 3 deletions xla/service/gpu/runtime/dynamic_slice_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class DynamicSliceThunk : public Thunk {
public:
// Dynamic slice offset can be either: (1) a statically known constant value
// or (2) a truly dynamic offset that is computed on device and have to be
// transferred to host.
// transferred to host or (3) a temporary HloModule that computes the offset
// with induction variable as the input.
using Offset = std::variant<int64_t, BufferAllocation::Slice, HloModule*>;

DynamicSliceThunk(
Expand All @@ -58,7 +59,7 @@ class DynamicSliceThunk : public Thunk {
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::vector<std::unique_ptr<HloModule>> temp_modules = {},
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules = {},
std::unique_ptr<HloModule> indvar_init = nullptr,
std::unique_ptr<HloModule> indvar_update = nullptr);
DynamicSliceThunk(const DynamicSliceThunk&) = delete;
Expand Down Expand Up @@ -135,7 +136,20 @@ class DynamicSliceThunk : public Thunk {
// A mapping from argument index to the base offset in the `offsets_allocs_`.
std::vector<int64_t> offsets_allocs_base_;

std::vector<std::unique_ptr<HloModule>> temp_modules_;
// Extracted HloModules for computing dynamic offsets. The modules here are
// not used, this is solely for keeping the modules alive and maintain their
// ownership with the thunk. The pointers to these offsets are stored in the
// `offsets_` vector. These modules are of the signature `(integer[]) ->
// integer[]`, where the input is the current value of the loop induction
// variable, and the output is the offset value for that iteration.
std::vector<std::unique_ptr<HloModule>> extracted_offset_modules_;

// These two modules help keep track of the induction variable. The module
// `indvar_init_` is a module of the prototype `() -> integer[]`. It takes no
// input, and returns the initial value of the induction variable. The module
// `indvar_update_` is a module of the prototype `(integer[]) -> integer[]`.
// It takes the current value of the induction variable, and returns the next
// value of the induction variable.
std::unique_ptr<HloModule> indvar_init_, indvar_update_;
};

Expand Down

0 comments on commit 0995afc

Please sign in to comment.